171 lines
4.3 KiB
Go
171 lines
4.3 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"time"
|
|
|
|
_ "modernc.org/sqlite"
|
|
)
|
|
|
|
type Store struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
type Message struct {
|
|
ID int64
|
|
Channel string
|
|
Author string
|
|
Body string
|
|
Time time.Time
|
|
MsgID string
|
|
}
|
|
|
|
func Open(ctx context.Context, path string) (*Store, error) {
|
|
db, err := sql.Open("sqlite", path)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
db.SetMaxOpenConns(1)
|
|
if _, err := db.ExecContext(ctx, `PRAGMA journal_mode = WAL; PRAGMA foreign_keys = ON;`); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
if err := initSchema(ctx, db); err != nil {
|
|
_ = db.Close()
|
|
return nil, err
|
|
}
|
|
// Best-effort migration: add msgid column and unique index if missing
|
|
_, _ = db.ExecContext(ctx, `ALTER TABLE messages ADD COLUMN msgid TEXT`)
|
|
_, _ = db.ExecContext(ctx, `CREATE UNIQUE INDEX IF NOT EXISTS idx_messages_msgid ON messages(msgid) WHERE msgid IS NOT NULL`)
|
|
return &Store{db: db}, nil
|
|
}
|
|
|
|
func (s *Store) Close() error { return s.db.Close() }
|
|
|
|
func initSchema(ctx context.Context, db *sql.DB) error {
|
|
const schema = `
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
channel TEXT NOT NULL,
|
|
author TEXT NOT NULL,
|
|
body TEXT NOT NULL,
|
|
at TIMESTAMP NOT NULL,
|
|
msgid TEXT
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_messages_channel_at ON messages(channel, at);
|
|
`
|
|
_, err := db.ExecContext(ctx, schema)
|
|
return err
|
|
}
|
|
|
|
func (s *Store) InsertMessage(ctx context.Context, m Message) error {
|
|
_, err := s.db.ExecContext(ctx,
|
|
"INSERT OR IGNORE INTO messages(channel, author, body, at, msgid) VALUES(?,?,?,?,?)",
|
|
m.Channel, m.Author, m.Body, m.Time.UTC(), nullIfEmpty(m.MsgID))
|
|
return err
|
|
}
|
|
|
|
// ListChannels returns distinct channel identifiers seen in the database.
|
|
func (s *Store) ListChannels(ctx context.Context) ([]string, error) {
|
|
rows, err := s.db.QueryContext(ctx, "SELECT DISTINCT channel FROM messages ORDER BY lower(channel)")
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []string
|
|
for rows.Next() {
|
|
var ch string
|
|
if err := rows.Scan(&ch); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, ch)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func nullIfEmpty(s string) any {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
return s
|
|
}
|
|
|
|
func (s *Store) ListMessagesSince(ctx context.Context, channel string, since time.Time) ([]Message, error) {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
"SELECT id, channel, author, body, at, msgid FROM messages WHERE lower(channel) = lower(?) AND at >= ? ORDER BY at ASC",
|
|
channel, since.UTC())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []Message
|
|
for rows.Next() {
|
|
var m Message
|
|
var at time.Time
|
|
var msgid sql.NullString
|
|
if err := rows.Scan(&m.ID, &m.Channel, &m.Author, &m.Body, &at, &msgid); err != nil {
|
|
return nil, err
|
|
}
|
|
m.Time = at
|
|
if msgid.Valid {
|
|
m.MsgID = msgid.String
|
|
}
|
|
out = append(out, m)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// ListRecentMessages returns the most recent N messages for a channel.
|
|
func (s *Store) ListRecentMessages(ctx context.Context, channel string, limit int) ([]Message, error) {
|
|
if limit <= 0 {
|
|
limit = 50
|
|
}
|
|
rows, err := s.db.QueryContext(ctx,
|
|
"SELECT id, channel, author, body, at, msgid FROM messages WHERE lower(channel) = lower(?) ORDER BY at DESC LIMIT ?",
|
|
channel, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []Message
|
|
for rows.Next() {
|
|
var m Message
|
|
var at time.Time
|
|
var msgid sql.NullString
|
|
if err := rows.Scan(&m.ID, &m.Channel, &m.Author, &m.Body, &at, &msgid); err != nil {
|
|
return nil, err
|
|
}
|
|
m.Time = at
|
|
if msgid.Valid {
|
|
m.MsgID = msgid.String
|
|
}
|
|
out = append(out, m)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
// LastMessageTime returns the last stored timestamp for a channel.
|
|
func (s *Store) LastMessageTime(ctx context.Context, channel string) (time.Time, bool, error) {
|
|
var nt sql.NullTime
|
|
err := s.db.QueryRowContext(ctx, "SELECT MAX(at) FROM messages WHERE lower(channel) = lower(?)", channel).Scan(&nt)
|
|
if err != nil {
|
|
return time.Time{}, false, err
|
|
}
|
|
if !nt.Valid {
|
|
return time.Time{}, false, nil
|
|
}
|
|
return nt.Time, true, nil
|
|
}
|
|
|
|
func (s *Store) DeleteOlderThan(ctx context.Context, cutoff time.Time) (int64, error) {
|
|
res, err := s.db.ExecContext(ctx, "DELETE FROM messages WHERE at < ?", cutoff.UTC())
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return res.RowsAffected()
|
|
}
|
|
|
|
var ErrNotFound = errors.New("not found")
|