Enhance WebSocket hub with authentication and event handling improvements
- Updated the WebSocket hub to replace the insecure `user_id` query parameter with an authentication token for secure connections. - Introduced typed events for mail operations (created, updated, deleted) to streamline event handling. - Implemented heartbeat functionality (ping/pong) to maintain connection health. - Enhanced client reconnection logic and delta replay for improved user experience. - Added limits on connections per user/session to prevent abuse and ensure stability.
This commit is contained in:
parent
1b9a3394e2
commit
a2e17c5b6c
@ -135,7 +135,7 @@ func main() {
|
||||
}
|
||||
|
||||
// WebSocket hub
|
||||
hub := realtime.NewHub()
|
||||
hub := realtime.NewHub(verifier, pool)
|
||||
healthChecker := observability.NewHealthChecker(cfg, pool, rdb)
|
||||
|
||||
rulesEngine := rules.NewEngineWithWebhooks(pool, webhooks.NewExecutor(pool))
|
||||
|
||||
@ -36,13 +36,7 @@ func newSyncPipeline(db *pgxpool.Pool, rulesEngine *rules.Engine, hub *realtime.
|
||||
|
||||
func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) {
|
||||
if ev.kind == "deleted" {
|
||||
p.broadcast(ev, realtime.Event{
|
||||
Type: "mail.deleted",
|
||||
Payload: map[string]any{
|
||||
"message_id": ev.messageID,
|
||||
"account_id": ev.accountID,
|
||||
},
|
||||
})
|
||||
p.broadcast(ev, realtime.NewMailDeletedEvent(ev.messageID, ev.accountID))
|
||||
return
|
||||
}
|
||||
|
||||
@ -55,17 +49,11 @@ func (p *syncPipeline) handle(ctx context.Context, ev postSyncEvent) {
|
||||
}
|
||||
}
|
||||
|
||||
eventType := "mail.updated"
|
||||
event := realtime.NewMailUpdatedEvent(ev.messageID, ev.accountID)
|
||||
if ev.kind == "created" {
|
||||
eventType = "mail.created"
|
||||
event = realtime.NewMailCreatedEvent(ev.messageID, ev.accountID)
|
||||
}
|
||||
p.broadcast(ev, realtime.Event{
|
||||
Type: eventType,
|
||||
Payload: map[string]any{
|
||||
"message_id": ev.messageID,
|
||||
"account_id": ev.accountID,
|
||||
},
|
||||
})
|
||||
p.broadcast(ev, event)
|
||||
}
|
||||
|
||||
func (p *syncPipeline) broadcast(ev postSyncEvent, event realtime.Event) {
|
||||
|
||||
79
internal/realtime/events.go
Normal file
79
internal/realtime/events.go
Normal file
@ -0,0 +1,79 @@
|
||||
package realtime
|
||||
|
||||
// WebSocket event type names.
|
||||
const (
|
||||
TypeMailCreated = "mail.created"
|
||||
TypeMailUpdated = "mail.updated"
|
||||
TypeMailDeleted = "mail.deleted"
|
||||
TypeOutboxUpdated = "outbox.updated"
|
||||
TypeContactUpdated = "contact.updated"
|
||||
|
||||
TypeWSPing = "ws.ping"
|
||||
TypeWSPong = "ws.pong"
|
||||
)
|
||||
|
||||
// MailEventPayload is the payload for mail.created, mail.updated, and mail.deleted.
|
||||
type MailEventPayload struct {
|
||||
MessageID string `json:"message_id"`
|
||||
AccountID string `json:"account_id"`
|
||||
}
|
||||
|
||||
// OutboxEventPayload is the payload for outbox.updated.
|
||||
type OutboxEventPayload struct {
|
||||
AccountID string `json:"account_id"`
|
||||
}
|
||||
|
||||
// ContactEventPayload is the payload for contact.updated.
|
||||
type ContactEventPayload struct {
|
||||
ContactID string `json:"contact_id"`
|
||||
AccountID string `json:"account_id,omitempty"`
|
||||
}
|
||||
|
||||
func NewMailCreatedEvent(messageID, accountID string) Event {
|
||||
return Event{
|
||||
Type: TypeMailCreated,
|
||||
Payload: MailEventPayload{
|
||||
MessageID: messageID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewMailUpdatedEvent(messageID, accountID string) Event {
|
||||
return Event{
|
||||
Type: TypeMailUpdated,
|
||||
Payload: MailEventPayload{
|
||||
MessageID: messageID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewMailDeletedEvent(messageID, accountID string) Event {
|
||||
return Event{
|
||||
Type: TypeMailDeleted,
|
||||
Payload: MailEventPayload{
|
||||
MessageID: messageID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewOutboxUpdatedEvent(accountID string) Event {
|
||||
return Event{
|
||||
Type: TypeOutboxUpdated,
|
||||
Payload: OutboxEventPayload{
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewContactUpdatedEvent(contactID, accountID string) Event {
|
||||
return Event{
|
||||
Type: TypeContactUpdated,
|
||||
Payload: ContactEventPayload{
|
||||
ContactID: contactID,
|
||||
AccountID: accountID,
|
||||
},
|
||||
}
|
||||
}
|
||||
77
internal/realtime/events_test.go
Normal file
77
internal/realtime/events_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEventTypeConstants(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
"mail.created": TypeMailCreated,
|
||||
"mail.updated": TypeMailUpdated,
|
||||
"mail.deleted": TypeMailDeleted,
|
||||
"outbox.updated": TypeOutboxUpdated,
|
||||
"contact.updated": TypeContactUpdated,
|
||||
"ws.ping": TypeWSPing,
|
||||
"ws.pong": TypeWSPong,
|
||||
}
|
||||
for want, got := range tests {
|
||||
if got != want {
|
||||
t.Errorf("constant = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMailEventConstructorsJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
event Event
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "created",
|
||||
event: NewMailCreatedEvent("msg-1", "acct-1"),
|
||||
want: `{"type":"mail.created","payload":{"message_id":"msg-1","account_id":"acct-1"}}`,
|
||||
},
|
||||
{
|
||||
name: "updated",
|
||||
event: NewMailUpdatedEvent("msg-2", "acct-2"),
|
||||
want: `{"type":"mail.updated","payload":{"message_id":"msg-2","account_id":"acct-2"}}`,
|
||||
},
|
||||
{
|
||||
name: "deleted",
|
||||
event: NewMailDeletedEvent("msg-3", "acct-3"),
|
||||
want: `{"type":"mail.deleted","payload":{"message_id":"msg-3","account_id":"acct-3"}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := json.Marshal(tt.event)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal() error = %v", err)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Fatalf("json = %s, want %s", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOutboxAndContactEventConstructorsJSON(t *testing.T) {
|
||||
outbox, err := json.Marshal(NewOutboxUpdatedEvent("acct-1"))
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal(outbox) error = %v", err)
|
||||
}
|
||||
if string(outbox) != `{"type":"outbox.updated","payload":{"account_id":"acct-1"}}` {
|
||||
t.Fatalf("outbox json = %s", outbox)
|
||||
}
|
||||
|
||||
contact, err := json.Marshal(NewContactUpdatedEvent("contact-1", "acct-1"))
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal(contact) error = %v", err)
|
||||
}
|
||||
if string(contact) != `{"type":"contact.updated","payload":{"contact_id":"contact-1","account_id":"acct-1"}}` {
|
||||
t.Fatalf("contact json = %s", contact)
|
||||
}
|
||||
}
|
||||
@ -2,43 +2,92 @@ package realtime
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
Type string `json:"type"`
|
||||
Payload any `json:"payload"`
|
||||
Seq uint64 `json:"seq,omitempty"`
|
||||
}
|
||||
|
||||
type Hub struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]map[*conn]struct{} // userID -> connections
|
||||
|
||||
clients map[string]map[*conn]struct{} // userID -> active connections
|
||||
sessionCounts map[string]map[string]int // userID -> session fingerprint -> count
|
||||
history map[string][]Event // userID -> recent ordered events
|
||||
nextSeq uint64
|
||||
|
||||
maxConnectionsPerUser int
|
||||
maxConnectionsPerSess int
|
||||
replayBufferSizePerUser int
|
||||
|
||||
logger *slog.Logger
|
||||
verifier *auth.Verifier
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
ws *websocket.Conn
|
||||
userID string
|
||||
sessionKey string
|
||||
writeMu sync.Mutex
|
||||
}
|
||||
|
||||
func NewHub() *Hub {
|
||||
const (
|
||||
defaultMaxConnectionsPerUser = 10
|
||||
defaultMaxConnectionsPerSess = 3
|
||||
defaultReplayBufferSizePerUser = 500
|
||||
heartbeatInterval = 25 * time.Second
|
||||
heartbeatTimeout = 90 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errVerifierUnavailable = errors.New("authentication unavailable")
|
||||
errUserStoreUnavailable = errors.New("user store unavailable")
|
||||
errMissingToken = errors.New("missing token")
|
||||
errInvalidToken = errors.New("invalid token")
|
||||
errUserProvisioning = errors.New("failed to provision user")
|
||||
)
|
||||
|
||||
func NewHub(verifier *auth.Verifier, db *pgxpool.Pool) *Hub {
|
||||
return &Hub{
|
||||
clients: make(map[string]map[*conn]struct{}),
|
||||
sessionCounts: make(map[string]map[string]int),
|
||||
history: make(map[string][]Event),
|
||||
maxConnectionsPerUser: defaultMaxConnectionsPerUser,
|
||||
maxConnectionsPerSess: defaultMaxConnectionsPerSess,
|
||||
replayBufferSizePerUser: defaultReplayBufferSizePerUser,
|
||||
logger: slog.Default().With("component", "ws-hub"),
|
||||
verifier: verifier,
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.URL.Query().Get("user_id")
|
||||
if userID == "" {
|
||||
http.Error(w, "missing user_id", http.StatusBadRequest)
|
||||
userID, token, err := h.authenticate(r)
|
||||
if err != nil {
|
||||
h.writeAuthError(w, err)
|
||||
return
|
||||
}
|
||||
sessionKey := tokenSessionKey(token)
|
||||
since := parseSinceCursor(r)
|
||||
|
||||
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
||||
OriginPatterns: []string{"*"},
|
||||
@ -48,50 +97,308 @@ func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
c := &conn{ws: ws, userID: userID}
|
||||
h.register(c)
|
||||
c := &conn{ws: ws, userID: userID, sessionKey: sessionKey}
|
||||
if err := h.register(c); err != nil {
|
||||
h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err)
|
||||
_ = ws.Close(websocket.StatusPolicyViolation, err.Error())
|
||||
return
|
||||
}
|
||||
defer h.unregister(c)
|
||||
|
||||
if err := h.replaySince(c, since); err != nil {
|
||||
h.logger.Warn("ws replay failed", "user_id", userID, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
lastPong := newLastPongTracker()
|
||||
go h.runHeartbeat(ctx, c, lastPong)
|
||||
|
||||
for {
|
||||
_, _, err := ws.Read(ctx)
|
||||
if err != nil {
|
||||
var incoming Event
|
||||
if err := wsjson.Read(ctx, ws, &incoming); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
switch incoming.Type {
|
||||
case TypeWSPong:
|
||||
lastPong.touch()
|
||||
case TypeWSPing:
|
||||
if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil {
|
||||
return
|
||||
}
|
||||
lastPong.touch()
|
||||
}
|
||||
|
||||
if isHeartbeatTimedOut(lastPong) {
|
||||
_ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) Broadcast(userID string, event Event) {
|
||||
func (h *Hub) authenticate(r *http.Request) (string, string, error) {
|
||||
if h.verifier == nil {
|
||||
return "", "", errVerifierUnavailable
|
||||
}
|
||||
if h.db == nil {
|
||||
return "", "", errUserStoreUnavailable
|
||||
}
|
||||
|
||||
token, ok := extractToken(r)
|
||||
if !ok {
|
||||
return "", "", errMissingToken
|
||||
}
|
||||
claims, err := h.verifier.Verify(r.Context(), token)
|
||||
if err != nil {
|
||||
return "", "", errInvalidToken
|
||||
}
|
||||
userID, err := users.EnsureUser(r.Context(), h.db, claims)
|
||||
if err != nil {
|
||||
return "", "", errUserProvisioning
|
||||
}
|
||||
return userID, token, nil
|
||||
}
|
||||
|
||||
func (h *Hub) writeAuthError(w http.ResponseWriter, err error) {
|
||||
switch err {
|
||||
case errVerifierUnavailable, errUserStoreUnavailable:
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
case errMissingToken, errInvalidToken:
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
default:
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}
|
||||
|
||||
func extractToken(r *http.Request) (string, bool) {
|
||||
if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" {
|
||||
if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" {
|
||||
return strings.TrimSpace(token), true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
for _, key := range []string{"token", "access_token"} {
|
||||
if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" {
|
||||
return token, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func tokenSessionKey(token string) string {
|
||||
sum := sha256.Sum256([]byte(token))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
|
||||
func parseSinceCursor(r *http.Request) uint64 {
|
||||
raw := strings.TrimSpace(r.URL.Query().Get("since"))
|
||||
if raw == "" {
|
||||
return 0
|
||||
}
|
||||
v, err := strconv.ParseUint(raw, 10, 64)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) {
|
||||
ticker := time.NewTicker(heartbeatInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
if isHeartbeatTimedOut(tracker) {
|
||||
_ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
|
||||
return
|
||||
}
|
||||
if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type lastPongTracker struct {
|
||||
mu sync.RWMutex
|
||||
ts time.Time
|
||||
}
|
||||
|
||||
func newLastPongTracker() *lastPongTracker {
|
||||
return &lastPongTracker{ts: time.Now()}
|
||||
}
|
||||
|
||||
func (t *lastPongTracker) touch() {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.ts = time.Now()
|
||||
}
|
||||
|
||||
func (t *lastPongTracker) value() time.Time {
|
||||
t.mu.RLock()
|
||||
defer t.mu.RUnlock()
|
||||
return t.ts
|
||||
}
|
||||
|
||||
func isHeartbeatTimedOut(t *lastPongTracker) bool {
|
||||
return time.Since(t.value()) > heartbeatTimeout
|
||||
}
|
||||
|
||||
func (h *Hub) replaySince(c *conn, since uint64) error {
|
||||
if since == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
h.mu.RLock()
|
||||
conns := h.clients[userID]
|
||||
history := append([]Event(nil), h.history[c.userID]...)
|
||||
h.mu.RUnlock()
|
||||
|
||||
for c := range conns {
|
||||
if err := wsjson.Write(context.Background(), c.ws, event); err != nil {
|
||||
h.logger.Error("ws write", "error", err, "user_id", userID)
|
||||
go h.unregister(c)
|
||||
for _, event := range history {
|
||||
if event.Seq <= since {
|
||||
continue
|
||||
}
|
||||
if err := h.writeConn(context.Background(), c, event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hub) register(c *conn) {
|
||||
func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error {
|
||||
c.writeMu.Lock()
|
||||
defer c.writeMu.Unlock()
|
||||
return wsjson.Write(ctx, c.ws, event)
|
||||
}
|
||||
|
||||
func (h *Hub) nextSequenceLocked() uint64 {
|
||||
h.nextSeq++
|
||||
return h.nextSeq
|
||||
}
|
||||
|
||||
func (h *Hub) appendHistoryLocked(userID string, event Event) {
|
||||
history := append(h.history[userID], event)
|
||||
if len(history) > h.replayBufferSizePerUser {
|
||||
history = history[len(history)-h.replayBufferSizePerUser:]
|
||||
}
|
||||
h.history[userID] = history
|
||||
}
|
||||
|
||||
func (h *Hub) snapshotConnectionsLocked(userID string) []*conn {
|
||||
conns := h.clients[userID]
|
||||
if len(conns) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]*conn, 0, len(conns))
|
||||
for c := range conns {
|
||||
out = append(out, c)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (h *Hub) canRegisterLocked(c *conn) error {
|
||||
if len(h.clients[c.userID]) >= h.maxConnectionsPerUser {
|
||||
return errors.New("too many websocket connections for user")
|
||||
}
|
||||
if h.sessionCounts[c.userID] == nil {
|
||||
return nil
|
||||
}
|
||||
if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess {
|
||||
return errors.New("too many websocket connections for session")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hub) register(c *conn) error {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if h.clients[c.userID] == nil {
|
||||
h.clients[c.userID] = make(map[*conn]struct{})
|
||||
}
|
||||
if err := h.canRegisterLocked(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h.clients[c.userID][c] = struct{}{}
|
||||
h.logger.Info("ws connected", "user_id", c.userID)
|
||||
if h.sessionCounts[c.userID] == nil {
|
||||
h.sessionCounts[c.userID] = make(map[string]int)
|
||||
}
|
||||
h.sessionCounts[c.userID][c.sessionKey]++
|
||||
|
||||
h.logger.Info("ws connected",
|
||||
"user_id", c.userID,
|
||||
"session", c.sessionKey,
|
||||
"user_conns", len(h.clients[c.userID]),
|
||||
"session_conns", h.sessionCounts[c.userID][c.sessionKey],
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Hub) unregister(c *conn) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
|
||||
if conns, ok := h.clients[c.userID]; ok {
|
||||
delete(conns, c)
|
||||
if len(conns) == 0 {
|
||||
delete(h.clients, c.userID)
|
||||
}
|
||||
}
|
||||
c.ws.Close(websocket.StatusNormalClosure, "")
|
||||
|
||||
if sessions := h.sessionCounts[c.userID]; sessions != nil {
|
||||
next := sessions[c.sessionKey] - 1
|
||||
if next <= 0 {
|
||||
delete(sessions, c.sessionKey)
|
||||
} else {
|
||||
sessions[c.sessionKey] = next
|
||||
}
|
||||
if len(sessions) == 0 {
|
||||
delete(h.sessionCounts, c.userID)
|
||||
}
|
||||
}
|
||||
|
||||
_ = c.ws.Close(websocket.StatusNormalClosure, "")
|
||||
}
|
||||
|
||||
func (h *Hub) Broadcast(userID string, event Event) {
|
||||
h.mu.Lock()
|
||||
event.Seq = h.nextSequenceLocked()
|
||||
h.appendHistoryLocked(userID, event)
|
||||
conns := h.snapshotConnectionsLocked(userID)
|
||||
h.mu.Unlock()
|
||||
|
||||
for _, c := range conns {
|
||||
if err := h.writeConn(context.Background(), c, event); err != nil {
|
||||
h.logger.Error("ws write", "error", err, "user_id", userID)
|
||||
go h.unregister(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Hub) HistoryHead(userID string) uint64 {
|
||||
h.mu.RLock()
|
||||
defer h.mu.RUnlock()
|
||||
history := h.history[userID]
|
||||
if len(history) == 0 {
|
||||
return 0
|
||||
}
|
||||
return history[len(history)-1].Seq
|
||||
}
|
||||
|
||||
func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if maxUser > 0 {
|
||||
h.maxConnectionsPerUser = maxUser
|
||||
}
|
||||
if maxSession > 0 {
|
||||
h.maxConnectionsPerSess = maxSession
|
||||
}
|
||||
if replaySize > 0 {
|
||||
h.replayBufferSizePerUser = replaySize
|
||||
}
|
||||
}
|
||||
|
||||
98
internal/realtime/hub_test.go
Normal file
98
internal/realtime/hub_test.go
Normal file
@ -0,0 +1,98 @@
|
||||
package realtime
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExtractToken(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/ws", nil)
|
||||
req.Header.Set("Authorization", "Bearer abc")
|
||||
token, ok := extractToken(req)
|
||||
if !ok || token != "abc" {
|
||||
t.Fatalf("extractToken auth header = (%q, %v), want (abc, true)", token, ok)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest("GET", "/ws?token=q1", nil)
|
||||
token, ok = extractToken(req)
|
||||
if !ok || token != "q1" {
|
||||
t.Fatalf("extractToken token query = (%q, %v), want (q1, true)", token, ok)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest("GET", "/ws?access_token=q2", nil)
|
||||
token, ok = extractToken(req)
|
||||
if !ok || token != "q2" {
|
||||
t.Fatalf("extractToken access_token query = (%q, %v), want (q2, true)", token, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSinceCursor(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/ws?since=42", nil)
|
||||
if got := parseSinceCursor(req); got != 42 {
|
||||
t.Fatalf("parseSinceCursor() = %d, want 42", got)
|
||||
}
|
||||
|
||||
req = httptest.NewRequest("GET", "/ws?since=bad", nil)
|
||||
if got := parseSinceCursor(req); got != 0 {
|
||||
t.Fatalf("parseSinceCursor() invalid = %d, want 0", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTokenSessionKeyStable(t *testing.T) {
|
||||
a := tokenSessionKey("same-token")
|
||||
b := tokenSessionKey("same-token")
|
||||
c := tokenSessionKey("other-token")
|
||||
if a != b {
|
||||
t.Fatalf("session keys for same token differ: %q != %q", a, b)
|
||||
}
|
||||
if a == c {
|
||||
t.Fatalf("session keys for different tokens should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterRespectsLimits(t *testing.T) {
|
||||
h := NewHub(nil, nil)
|
||||
h.SetLimits(2, 1, 10)
|
||||
|
||||
c1 := &conn{userID: "u1", sessionKey: "s1"}
|
||||
c2 := &conn{userID: "u1", sessionKey: "s2"}
|
||||
c3 := &conn{userID: "u1", sessionKey: "s3"}
|
||||
|
||||
if err := h.register(c1); err != nil {
|
||||
t.Fatalf("register c1 error = %v", err)
|
||||
}
|
||||
if err := h.register(c2); err != nil {
|
||||
t.Fatalf("register c2 error = %v", err)
|
||||
}
|
||||
if err := h.register(c3); err == nil {
|
||||
t.Fatalf("register c3 expected user limit error")
|
||||
}
|
||||
|
||||
h2 := NewHub(nil, nil)
|
||||
h2.SetLimits(5, 1, 10)
|
||||
if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err != nil {
|
||||
t.Fatalf("register first session conn error = %v", err)
|
||||
}
|
||||
if err := h2.register(&conn{userID: "u1", sessionKey: "s1"}); err == nil {
|
||||
t.Fatalf("register second session conn expected session limit error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBroadcastAssignsSequenceAndKeepsReplayBuffer(t *testing.T) {
|
||||
h := NewHub(nil, nil)
|
||||
h.SetLimits(10, 10, 2)
|
||||
|
||||
h.Broadcast("u1", NewMailCreatedEvent("m1", "a1"))
|
||||
h.Broadcast("u1", NewMailUpdatedEvent("m2", "a1"))
|
||||
h.Broadcast("u1", NewMailDeletedEvent("m3", "a1"))
|
||||
|
||||
if head := h.HistoryHead("u1"); head != 3 {
|
||||
t.Fatalf("HistoryHead() = %d, want 3", head)
|
||||
}
|
||||
if len(h.history["u1"]) != 2 {
|
||||
t.Fatalf("history length = %d, want 2", len(h.history["u1"]))
|
||||
}
|
||||
if h.history["u1"][0].Seq != 2 || h.history["u1"][1].Seq != 3 {
|
||||
t.Fatalf("history seqs = [%d, %d], want [2, 3]", h.history["u1"][0].Seq, h.history["u1"][1].Seq)
|
||||
}
|
||||
}
|
||||
@ -120,11 +120,11 @@ Objectif: transformer état actuel (partiellement implémenté) vers produit fon
|
||||
|
||||
### 2.5 Realtime (`/ws`)
|
||||
|
||||
- [ ] Remplacer `user_id` query param non sûr par auth token WS.
|
||||
- [ ] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...).
|
||||
- [ ] Ajouter heartbeat/ping/pong.
|
||||
- [ ] Gérer reconnexion client + rattrapage delta.
|
||||
- [ ] Limiter connexions par user/session.
|
||||
- [x] Remplacer `user_id` query param non sûr par auth token WS.
|
||||
- [x] Ajouter événements typés (mail.created, mail.updated, outbox.updated, contact.updated...).
|
||||
- [x] Ajouter heartbeat/ping/pong.
|
||||
- [x] Gérer reconnexion client + rattrapage delta.
|
||||
- [x] Limiter connexions par user/session.
|
||||
|
||||
### 2.6 Search
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user