sojuboy/internal/soju/rawclient.go

310 lines
6.7 KiB
Go
Raw Normal View History

package soju
import (
"bufio"
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
"sojuboy/internal/store"
irc "github.com/sorcix/irc"
)
type RawClient struct {
Server string
Port int
UseTLS bool
Nick string
Username string // full identity: username/network@client
Realname string
Password string // PASS <password>
Channels []string
// Number of messages to fetch via CHATHISTORY LATEST per channel after join.
BackfillLatest int
OnPrivmsg func(channel, author, text, msgid string, at time.Time)
Logger *slog.Logger
Debug bool
// Store is used to compute last-seen timestamp for CHATHISTORY.
Store *store.Store
// Readiness/metrics hooks
ConnectedGauge *int64 // 0/1
IsReady *int32 // 0/1 atomic flag
}
func (c *RawClient) setConnected(v bool) {
if c.ConnectedGauge != nil {
if v {
atomic.StoreInt64(c.ConnectedGauge, 1)
} else {
atomic.StoreInt64(c.ConnectedGauge, 0)
}
}
if c.IsReady != nil {
if v {
atomic.StoreInt32(c.IsReady, 1)
} else {
atomic.StoreInt32(c.IsReady, 0)
}
}
}
func (c *RawClient) Run(ctx context.Context) error {
backoff := time.Second
for {
if err := c.runOnce(ctx); err != nil {
if ctx.Err() != nil {
return ctx.Err()
}
if c.Logger != nil {
c.Logger.Error("raw soju client stopped", "err", err)
}
time.Sleep(backoff)
if backoff < 30*time.Second {
backoff *= 2
}
continue
}
return nil
}
}
func (c *RawClient) runOnce(ctx context.Context) error {
address := net.JoinHostPort(c.Server, strconv.Itoa(c.Port))
var conn net.Conn
var err error
if c.UseTLS {
tlsCfg := &tls.Config{ServerName: c.Server, MinVersion: tls.VersionTLS12}
conn, err = tls.Dial("tcp", address, tlsCfg)
} else {
conn, err = net.Dial("tcp", address)
}
if err != nil {
return err
}
defer conn.Close()
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
write := func(line string) error {
out := line
if strings.HasPrefix(strings.ToUpper(line), "PASS ") {
out = "PASS ********"
}
if c.Debug && c.Logger != nil {
c.Logger.Debug("irc>", "line", out)
}
if _, err := rw.WriteString(line + "\r\n"); err != nil {
return err
}
return rw.Flush()
}
// Request capabilities needed for chathistory and accurate timestamps.
_ = write("CAP LS 302")
_ = write("CAP REQ :server-time batch message-tags draft/chathistory draft/event-playback echo-message cap-notify")
_ = write("CAP END")
// Authenticate with PASS/NICK/USER
if c.Password != "" {
if err := write("PASS " + c.Password); err != nil {
return err
}
}
if err := write("NICK " + c.Nick); err != nil {
return err
}
user := c.Username
if user == "" {
user = c.Nick
}
host := c.Server
if err := write(fmt.Sprintf("USER %s %s %s :%s", user, user, host, c.Realname)); err != nil {
return err
}
// Reader loop
connected := false
eventPlayback := false
selfJoined := map[string]bool{}
for {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
rawLine, err := rw.ReadString('\n')
if err != nil {
return err
}
rawLine = strings.TrimRight(rawLine, "\r\n")
if rawLine == "" {
continue
}
if c.Debug && c.Logger != nil {
c.Logger.Debug("irc<", "line", rawLine)
}
// Parse IRCv3 tags if present
var tags map[string]string
line := rawLine
if strings.HasPrefix(line, "@") {
sp := strings.IndexByte(line, ' ')
if sp > 0 {
tags = parseTags(line[1:sp])
line = strings.TrimSpace(line[sp+1:])
}
}
msg := irc.ParseMessage(line)
if msg == nil {
continue
}
cmd := strings.ToUpper(msg.Command)
switch cmd {
case "CAP":
// Examples: :bnc CAP * ACK :server-time batch message-tags draft/chathistory draft/event-playback
if len(msg.Params) >= 3 {
sub := strings.ToUpper(msg.Params[1])
caps := strings.TrimPrefix(msg.Params[2], ":")
switch sub {
case "ACK":
if strings.Contains(caps, "draft/event-playback") {
eventPlayback = true
if c.Logger != nil {
c.Logger.Info("cap enabled", "cap", "draft/event-playback")
}
}
case "NEW":
if strings.Contains(caps, "draft/event-playback") && !eventPlayback {
_ = write("CAP REQ :draft/event-playback")
}
}
}
case "PING":
if len(msg.Params) > 0 {
_ = write("PONG :" + msg.Params[len(msg.Params)-1])
}
case "001": // welcome
connected = true
c.setConnected(true)
if c.Logger != nil {
c.Logger.Info("connected", "server", c.Server, "auth", "raw")
}
for _, ch := range c.Channels {
_ = write("JOIN " + ch)
if c.Logger != nil {
c.Logger.Info("join requested", "channel", ch)
}
}
case "JOIN":
if len(msg.Params) == 0 {
break
}
ch := msg.Params[0]
nick := nickFromPrefix(msg.Prefix)
if c.Logger != nil {
c.Logger.Info("joined", "channel", ch, "nick", nick)
}
if nick == c.Nick && !selfJoined[ch] {
selfJoined[ch] = true
if !eventPlayback && c.BackfillLatest > 0 {
// Use last seen timestamp if available
since := time.Now().Add(-24 * time.Hour) // default fallback
if c.Store != nil {
if t, ok, err := c.Store.LastMessageTime(ctx, ch); err == nil && ok {
since = t
}
}
// ISO-8601 / RFC3339 format
ts := since.UTC().Format(time.RFC3339Nano)
_ = write(fmt.Sprintf("CHATHISTORY LATEST %s timestamp=%s %d", ch, ts, c.BackfillLatest))
}
}
case "PRIVMSG":
if len(msg.Params) < 1 {
continue
}
target := msg.Params[0]
var text string
if len(msg.Params) >= 2 {
text = msg.Params[1]
} else if msg.Trailing != "" {
text = msg.Trailing
} else {
continue
}
at := time.Now()
if ts, ok := tags["time"]; ok && ts != "" {
if t, e := time.Parse(time.RFC3339Nano, ts); e == nil {
at = t
} else if t2, e2 := time.Parse(time.RFC3339, ts); e2 == nil {
at = t2
}
}
msgid := tags["soju-msgid"]
if msgid == "" {
msgid = tags["msgid"]
}
if c.OnPrivmsg != nil {
c.OnPrivmsg(target, nickFromPrefix(msg.Prefix), text, msgid, at)
}
case "ERROR":
c.setConnected(false)
return fmt.Errorf("server closed: %s", strings.Join(msg.Params, " "))
}
_ = connected
}
}
func nickFromPrefix(pfx *irc.Prefix) string {
if pfx == nil {
return ""
}
if pfx.Name != "" {
return pfx.Name
}
if pfx.User != "" {
return pfx.User
}
if pfx.Host != "" {
return pfx.Host
}
return ""
}
func parseTags(s string) map[string]string {
out := make(map[string]string)
if s == "" {
return out
}
parts := strings.Split(s, ";")
for _, p := range parts {
if p == "" {
continue
}
kv := strings.SplitN(p, "=", 2)
key := kv[0]
val := ""
if len(kv) == 2 {
val = kv[1]
}
// No unescape implemented; good enough for 'time' and 'batch'
out[key] = val
}
return out
}