310 lines
6.7 KiB
Go
310 lines
6.7 KiB
Go
|
|
package soju
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bufio"
|
||
|
|
"context"
|
||
|
|
"crypto/tls"
|
||
|
|
"fmt"
|
||
|
|
"log/slog"
|
||
|
|
"net"
|
||
|
|
"strconv"
|
||
|
|
"strings"
|
||
|
|
"sync/atomic"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"sojuboy/internal/store"
|
||
|
|
|
||
|
|
irc "github.com/sorcix/irc"
|
||
|
|
)
|
||
|
|
|
||
|
|
type RawClient struct {
|
||
|
|
Server string
|
||
|
|
Port int
|
||
|
|
UseTLS bool
|
||
|
|
Nick string
|
||
|
|
Username string // full identity: username/network@client
|
||
|
|
Realname string
|
||
|
|
Password string // PASS <password>
|
||
|
|
Channels []string
|
||
|
|
|
||
|
|
// Number of messages to fetch via CHATHISTORY LATEST per channel after join.
|
||
|
|
BackfillLatest int
|
||
|
|
|
||
|
|
OnPrivmsg func(channel, author, text, msgid string, at time.Time)
|
||
|
|
|
||
|
|
Logger *slog.Logger
|
||
|
|
Debug bool
|
||
|
|
|
||
|
|
// Store is used to compute last-seen timestamp for CHATHISTORY.
|
||
|
|
Store *store.Store
|
||
|
|
|
||
|
|
// Readiness/metrics hooks
|
||
|
|
ConnectedGauge *int64 // 0/1
|
||
|
|
IsReady *int32 // 0/1 atomic flag
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *RawClient) setConnected(v bool) {
|
||
|
|
if c.ConnectedGauge != nil {
|
||
|
|
if v {
|
||
|
|
atomic.StoreInt64(c.ConnectedGauge, 1)
|
||
|
|
} else {
|
||
|
|
atomic.StoreInt64(c.ConnectedGauge, 0)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if c.IsReady != nil {
|
||
|
|
if v {
|
||
|
|
atomic.StoreInt32(c.IsReady, 1)
|
||
|
|
} else {
|
||
|
|
atomic.StoreInt32(c.IsReady, 0)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *RawClient) Run(ctx context.Context) error {
|
||
|
|
backoff := time.Second
|
||
|
|
for {
|
||
|
|
if err := c.runOnce(ctx); err != nil {
|
||
|
|
if ctx.Err() != nil {
|
||
|
|
return ctx.Err()
|
||
|
|
}
|
||
|
|
if c.Logger != nil {
|
||
|
|
c.Logger.Error("raw soju client stopped", "err", err)
|
||
|
|
}
|
||
|
|
time.Sleep(backoff)
|
||
|
|
if backoff < 30*time.Second {
|
||
|
|
backoff *= 2
|
||
|
|
}
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *RawClient) runOnce(ctx context.Context) error {
|
||
|
|
address := net.JoinHostPort(c.Server, strconv.Itoa(c.Port))
|
||
|
|
var conn net.Conn
|
||
|
|
var err error
|
||
|
|
if c.UseTLS {
|
||
|
|
tlsCfg := &tls.Config{ServerName: c.Server, MinVersion: tls.VersionTLS12}
|
||
|
|
conn, err = tls.Dial("tcp", address, tlsCfg)
|
||
|
|
} else {
|
||
|
|
conn, err = net.Dial("tcp", address)
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
defer conn.Close()
|
||
|
|
|
||
|
|
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
|
||
|
|
|
||
|
|
write := func(line string) error {
|
||
|
|
out := line
|
||
|
|
if strings.HasPrefix(strings.ToUpper(line), "PASS ") {
|
||
|
|
out = "PASS ********"
|
||
|
|
}
|
||
|
|
if c.Debug && c.Logger != nil {
|
||
|
|
c.Logger.Debug("irc>", "line", out)
|
||
|
|
}
|
||
|
|
if _, err := rw.WriteString(line + "\r\n"); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
return rw.Flush()
|
||
|
|
}
|
||
|
|
|
||
|
|
// Request capabilities needed for chathistory and accurate timestamps.
|
||
|
|
_ = write("CAP LS 302")
|
||
|
|
_ = write("CAP REQ :server-time batch message-tags draft/chathistory draft/event-playback echo-message cap-notify")
|
||
|
|
_ = write("CAP END")
|
||
|
|
|
||
|
|
// Authenticate with PASS/NICK/USER
|
||
|
|
if c.Password != "" {
|
||
|
|
if err := write("PASS " + c.Password); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if err := write("NICK " + c.Nick); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
user := c.Username
|
||
|
|
if user == "" {
|
||
|
|
user = c.Nick
|
||
|
|
}
|
||
|
|
host := c.Server
|
||
|
|
if err := write(fmt.Sprintf("USER %s %s %s :%s", user, user, host, c.Realname)); err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
|
||
|
|
// Reader loop
|
||
|
|
connected := false
|
||
|
|
eventPlayback := false
|
||
|
|
selfJoined := map[string]bool{}
|
||
|
|
|
||
|
|
for {
|
||
|
|
select {
|
||
|
|
case <-ctx.Done():
|
||
|
|
return ctx.Err()
|
||
|
|
default:
|
||
|
|
}
|
||
|
|
rawLine, err := rw.ReadString('\n')
|
||
|
|
if err != nil {
|
||
|
|
return err
|
||
|
|
}
|
||
|
|
rawLine = strings.TrimRight(rawLine, "\r\n")
|
||
|
|
if rawLine == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if c.Debug && c.Logger != nil {
|
||
|
|
c.Logger.Debug("irc<", "line", rawLine)
|
||
|
|
}
|
||
|
|
|
||
|
|
// Parse IRCv3 tags if present
|
||
|
|
var tags map[string]string
|
||
|
|
line := rawLine
|
||
|
|
if strings.HasPrefix(line, "@") {
|
||
|
|
sp := strings.IndexByte(line, ' ')
|
||
|
|
if sp > 0 {
|
||
|
|
tags = parseTags(line[1:sp])
|
||
|
|
line = strings.TrimSpace(line[sp+1:])
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
msg := irc.ParseMessage(line)
|
||
|
|
if msg == nil {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
cmd := strings.ToUpper(msg.Command)
|
||
|
|
switch cmd {
|
||
|
|
case "CAP":
|
||
|
|
// Examples: :bnc CAP * ACK :server-time batch message-tags draft/chathistory draft/event-playback
|
||
|
|
if len(msg.Params) >= 3 {
|
||
|
|
sub := strings.ToUpper(msg.Params[1])
|
||
|
|
caps := strings.TrimPrefix(msg.Params[2], ":")
|
||
|
|
switch sub {
|
||
|
|
case "ACK":
|
||
|
|
if strings.Contains(caps, "draft/event-playback") {
|
||
|
|
eventPlayback = true
|
||
|
|
if c.Logger != nil {
|
||
|
|
c.Logger.Info("cap enabled", "cap", "draft/event-playback")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case "NEW":
|
||
|
|
if strings.Contains(caps, "draft/event-playback") && !eventPlayback {
|
||
|
|
_ = write("CAP REQ :draft/event-playback")
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case "PING":
|
||
|
|
if len(msg.Params) > 0 {
|
||
|
|
_ = write("PONG :" + msg.Params[len(msg.Params)-1])
|
||
|
|
}
|
||
|
|
case "001": // welcome
|
||
|
|
connected = true
|
||
|
|
c.setConnected(true)
|
||
|
|
if c.Logger != nil {
|
||
|
|
c.Logger.Info("connected", "server", c.Server, "auth", "raw")
|
||
|
|
}
|
||
|
|
for _, ch := range c.Channels {
|
||
|
|
_ = write("JOIN " + ch)
|
||
|
|
if c.Logger != nil {
|
||
|
|
c.Logger.Info("join requested", "channel", ch)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case "JOIN":
|
||
|
|
if len(msg.Params) == 0 {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
ch := msg.Params[0]
|
||
|
|
nick := nickFromPrefix(msg.Prefix)
|
||
|
|
if c.Logger != nil {
|
||
|
|
c.Logger.Info("joined", "channel", ch, "nick", nick)
|
||
|
|
}
|
||
|
|
if nick == c.Nick && !selfJoined[ch] {
|
||
|
|
selfJoined[ch] = true
|
||
|
|
if !eventPlayback && c.BackfillLatest > 0 {
|
||
|
|
// Use last seen timestamp if available
|
||
|
|
since := time.Now().Add(-24 * time.Hour) // default fallback
|
||
|
|
if c.Store != nil {
|
||
|
|
if t, ok, err := c.Store.LastMessageTime(ctx, ch); err == nil && ok {
|
||
|
|
since = t
|
||
|
|
}
|
||
|
|
}
|
||
|
|
// ISO-8601 / RFC3339 format
|
||
|
|
ts := since.UTC().Format(time.RFC3339Nano)
|
||
|
|
_ = write(fmt.Sprintf("CHATHISTORY LATEST %s timestamp=%s %d", ch, ts, c.BackfillLatest))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
case "PRIVMSG":
|
||
|
|
if len(msg.Params) < 1 {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
target := msg.Params[0]
|
||
|
|
var text string
|
||
|
|
if len(msg.Params) >= 2 {
|
||
|
|
text = msg.Params[1]
|
||
|
|
} else if msg.Trailing != "" {
|
||
|
|
text = msg.Trailing
|
||
|
|
} else {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
at := time.Now()
|
||
|
|
if ts, ok := tags["time"]; ok && ts != "" {
|
||
|
|
if t, e := time.Parse(time.RFC3339Nano, ts); e == nil {
|
||
|
|
at = t
|
||
|
|
} else if t2, e2 := time.Parse(time.RFC3339, ts); e2 == nil {
|
||
|
|
at = t2
|
||
|
|
}
|
||
|
|
}
|
||
|
|
msgid := tags["soju-msgid"]
|
||
|
|
if msgid == "" {
|
||
|
|
msgid = tags["msgid"]
|
||
|
|
}
|
||
|
|
if c.OnPrivmsg != nil {
|
||
|
|
c.OnPrivmsg(target, nickFromPrefix(msg.Prefix), text, msgid, at)
|
||
|
|
}
|
||
|
|
case "ERROR":
|
||
|
|
c.setConnected(false)
|
||
|
|
return fmt.Errorf("server closed: %s", strings.Join(msg.Params, " "))
|
||
|
|
}
|
||
|
|
|
||
|
|
_ = connected
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func nickFromPrefix(pfx *irc.Prefix) string {
|
||
|
|
if pfx == nil {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
if pfx.Name != "" {
|
||
|
|
return pfx.Name
|
||
|
|
}
|
||
|
|
if pfx.User != "" {
|
||
|
|
return pfx.User
|
||
|
|
}
|
||
|
|
if pfx.Host != "" {
|
||
|
|
return pfx.Host
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func parseTags(s string) map[string]string {
|
||
|
|
out := make(map[string]string)
|
||
|
|
if s == "" {
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
parts := strings.Split(s, ";")
|
||
|
|
for _, p := range parts {
|
||
|
|
if p == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
kv := strings.SplitN(p, "=", 2)
|
||
|
|
key := kv[0]
|
||
|
|
val := ""
|
||
|
|
if len(kv) == 2 {
|
||
|
|
val = kv[1]
|
||
|
|
}
|
||
|
|
// No unescape implemented; good enough for 'time' and 'batch'
|
||
|
|
out[key] = val
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|