Enhance WebSocket hub with authentication and event handling improvements

- Updated the WebSocket hub to replace the insecure `user_id` query parameter with an authentication token for secure connections.
- Introduced typed events for mail operations (created, updated, deleted) to streamline event handling.
- Implemented heartbeat functionality (ping/pong) to maintain connection health.
- Enhanced client reconnection logic and delta replay for improved user experience.
- Added limits on connections per user/session to prevent abuse and ensure stability.
This commit is contained in:
R3D347HR4Y 2026-05-22 18:09:02 +02:00
parent 1b9a3394e2
commit a2e17c5b6c
7 changed files with 597 additions and 48 deletions

View File

@ -135,7 +135,7 @@ func main() {
} }
// WebSocket hub // WebSocket hub
hub := realtime.NewHub() hub := realtime.NewHub(verifier, pool)
healthChecker := observability.NewHealthChecker(cfg, pool, rdb) healthChecker := observability.NewHealthChecker(cfg, pool, rdb)
rulesEngine := rules.NewEngineWithWebhooks(pool, webhooks.NewExecutor(pool)) rulesEngine := rules.NewEngineWithWebhooks(pool, webhooks.NewExecutor(pool))

View File

@ -36,13 +36,7 @@ func newSyncPipeline(db *pgxpool.Pool, rulesEngine *rules.Engine, hub *realtime.
func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) { func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) {
if ev.kind == "deleted" { if ev.kind == "deleted" {
p.broadcast(ev, realtime.Event{ p.broadcast(ev, realtime.NewMailDeletedEvent(ev.messageID, ev.accountID))
Type: "mail.deleted",
Payload: map[string]any{
"message_id": ev.messageID,
"account_id": ev.accountID,
},
})
return return
} }
@ -55,17 +49,11 @@ func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) {
} }
} }
eventType := "mail.updated" event := realtime.NewMailUpdatedEvent(ev.messageID, ev.accountID)
if ev.kind == "created" { if ev.kind == "created" {
eventType = "mail.created" event = realtime.NewMailCreatedEvent(ev.messageID, ev.accountID)
} }
p.broadcast(ev, realtime.Event{ p.broadcast(ev, event)
Type: eventType,
Payload: map[string]any{
"message_id": ev.messageID,
"account_id": ev.accountID,
},
})
} }
func (p *syncPipeline) broadcast(ev postSyncEvent, event realtime.Event) { func (p *syncPipeline) broadcast(ev postSyncEvent, event realtime.Event) {

View File

@ -0,0 +1,79 @@
package realtime
// WebSocket event type names.
const (
TypeMailCreated = "mail.created"
TypeMailUpdated = "mail.updated"
TypeMailDeleted = "mail.deleted"
TypeOutboxUpdated = "outbox.updated"
TypeContactUpdated = "contact.updated"
TypeWSPing = "ws.ping"
TypeWSPong = "ws.pong"
)
// MailEventPayload is the payload for mail.created, mail.updated, and mail.deleted.
type MailEventPayload struct {
MessageID string `json:"message_id"`
AccountID string `json:"account_id"`
}
// OutboxEventPayload is the payload for outbox.updated.
type OutboxEventPayload struct {
AccountID string `json:"account_id"`
}
// ContactEventPayload is the payload for contact.updated.
type ContactEventPayload struct {
ContactID string `json:"contact_id"`
AccountID string `json:"account_id,omitempty"`
}
func NewMailCreatedEvent(messageID, accountID string) Event {
return Event{
Type: TypeMailCreated,
Payload: MailEventPayload{
MessageID: messageID,
AccountID: accountID,
},
}
}
func NewMailUpdatedEvent(messageID, accountID string) Event {
return Event{
Type: TypeMailUpdated,
Payload: MailEventPayload{
MessageID: messageID,
AccountID: accountID,
},
}
}
func NewMailDeletedEvent(messageID, accountID string) Event {
return Event{
Type: TypeMailDeleted,
Payload: MailEventPayload{
MessageID: messageID,
AccountID: accountID,
},
}
}
func NewOutboxUpdatedEvent(accountID string) Event {
return Event{
Type: TypeOutboxUpdated,
Payload: OutboxEventPayload{
AccountID: accountID,
},
}
}
func NewContactUpdatedEvent(contactID, accountID string) Event {
return Event{
Type: TypeContactUpdated,
Payload: ContactEventPayload{
ContactID: contactID,
AccountID: accountID,
},
}
}

View File

@ -0,0 +1,77 @@
package realtime
import (
"encoding/json"
"testing"
)
func TestEventTypeConstants(t *testing.T) {
tests := map[string]string{
"mail.created": TypeMailCreated,
"mail.updated": TypeMailUpdated,
"mail.deleted": TypeMailDeleted,
"outbox.updated": TypeOutboxUpdated,
"contact.updated": TypeContactUpdated,
"ws.ping": TypeWSPing,
"ws.pong": TypeWSPong,
}
for want, got := range tests {
if got != want {
t.Errorf("constant = %q, want %q", got, want)
}
}
}
func TestMailEventConstructorsJSON(t *testing.T) {
tests := []struct {
name string
event Event
want string
}{
{
name: "created",
event: NewMailCreatedEvent("msg-1", "acct-1"),
want: `{"type":"mail.created","payload":{"message_id":"msg-1","account_id":"acct-1"}}`,
},
{
name: "updated",
event: NewMailUpdatedEvent("msg-2", "acct-2"),
want: `{"type":"mail.updated","payload":{"message_id":"msg-2","account_id":"acct-2"}}`,
},
{
name: "deleted",
event: NewMailDeletedEvent("msg-3", "acct-3"),
want: `{"type":"mail.deleted","payload":{"message_id":"msg-3","account_id":"acct-3"}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := json.Marshal(tt.event)
if err != nil {
t.Fatalf("json.Marshal() error = %v", err)
}
if string(got) != tt.want {
t.Fatalf("json = %s, want %s", got, tt.want)
}
})
}
}
func TestOutboxAndContactEventConstructorsJSON(t *testing.T) {
outbox, err := json.Marshal(NewOutboxUpdatedEvent("acct-1"))
if err != nil {
t.Fatalf("json.Marshal(outbox) error = %v", err)
}
if string(outbox) != `{"type":"outbox.updated","payload":{"account_id":"acct-1"}}` {
t.Fatalf("outbox json = %s", outbox)
}
contact, err := json.Marshal(NewContactUpdatedEvent("contact-1", "acct-1"))
if err != nil {
t.Fatalf("json.Marshal(contact) error = %v", err)
}
if string(contact) != `{"type":"contact.updated","payload":{"contact_id":"contact-1","account_id":"acct-1"}}` {
t.Fatalf("contact json = %s", contact)
}
}

View File

@ -2,43 +2,92 @@ package realtime
import ( import (
"context" "context"
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog" "log/slog"
"net/http" "net/http"
"strconv"
"strings"
"sync" "sync"
"time"
"github.com/coder/websocket" "github.com/coder/websocket"
"github.com/coder/websocket/wsjson" "github.com/coder/websocket/wsjson"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/users"
) )
type Event struct { type Event struct {
Type string `json:"type"` Type string `json:"type"`
Payload any `json:"payload"` Payload any `json:"payload"`
Seq uint64 `json:"seq,omitempty"`
} }
type Hub struct { type Hub struct {
mu sync.RWMutex mu sync.RWMutex
clients map[string]map[*conn]struct{} // userID -> connections
logger *slog.Logger clients map[string]map[*conn]struct{} // userID -> active connections
sessionCounts map[string]map[string]int // userID -> session fingerprint -> count
history map[string][]Event // userID -> recent ordered events
nextSeq uint64
maxConnectionsPerUser int
maxConnectionsPerSess int
replayBufferSizePerUser int
logger *slog.Logger
verifier *auth.Verifier
db *pgxpool.Pool
} }
type conn struct { type conn struct {
ws *websocket.Conn ws *websocket.Conn
userID string userID string
sessionKey string
writeMu sync.Mutex
} }
func NewHub() *Hub { const (
defaultMaxConnectionsPerUser = 10
defaultMaxConnectionsPerSess = 3
defaultReplayBufferSizePerUser = 500
heartbeatInterval = 25 * time.Second
heartbeatTimeout = 90 * time.Second
)
var (
errVerifierUnavailable = errors.New("authentication unavailable")
errUserStoreUnavailable = errors.New("user store unavailable")
errMissingToken = errors.New("missing token")
errInvalidToken = errors.New("invalid token")
errUserProvisioning = errors.New("failed to provision user")
)
func NewHub(verifier *auth.Verifier, db *pgxpool.Pool) *Hub {
return &Hub{ return &Hub{
clients: make(map[string]map[*conn]struct{}), clients: make(map[string]map[*conn]struct{}),
logger: slog.Default().With("component", "ws-hub"), sessionCounts: make(map[string]map[string]int),
history: make(map[string][]Event),
maxConnectionsPerUser: defaultMaxConnectionsPerUser,
maxConnectionsPerSess: defaultMaxConnectionsPerSess,
replayBufferSizePerUser: defaultReplayBufferSizePerUser,
logger: slog.Default().With("component", "ws-hub"),
verifier: verifier,
db: db,
} }
} }
func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) { func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
userID := r.URL.Query().Get("user_id") userID, token, err := h.authenticate(r)
if userID == "" { if err != nil {
http.Error(w, "missing user_id", http.StatusBadRequest) h.writeAuthError(w, err)
return return
} }
sessionKey := tokenSessionKey(token)
since := parseSinceCursor(r)
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"}, OriginPatterns: []string{"*"},
@ -48,50 +97,308 @@ func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
return return
} }
c := &conn{ws: ws, userID: userID} c := &conn{ws: ws, userID: userID, sessionKey: sessionKey}
h.register(c) if err := h.register(c); err != nil {
h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err)
_ = ws.Close(websocket.StatusPolicyViolation, err.Error())
return
}
defer h.unregister(c) defer h.unregister(c)
if err := h.replaySince(c, since); err != nil {
h.logger.Warn("ws replay failed", "user_id", userID, "error", err)
return
}
ctx := r.Context() ctx := r.Context()
lastPong := newLastPongTracker()
go h.runHeartbeat(ctx, c, lastPong)
for { for {
_, _, err := ws.Read(ctx) var incoming Event
if err != nil { if err := wsjson.Read(ctx, ws, &incoming); err != nil {
break
}
switch incoming.Type {
case TypeWSPong:
lastPong.touch()
case TypeWSPing:
if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil {
return
}
lastPong.touch()
}
if isHeartbeatTimedOut(lastPong) {
_ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
break break
} }
} }
} }
func (h *Hub) Broadcast(userID string, event Event) { func (h *Hub) authenticate(r *http.Request) (string, string, error) {
h.mu.RLock() if h.verifier == nil {
conns := h.clients[userID] return "", "", errVerifierUnavailable
h.mu.RUnlock() }
if h.db == nil {
return "", "", errUserStoreUnavailable
}
for c := range conns { token, ok := extractToken(r)
if err := wsjson.Write(context.Background(), c.ws, event); err != nil { if !ok {
h.logger.Error("ws write", "error", err, "user_id", userID) return "", "", errMissingToken
go h.unregister(c) }
claims, err := h.verifier.Verify(r.Context(), token)
if err != nil {
return "", "", errInvalidToken
}
userID, err := users.EnsureUser(r.Context(), h.db, claims)
if err != nil {
return "", "", errUserProvisioning
}
return userID, token, nil
}
func (h *Hub) writeAuthError(w http.ResponseWriter, err error) {
switch err {
case errVerifierUnavailable, errUserStoreUnavailable:
http.Error(w, err.Error(), http.StatusServiceUnavailable)
case errMissingToken, errInvalidToken:
http.Error(w, err.Error(), http.StatusUnauthorized)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func extractToken(r *http.Request) (string, bool) {
if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" {
if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" {
return strings.TrimSpace(token), true
}
return "", false
}
for _, key := range []string{"token", "access_token"} {
if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" {
return token, true
}
}
return "", false
}
func tokenSessionKey(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:8])
}
func parseSinceCursor(r *http.Request) uint64 {
raw := strings.TrimSpace(r.URL.Query().Get("since"))
if raw == "" {
return 0
}
v, err := strconv.ParseUint(raw, 10, 64)
if err != nil {
return 0
}
return v
}
func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if isHeartbeatTimedOut(tracker) {
_ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
return
}
if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil {
return
}
} }
} }
} }
func (h *Hub) register(c *conn) { type lastPongTracker struct {
mu sync.RWMutex
ts time.Time
}
func newLastPongTracker() *lastPongTracker {
return &lastPongTracker{ts: time.Now()}
}
func (t *lastPongTracker) touch() {
t.mu.Lock()
defer t.mu.Unlock()
t.ts = time.Now()
}
func (t *lastPongTracker) value() time.Time {
t.mu.RLock()
defer t.mu.RUnlock()
return t.ts
}
func isHeartbeatTimedOut(t *lastPongTracker) bool {
return time.Since(t.value()) > heartbeatTimeout
}
func (h *Hub) replaySince(c *conn, since uint64) error {
if since == 0 {
return nil
}
h.mu.RLock()
history := append([]Event(nil), h.history[c.userID]...)
h.mu.RUnlock()
for _, event := range history {
if event.Seq <= since {
continue
}
if err := h.writeConn(context.Background(), c, event); err != nil {
return err
}
}
return nil
}
func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return wsjson.Write(ctx, c.ws, event)
}
func (h *Hub) nextSequenceLocked() uint64 {
h.nextSeq++
return h.nextSeq
}
func (h *Hub) appendHistoryLocked(userID string, event Event) {
history := append(h.history[userID], event)
if len(history) > h.replayBufferSizePerUser {
history = history[len(history)-h.replayBufferSizePerUser:]
}
h.history[userID] = history
}
func (h *Hub) snapshotConnectionsLocked(userID string) []*conn {
conns := h.clients[userID]
if len(conns) == 0 {
return nil
}
out := make([]*conn, 0, len(conns))
for c := range conns {
out = append(out, c)
}
return out
}
func (h *Hub) canRegisterLocked(c *conn) error {
if len(h.clients[c.userID]) >= h.maxConnectionsPerUser {
return errors.New("too many websocket connections for user")
}
if h.sessionCounts[c.userID] == nil {
return nil
}
if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess {
return errors.New("too many websocket connections for session")
}
return nil
}
func (h *Hub) register(c *conn) error {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
if h.clients[c.userID] == nil { if h.clients[c.userID] == nil {
h.clients[c.userID] = make(map[*conn]struct{}) h.clients[c.userID] = make(map[*conn]struct{})
} }
if err := h.canRegisterLocked(c); err != nil {
return err
}
h.clients[c.userID][c] = struct{}{} h.clients[c.userID][c] = struct{}{}
h.logger.Info("ws connected", "user_id", c.userID) if h.sessionCounts[c.userID] == nil {
h.sessionCounts[c.userID] = make(map[string]int)
}
h.sessionCounts[c.userID][c.sessionKey]++
h.logger.Info("ws connected",
"user_id", c.userID,
"session", c.sessionKey,
"user_conns", len(h.clients[c.userID]),
"session_conns", h.sessionCounts[c.userID][c.sessionKey],
)
return nil
} }
func (h *Hub) unregister(c *conn) { func (h *Hub) unregister(c *conn) {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
if conns, ok := h.clients[c.userID]; ok { if conns, ok := h.clients[c.userID]; ok {
delete(conns, c) delete(conns, c)
if len(conns) == 0 { if len(conns) == 0 {
delete(h.clients, c.userID) delete(h.clients, c.userID)
} }
} }
c.ws.Close(websocket.StatusNormalClosure, "")
if sessions := h.sessionCounts[c.userID]; sessions != nil {
next := sessions[c.sessionKey] - 1
if next <= 0 {
delete(sessions, c.sessionKey)
} else {
sessions[c.sessionKey] = next
}
if len(sessions) == 0 {
delete(h.sessionCounts, c.userID)
}
}
_ = c.ws.Close(websocket.StatusNormalClosure, "")
}
func (h *Hub) Broadcast(userID string, event Event) {
h.mu.Lock()
event.Seq = h.nextSequenceLocked()
h.appendHistoryLocked(userID, event)
conns := h.snapshotConnectionsLocked(userID)
h.mu.Unlock()
for _, c := range conns {
if err := h.writeConn(context.Background(), c, event); err != nil {
h.logger.Error("ws write", "error", err, "user_id", userID)
go h.unregister(c)
}
}
}
func (h *Hub) HistoryHead(userID string) uint64 {
h.mu.RLock()
defer h.mu.RUnlock()
history := h.history[userID]
if len(history) == 0 {
return 0
}
return history[len(history)-1].Seq
}
func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) {
h.mu.Lock()
defer h.mu.Unlock()
if maxUser > 0 {
h.maxConnectionsPerUser = maxUser
}
if maxSession > 0 {
h.maxConnectionsPerSess = maxSession
}
if replaySize > 0 {
h.replayBufferSizePerUser = replaySize
}
} }

View File

@ -0,0 +1,98 @@
package realtime
import (
"net/http/httptest"
"testing"
)
func TestExtractToken(t *testing.T) {
req := httptest.NewRequest("GET", "/ws", nil)
req.Header.Set("Authorization", "Bearer abc")
token, ok := extractToken(req)
if !ok || token != "abc" {
t.Fatalf("extractToken auth header = (%q, %v), want (abc, true)", token, ok)
}
req = httptest.NewRequest("GET", "/ws?token=q1", nil)
token, ok = extractToken(req)
if !ok || token != "q1" {
t.Fatalf("extractToken token query = (%q, %v), want (q1, true)", token, ok)
}
req = httptest.NewRequest("GET", "/ws?access_token=q2", nil)
token, ok = extractToken(req)
if !ok || token != "q2" {
t.Fatalf("extractToken access_token query = (%q, %v), want (q2, true)", token, ok)
}
}
func TestParseSinceCursor(t *testing.T) {
req := httptest.NewRequest("GET", "/ws?since=42", nil)
if got := parseSinceCursor(req); got != 42 {
t.Fatalf("parseSinceCursor() = %d, want 42", got)
}
req = httptest.NewRequest("GET", "/ws?since=bad", nil)
if got := parseSinceCursor(req); got != 0 {
t.Fatalf("parseSinceCursor() invalid = %d, want 0", got)
}
}
func TestTokenSessionKeyStable(t *testing.T) {
a := tokenSessionKey("same-token")
b := tokenSessionKey("same-token")
c := tokenSessionKey("other-token")
if a != b {
t.Fatalf("session keys for same token differ: %q != %q", a, b)
}
if a == c {
t.Fatalf("session keys for different tokens should differ")
}
}
func TestRegisterRespectsLimits(t *testing.T) {
h := NewHub(nil, nil)
h.SetLimits(2, 1, 10)
c1 := &conn{userID: "u1", sessionKey: "s1"}
c2 := &conn{userID: "u1", sessionKey: "s2"}
c3 := &conn{userID: "u1", sessionKey: "s3"}
if err := h.register(c1); err != nil {
t.Fatalf("register c1 error = %v", err)
}
if err := h.register(c2); err != nil {
t.Fatalf("register c2 error = %v", err)
}
if err := h.register(c3); err == nil {
t.Fatalf("register c3 expected user limit error")
}
h2 := NewHub(nil, nil)
h2.SetLimits(5, 1, 10)
if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err != nil {
t.Fatalf("register first session conn error = %v", err)
}
if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err == nil {
t.Fatalf("register second session conn expected session limit error")
}
}
func TestBroadcastAssignsSequenceAndKeepsReplayBuffer(t *testing.T) {
h := NewHub(nil, nil)
h.SetLimits(10, 10, 2)
h.Broadcast("u1", NewMailCreatedEvent("m1", "a1"))
h.Broadcast("u1", NewMailUpdatedEvent("m2", "a1"))
h.Broadcast("u1", NewMailDeletedEvent("m3", "a1"))
if head := h.HistoryHead("u1"); head != 3 {
t.Fatalf("HistoryHead() = %d, want 3", head)
}
if len(h.history["u1"]) != 2 {
t.Fatalf("history length = %d, want 2", len(h.history["u1"]))
}
if h.history["u1"][0].Seq != 2 || h.history["u1"][1].Seq != 3 {
t.Fatalf("history seqs = [%d, %d], want [2, 3]", h.history["u1"][0].Seq, h.history["u1"][1].Seq)
}
}

View File

@ -120,11 +120,11 @@ Objectif: transformer état actuel (partiellement implémenté) vers produit fon
### 2.5 Realtime (`/ws`) ### 2.5 Realtime (`/ws`)
- [ ] Remplacer `user_id` query param non sûr par auth token WS. - [x] Remplacer `user_id` query param non sûr par auth token WS.
- [ ] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...). - [x] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...).
- [ ] Ajouter heartbeat/ping/pong. - [x] Ajouter heartbeat/ping/pong.
- [ ] Gérer reconnexion client + rattrapage delta. - [x] Gérer reconnexion client + rattrapage delta.
- [ ] Limiter connexions par user/session. - [x] Limiter connexions par user/session.
### 2.6 Search ### 2.6 Search