diff --git a/cmd/ultid/main.go b/cmd/ultid/main.go index bd9b5ab..c619f1b 100644 --- a/cmd/ultid/main.go +++ b/cmd/ultid/main.go @@ -135,7 +135,7 @@ func main() { } // WebSocket hub - hub := realtime.NewHub() + hub := realtime.NewHub(verifier, pool) healthChecker := observability.NewHealthChecker(cfg, pool, rdb) rulesEngine := rules.NewEngineWithWebhooks(pool, webhooks.NewExecutor(pool)) diff --git a/internal/mail/imap/pipeline.go b/internal/mail/imap/pipeline.go index ada0184..672fe91 100644 --- a/internal/mail/imap/pipeline.go +++ b/internal/mail/imap/pipeline.go @@ -36,13 +36,7 @@ func newSyncPipeline(db *pgxpool.Pool, rulesEngine *rules.Engine, hub *realtime. func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) { if ev.kind == "deleted" { - p.broadcast(ev, realtime.Event{ - Type: "mail.deleted", - Payload: map[string]any{ - "message_id": ev.messageID, - "account_id": ev.accountID, - }, - }) + p.broadcast(ev, realtime.NewMailDeletedEvent(ev.messageID, ev.accountID)) return } @@ -55,17 +49,11 @@ func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) { } } - eventType := "mail.updated" + event := realtime.NewMailUpdatedEvent(ev.messageID, ev.accountID) if ev.kind == "created" { - eventType = "mail.created" + event = realtime.NewMailCreatedEvent(ev.messageID, ev.accountID) } - p.broadcast(ev, realtime.Event{ - Type: eventType, - Payload: map[string]any{ - "message_id": ev.messageID, - "account_id": ev.accountID, - }, - }) + p.broadcast(ev, event) } func (p *syncPipeline) broadcast(ev postSyncEvent, event realtime.Event) { diff --git a/internal/realtime/events.go b/internal/realtime/events.go new file mode 100644 index 0000000..6d1b3f5 --- /dev/null +++ b/internal/realtime/events.go @@ -0,0 +1,79 @@ +package realtime + +// WebSocket event type names. +const ( + TypeMailCreated = "mail.created" + TypeMailUpdated = "mail.updated" + TypeMailDeleted = "mail.deleted" + TypeOutboxUpdated = "outbox.updated" + TypeContactUpdated = "contact.updated" + + TypeWSPing = "ws.ping" + TypeWSPong = "ws.pong" +) + +// MailEventPayload is the payload for mail.created, mail.updated, and mail.deleted. +type MailEventPayload struct { + MessageID string `json:"message_id"` + AccountID string `json:"account_id"` +} + +// OutboxEventPayload is the payload for outbox.updated. +type OutboxEventPayload struct { + AccountID string `json:"account_id"` +} + +// ContactEventPayload is the payload for contact.updated. +type ContactEventPayload struct { + ContactID string `json:"contact_id"` + AccountID string `json:"account_id,omitempty"` +} + +func NewMailCreatedEvent(messageID, accountID string) Event { + return Event{ + Type: TypeMailCreated, + Payload: MailEventPayload{ + MessageID: messageID, + AccountID: accountID, + }, + } +} + +func NewMailUpdatedEvent(messageID, accountID string) Event { + return Event{ + Type: TypeMailUpdated, + Payload: MailEventPayload{ + MessageID: messageID, + AccountID: accountID, + }, + } +} + +func NewMailDeletedEvent(messageID, accountID string) Event { + return Event{ + Type: TypeMailDeleted, + Payload: MailEventPayload{ + MessageID: messageID, + AccountID: accountID, + }, + } +} + +func NewOutboxUpdatedEvent(accountID string) Event { + return Event{ + Type: TypeOutboxUpdated, + Payload: OutboxEventPayload{ + AccountID: accountID, + }, + } +} + +func NewContactUpdatedEvent(contactID, accountID string) Event { + return Event{ + Type: TypeContactUpdated, + Payload: ContactEventPayload{ + ContactID: contactID, + AccountID: accountID, + }, + } +} diff --git a/internal/realtime/events_test.go b/internal/realtime/events_test.go new file mode 100644 index 0000000..b671ab9 --- /dev/null +++ b/internal/realtime/events_test.go @@ -0,0 +1,77 @@ +package realtime + +import ( + "encoding/json" + "testing" +) + +func TestEventTypeConstants(t *testing.T) { + tests := map[string]string{ + "mail.created": TypeMailCreated, + "mail.updated": TypeMailUpdated, + "mail.deleted": TypeMailDeleted, + "outbox.updated": TypeOutboxUpdated, + "contact.updated": TypeContactUpdated, + "ws.ping": TypeWSPing, + "ws.pong": TypeWSPong, + } + for want, got := range tests { + if got != want { + t.Errorf("constant = %q, want %q", got, want) + } + } +} + +func TestMailEventConstructorsJSON(t *testing.T) { + tests := []struct { + name string + event Event + want string + }{ + { + name: "created", + event: NewMailCreatedEvent("msg-1", "acct-1"), + want: `{"type":"mail.created","payload":{"message_id":"msg-1","account_id":"acct-1"}}`, + }, + { + name: "updated", + event: NewMailUpdatedEvent("msg-2", "acct-2"), + want: `{"type":"mail.updated","payload":{"message_id":"msg-2","account_id":"acct-2"}}`, + }, + { + name: "deleted", + event: NewMailDeletedEvent("msg-3", "acct-3"), + want: `{"type":"mail.deleted","payload":{"message_id":"msg-3","account_id":"acct-3"}}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := json.Marshal(tt.event) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + if string(got) != tt.want { + t.Fatalf("json = %s, want %s", got, tt.want) + } + }) + } +} + +func TestOutboxAndContactEventConstructorsJSON(t *testing.T) { + outbox, err := json.Marshal(NewOutboxUpdatedEvent("acct-1")) + if err != nil { + t.Fatalf("json.Marshal(outbox) error = %v", err) + } + if string(outbox) != `{"type":"outbox.updated","payload":{"account_id":"acct-1"}}` { + t.Fatalf("outbox json = %s", outbox) + } + + contact, err := json.Marshal(NewContactUpdatedEvent("contact-1", "acct-1")) + if err != nil { + t.Fatalf("json.Marshal(contact) error = %v", err) + } + if string(contact) != `{"type":"contact.updated","payload":{"contact_id":"contact-1","account_id":"acct-1"}}` { + t.Fatalf("contact json = %s", contact) + } +} diff --git a/internal/realtime/hub.go b/internal/realtime/hub.go index ddaf079..7bbc4eb 100644 --- a/internal/realtime/hub.go +++ b/internal/realtime/hub.go @@ -2,43 +2,92 @@ package realtime import ( "context" + "crypto/sha256" + "encoding/hex" + "errors" "log/slog" "net/http" + "strconv" + "strings" "sync" + "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/auth" + "github.com/ultisuite/ulti-backend/internal/users" ) type Event struct { Type string `json:"type"` Payload any `json:"payload"` + Seq uint64 `json:"seq,omitempty"` } type Hub struct { - mu sync.RWMutex - clients map[string]map[*conn]struct{} // userID -> connections - logger *slog.Logger + mu sync.RWMutex + + clients map[string]map[*conn]struct{} // userID -> active connections + sessionCounts map[string]map[string]int // userID -> session fingerprint -> count + history map[string][]Event // userID -> recent ordered events + nextSeq uint64 + + maxConnectionsPerUser int + maxConnectionsPerSess int + replayBufferSizePerUser int + + logger *slog.Logger + verifier *auth.Verifier + db *pgxpool.Pool } type conn struct { - ws *websocket.Conn - userID string + ws *websocket.Conn + userID string + sessionKey string + writeMu sync.Mutex } -func NewHub() *Hub { +const ( + defaultMaxConnectionsPerUser = 10 + defaultMaxConnectionsPerSess = 3 + defaultReplayBufferSizePerUser = 500 + heartbeatInterval = 25 * time.Second + heartbeatTimeout = 90 * time.Second +) + +var ( + errVerifierUnavailable = errors.New("authentication unavailable") + errUserStoreUnavailable = errors.New("user store unavailable") + errMissingToken = errors.New("missing token") + errInvalidToken = errors.New("invalid token") + errUserProvisioning = errors.New("failed to provision user") +) + +func NewHub(verifier *auth.Verifier, db *pgxpool.Pool) *Hub { return &Hub{ - clients: make(map[string]map[*conn]struct{}), - logger: slog.Default().With("component", "ws-hub"), + clients: make(map[string]map[*conn]struct{}), + sessionCounts: make(map[string]map[string]int), + history: make(map[string][]Event), + maxConnectionsPerUser: defaultMaxConnectionsPerUser, + maxConnectionsPerSess: defaultMaxConnectionsPerSess, + replayBufferSizePerUser: defaultReplayBufferSizePerUser, + logger: slog.Default().With("component", "ws-hub"), + verifier: verifier, + db: db, } } func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) { - userID := r.URL.Query().Get("user_id") - if userID == "" { - http.Error(w, "missing user_id", http.StatusBadRequest) + userID, token, err := h.authenticate(r) + if err != nil { + h.writeAuthError(w, err) return } + sessionKey := tokenSessionKey(token) + since := parseSinceCursor(r) ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, @@ -48,50 +97,308 @@ func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) { return } - c := &conn{ws: ws, userID: userID} - h.register(c) + c := &conn{ws: ws, userID: userID, sessionKey: sessionKey} + if err := h.register(c); err != nil { + h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err) + _ = ws.Close(websocket.StatusPolicyViolation, err.Error()) + return + } defer h.unregister(c) + if err := h.replaySince(c, since); err != nil { + h.logger.Warn("ws replay failed", "user_id", userID, "error", err) + return + } + ctx := r.Context() + lastPong := newLastPongTracker() + go h.runHeartbeat(ctx, c, lastPong) + for { - _, _, err := ws.Read(ctx) - if err != nil { + var incoming Event + if err := wsjson.Read(ctx, ws, &incoming); err != nil { + break + } + + switch incoming.Type { + case TypeWSPong: + lastPong.touch() + case TypeWSPing: + if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil { + return + } + lastPong.touch() + } + + if isHeartbeatTimedOut(lastPong) { + _ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout") break } } } -func (h *Hub) Broadcast(userID string, event Event) { - h.mu.RLock() - conns := h.clients[userID] - h.mu.RUnlock() +func (h *Hub) authenticate(r *http.Request) (string, string, error) { + if h.verifier == nil { + return "", "", errVerifierUnavailable + } + if h.db == nil { + return "", "", errUserStoreUnavailable + } - for c := range conns { - if err := wsjson.Write(context.Background(), c.ws, event); err != nil { - h.logger.Error("ws write", "error", err, "user_id", userID) - go h.unregister(c) + token, ok := extractToken(r) + if !ok { + return "", "", errMissingToken + } + claims, err := h.verifier.Verify(r.Context(), token) + if err != nil { + return "", "", errInvalidToken + } + userID, err := users.EnsureUser(r.Context(), h.db, claims) + if err != nil { + return "", "", errUserProvisioning + } + return userID, token, nil +} + +func (h *Hub) writeAuthError(w http.ResponseWriter, err error) { + switch err { + case errVerifierUnavailable, errUserStoreUnavailable: + http.Error(w, err.Error(), http.StatusServiceUnavailable) + case errMissingToken, errInvalidToken: + http.Error(w, err.Error(), http.StatusUnauthorized) + default: + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +func extractToken(r *http.Request) (string, bool) { + if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" { + if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" { + return strings.TrimSpace(token), true + } + return "", false + } + for _, key := range []string{"token", "access_token"} { + if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" { + return token, true + } + } + return "", false +} + +func tokenSessionKey(token string) string { + sum := sha256.Sum256([]byte(token)) + return hex.EncodeToString(sum[:8]) +} + +func parseSinceCursor(r *http.Request) uint64 { + raw := strings.TrimSpace(r.URL.Query().Get("since")) + if raw == "" { + return 0 + } + v, err := strconv.ParseUint(raw, 10, 64) + if err != nil { + return 0 + } + return v +} + +func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) { + ticker := time.NewTicker(heartbeatInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if isHeartbeatTimedOut(tracker) { + _ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout") + return + } + if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil { + return + } } } } -func (h *Hub) register(c *conn) { +type lastPongTracker struct { + mu sync.RWMutex + ts time.Time +} + +func newLastPongTracker() *lastPongTracker { + return &lastPongTracker{ts: time.Now()} +} + +func (t *lastPongTracker) touch() { + t.mu.Lock() + defer t.mu.Unlock() + t.ts = time.Now() +} + +func (t *lastPongTracker) value() time.Time { + t.mu.RLock() + defer t.mu.RUnlock() + return t.ts +} + +func isHeartbeatTimedOut(t *lastPongTracker) bool { + return time.Since(t.value()) > heartbeatTimeout +} + +func (h *Hub) replaySince(c *conn, since uint64) error { + if since == 0 { + return nil + } + + h.mu.RLock() + history := append([]Event(nil), h.history[c.userID]...) + h.mu.RUnlock() + + for _, event := range history { + if event.Seq <= since { + continue + } + if err := h.writeConn(context.Background(), c, event); err != nil { + return err + } + } + return nil +} + +func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + return wsjson.Write(ctx, c.ws, event) +} + +func (h *Hub) nextSequenceLocked() uint64 { + h.nextSeq++ + return h.nextSeq +} + +func (h *Hub) appendHistoryLocked(userID string, event Event) { + history := append(h.history[userID], event) + if len(history) > h.replayBufferSizePerUser { + history = history[len(history)-h.replayBufferSizePerUser:] + } + h.history[userID] = history +} + +func (h *Hub) snapshotConnectionsLocked(userID string) []*conn { + conns := h.clients[userID] + if len(conns) == 0 { + return nil + } + out := make([]*conn, 0, len(conns)) + for c := range conns { + out = append(out, c) + } + return out +} + +func (h *Hub) canRegisterLocked(c *conn) error { + if len(h.clients[c.userID]) >= h.maxConnectionsPerUser { + return errors.New("too many websocket connections for user") + } + if h.sessionCounts[c.userID] == nil { + return nil + } + if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess { + return errors.New("too many websocket connections for session") + } + return nil +} + +func (h *Hub) register(c *conn) error { h.mu.Lock() defer h.mu.Unlock() + if h.clients[c.userID] == nil { h.clients[c.userID] = make(map[*conn]struct{}) } + if err := h.canRegisterLocked(c); err != nil { + return err + } + h.clients[c.userID][c] = struct{}{} - h.logger.Info("ws connected", "user_id", c.userID) + if h.sessionCounts[c.userID] == nil { + h.sessionCounts[c.userID] = make(map[string]int) + } + h.sessionCounts[c.userID][c.sessionKey]++ + + h.logger.Info("ws connected", + "user_id", c.userID, + "session", c.sessionKey, + "user_conns", len(h.clients[c.userID]), + "session_conns", h.sessionCounts[c.userID][c.sessionKey], + ) + return nil } func (h *Hub) unregister(c *conn) { h.mu.Lock() defer h.mu.Unlock() + if conns, ok := h.clients[c.userID]; ok { delete(conns, c) if len(conns) == 0 { delete(h.clients, c.userID) } } - c.ws.Close(websocket.StatusNormalClosure, "") + + if sessions := h.sessionCounts[c.userID]; sessions != nil { + next := sessions[c.sessionKey] - 1 + if next <= 0 { + delete(sessions, c.sessionKey) + } else { + sessions[c.sessionKey] = next + } + if len(sessions) == 0 { + delete(h.sessionCounts, c.userID) + } + } + + _ = c.ws.Close(websocket.StatusNormalClosure, "") +} + +func (h *Hub) Broadcast(userID string, event Event) { + h.mu.Lock() + event.Seq = h.nextSequenceLocked() + h.appendHistoryLocked(userID, event) + conns := h.snapshotConnectionsLocked(userID) + h.mu.Unlock() + + for _, c := range conns { + if err := h.writeConn(context.Background(), c, event); err != nil { + h.logger.Error("ws write", "error", err, "user_id", userID) + go h.unregister(c) + } + } +} + +func (h *Hub) HistoryHead(userID string) uint64 { + h.mu.RLock() + defer h.mu.RUnlock() + history := h.history[userID] + if len(history) == 0 { + return 0 + } + return history[len(history)-1].Seq +} + +func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) { + h.mu.Lock() + defer h.mu.Unlock() + if maxUser > 0 { + h.maxConnectionsPerUser = maxUser + } + if maxSession > 0 { + h.maxConnectionsPerSess = maxSession + } + if replaySize > 0 { + h.replayBufferSizePerUser = replaySize + } } diff --git a/internal/realtime/hub_test.go b/internal/realtime/hub_test.go new file mode 100644 index 0000000..432f05e --- /dev/null +++ b/internal/realtime/hub_test.go @@ -0,0 +1,98 @@ +package realtime + +import ( + "net/http/httptest" + "testing" +) + +func TestExtractToken(t *testing.T) { + req := httptest.NewRequest("GET", "/ws", nil) + req.Header.Set("Authorization", "Bearer abc") + token, ok := extractToken(req) + if !ok || token != "abc" { + t.Fatalf("extractToken auth header = (%q, %v), want (abc, true)", token, ok) + } + + req = httptest.NewRequest("GET", "/ws?token=q1", nil) + token, ok = extractToken(req) + if !ok || token != "q1" { + t.Fatalf("extractToken token query = (%q, %v), want (q1, true)", token, ok) + } + + req = httptest.NewRequest("GET", "/ws?access_token=q2", nil) + token, ok = extractToken(req) + if !ok || token != "q2" { + t.Fatalf("extractToken access_token query = (%q, %v), want (q2, true)", token, ok) + } +} + +func TestParseSinceCursor(t *testing.T) { + req := httptest.NewRequest("GET", "/ws?since=42", nil) + if got := parseSinceCursor(req); got != 42 { + t.Fatalf("parseSinceCursor() = %d, want 42", got) + } + + req = httptest.NewRequest("GET", "/ws?since=bad", nil) + if got := parseSinceCursor(req); got != 0 { + t.Fatalf("parseSinceCursor() invalid = %d, want 0", got) + } +} + +func TestTokenSessionKeyStable(t *testing.T) { + a := tokenSessionKey("same-token") + b := tokenSessionKey("same-token") + c := tokenSessionKey("other-token") + if a != b { + t.Fatalf("session keys for same token differ: %q != %q", a, b) + } + if a == c { + t.Fatalf("session keys for different tokens should differ") + } +} + +func TestRegisterRespectsLimits(t *testing.T) { + h := NewHub(nil, nil) + h.SetLimits(2, 1, 10) + + c1 := &conn{userID: "u1", sessionKey: "s1"} + c2 := &conn{userID: "u1", sessionKey: "s2"} + c3 := &conn{userID: "u1", sessionKey: "s3"} + + if err := h.register(c1); err != nil { + t.Fatalf("register c1 error = %v", err) + } + if err := h.register(c2); err != nil { + t.Fatalf("register c2 error = %v", err) + } + if err := h.register(c3); err == nil { + t.Fatalf("register c3 expected user limit error") + } + + h2 := NewHub(nil, nil) + h2.SetLimits(5, 1, 10) + if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err != nil { + t.Fatalf("register first session conn error = %v", err) + } + if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err == nil { + t.Fatalf("register second session conn expected session limit error") + } +} + +func TestBroadcastAssignsSequenceAndKeepsReplayBuffer(t *testing.T) { + h := NewHub(nil, nil) + h.SetLimits(10, 10, 2) + + h.Broadcast("u1", NewMailCreatedEvent("m1", "a1")) + h.Broadcast("u1", NewMailUpdatedEvent("m2", "a1")) + h.Broadcast("u1", NewMailDeletedEvent("m3", "a1")) + + if head := h.HistoryHead("u1"); head != 3 { + t.Fatalf("HistoryHead() = %d, want 3", head) + } + if len(h.history["u1"]) != 2 { + t.Fatalf("history length = %d, want 2", len(h.history["u1"])) + } + if h.history["u1"][0].Seq != 2 || h.history["u1"][1].Seq != 3 { + t.Fatalf("history seqs = [%d, %d], want [2, 3]", h.history["u1"][0].Seq, h.history["u1"][1].Seq) + } +} diff --git a/project-plan/checklist-execution.md b/project-plan/checklist-execution.md index bf1695e..b3cde0a 100644 --- a/project-plan/checklist-execution.md +++ b/project-plan/checklist-execution.md @@ -120,11 +120,11 @@ Objectif: transformer état actuel (partiellement implémenté) vers produit fon ### 2.5 Realtime (`/ws`) -- [ ] Remplacer `user_id` query param non sûr par auth token WS. -- [ ] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...). -- [ ] Ajouter heartbeat/ping/pong. -- [ ] Gérer reconnexion client + rattrapage delta. -- [ ] Limiter connexions par user/session. +- [x] Remplacer `user_id` query param non sûr par auth token WS. +- [x] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...). +- [x] Ajouter heartbeat/ping/pong. +- [x] Gérer reconnexion client + rattrapage delta. +- [x] Limiter connexions par user/session. ### 2.6 Search