sojuboy/internal/store/store.go

172 lines
4.3 KiB
Go
Raw Normal View History

package store
import (
"context"
"database/sql"
"errors"
"time"
_ "modernc.org/sqlite"
)
type Store struct {
db *sql.DB
}
type Message struct {
ID int64
Channel string
Author string
Body string
Time time.Time
MsgID string
}
func Open(ctx context.Context, path string) (*Store, error) {
db, err := sql.Open("sqlite", path)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;`); err != nil {
_ = db.Close()
return nil, err
}
if err := initSchema(ctx, db); err != nil {
_ = db.Close()
return nil, err
}
// Best-effort migration: add msgid column and unique index if missing
_, _ = db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN msgid TEXT`)
_, _ = db.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS idx_messages_msgid ON messages(msgid) WHERE msgid IS NOT NULL`)
return &Store{db: db}, nil
}
func (s *Store) Close() error { return s.db.Close() }
func initSchema(ctx context.Context, db *sql.DB) error {
const schema = `
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
channel TEXT NOT NULL,
author TEXT NOT NULL,
body TEXT NOT NULL,
at TIMESTAMP NOT NULL,
msgid TEXT
);
CREATE INDEX IF NOT EXISTS idx_messages_channel_at ON messages(channel, at);
`
_, err := db.ExecContext(ctx, schema)
return err
}
func (s *Store) InsertMessage(ctx context.Context, m Message) error {
_, err := s.db.ExecContext(ctx,
"INSERT OR IGNORE INTO messages(channel, author, body, at, msgid) VALUES(?,?,?,?,?)",
m.Channel, m.Author, m.Body, m.Time.UTC(), nullIfEmpty(m.MsgID))
return err
}
// ListChannels returns distinct channel identifiers seen in the database.
func (s *Store) ListChannels(ctx context.Context) ([]string, error) {
rows, err := s.db.QueryContext(ctx, "SELECT DISTINCT channel FROM messages ORDER BY lower(channel)")
if err != nil {
return nil, err
}
defer rows.Close()
var out []string
for rows.Next() {
var ch string
if err := rows.Scan(&ch); err != nil {
return nil, err
}
out = append(out, ch)
}
return out, rows.Err()
}
func nullIfEmpty(s string) any {
if s == "" {
return nil
}
return s
}
func (s *Store) ListMessagesSince(ctx context.Context, channel string, since time.Time) ([]Message, error) {
rows, err := s.db.QueryContext(ctx,
"SELECT id, channel, author, body, at, msgid FROM messages WHERE lower(channel) = lower(?) AND at >= ? ORDER BY at ASC",
channel, since.UTC())
if err != nil {
return nil, err
}
defer rows.Close()
var out []Message
for rows.Next() {
var m Message
var at time.Time
var msgid sql.NullString
if err := rows.Scan(&m.ID, &m.Channel, &m.Author, &m.Body, &at, &msgid); err != nil {
return nil, err
}
m.Time = at
if msgid.Valid {
m.MsgID = msgid.String
}
out = append(out, m)
}
return out, rows.Err()
}
// ListRecentMessages returns the most recent N messages for a channel.
func (s *Store) ListRecentMessages(ctx context.Context, channel string, limit int) ([]Message, error) {
if limit <= 0 {
limit = 50
}
rows, err := s.db.QueryContext(ctx,
"SELECT id, channel, author, body, at, msgid FROM messages WHERE lower(channel) = lower(?) ORDER BY at DESC LIMIT ?",
channel, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Message
for rows.Next() {
var m Message
var at time.Time
var msgid sql.NullString
if err := rows.Scan(&m.ID, &m.Channel, &m.Author, &m.Body, &at, &msgid); err != nil {
return nil, err
}
m.Time = at
if msgid.Valid {
m.MsgID = msgid.String
}
out = append(out, m)
}
return out, rows.Err()
}
// LastMessageTime returns the last stored timestamp for a channel.
func (s *Store) LastMessageTime(ctx context.Context, channel string) (time.Time, bool, error) {
var nt sql.NullTime
err := s.db.QueryRowContext(ctx, "SELECT MAX(at) FROM messages WHERE lower(channel) = lower(?)", channel).Scan(&nt)
if err != nil {
return time.Time{}, false, err
}
if !nt.Valid {
return time.Time{}, false, nil
}
return nt.Time, true, nil
}
func (s *Store) DeleteOlderThan(ctx context.Context, cutoff time.Time) (int64, error) {
res, err := s.db.ExecContext(ctx, "DELETE FROM messages WHERE at < ?", cutoff.UTC())
if err != nil {
return 0, err
}
return res.RowsAffected()
}
var ErrNotFound = errors.New("not found")