package realtime import ( "context" "crypto/sha256" "encoding/hex" "errors" "log/slog" "net/http" "strconv" "strings" "sync" "time" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/users" ) type Event struct { Type string `json:"type"` Payload any `json:"payload"` Seq uint64 `json:"seq,omitempty"` } type Hub struct { mu sync.RWMutex clients map[string]map[*conn]struct{} // userID -> active connections sessionCounts map[string]map[string]int // userID -> session fingerprint -> count history map[string][]Event // userID -> recent ordered events nextSeq uint64 maxConnectionsPerUser int maxConnectionsPerSess int replayBufferSizePerUser int logger *slog.Logger verifier *auth.Holder db *pgxpool.Pool } type conn struct { ws *websocket.Conn userID string sessionKey string writeMu sync.Mutex } const ( defaultMaxConnectionsPerUser = 10 defaultMaxConnectionsPerSess = 3 defaultReplayBufferSizePerUser = 500 heartbeatInterval = 25 * time.Second heartbeatTimeout = 90 * time.Second ) var ( errVerifierUnavailable = errors.New("authentication unavailable") errUserStoreUnavailable = errors.New("user store unavailable") errMissingToken = errors.New("missing token") errInvalidToken = errors.New("invalid token") errUserProvisioning = errors.New("failed to provision user") ) func NewHub(verifier *auth.Holder, db *pgxpool.Pool) *Hub { return &Hub{ clients: make(map[string]map[*conn]struct{}), sessionCounts: make(map[string]map[string]int), history: make(map[string][]Event), maxConnectionsPerUser: defaultMaxConnectionsPerUser, maxConnectionsPerSess: defaultMaxConnectionsPerSess, replayBufferSizePerUser: defaultReplayBufferSizePerUser, logger: slog.Default().With("component", "ws-hub"), verifier: verifier, db: db, } } func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) { userID, token, err := h.authenticate(r) if err != nil { h.writeAuthError(w, err) return } sessionKey := tokenSessionKey(token) since := parseSinceCursor(r) ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{ OriginPatterns: []string{"*"}, }) if err != nil { h.logger.Error("websocket accept", "error", err) return } c := &conn{ws: ws, userID: userID, sessionKey: sessionKey} if err := h.register(c); err != nil { h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err) _ = ws.Close(websocket.StatusPolicyViolation, err.Error()) return } defer h.unregister(c) if err := h.replaySince(c, since); err != nil { h.logger.Warn("ws replay failed", "user_id", userID, "error", err) return } ctx := r.Context() lastPong := newLastPongTracker() go h.runHeartbeat(ctx, c, lastPong) for { var incoming Event if err := wsjson.Read(ctx, ws, &incoming); err != nil { break } switch incoming.Type { case TypeWSPong: lastPong.touch() case TypeWSPing: if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil { return } lastPong.touch() } if isHeartbeatTimedOut(lastPong) { _ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout") break } } } func (h *Hub) authenticate(r *http.Request) (string, string, error) { if h.verifier == nil || !h.verifier.Ready() { return "", "", errVerifierUnavailable } if h.db == nil { return "", "", errUserStoreUnavailable } token, ok := extractToken(r) if !ok { return "", "", errMissingToken } claims, err := h.verifier.Verify(r.Context(), token) if err != nil { return "", "", errInvalidToken } userID, err := users.EnsureUser(r.Context(), h.db, claims) if err != nil { return "", "", errUserProvisioning } return userID, token, nil } func (h *Hub) writeAuthError(w http.ResponseWriter, err error) { switch err { case errVerifierUnavailable, errUserStoreUnavailable: http.Error(w, err.Error(), http.StatusServiceUnavailable) case errMissingToken, errInvalidToken: http.Error(w, err.Error(), http.StatusUnauthorized) default: http.Error(w, err.Error(), http.StatusInternalServerError) } } func extractToken(r *http.Request) (string, bool) { if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" { if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" { return strings.TrimSpace(token), true } return "", false } for _, key := range []string{"token", "access_token"} { if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" { return token, true } } return "", false } func tokenSessionKey(token string) string { sum := sha256.Sum256([]byte(token)) return hex.EncodeToString(sum[:8]) } func parseSinceCursor(r *http.Request) uint64 { raw := strings.TrimSpace(r.URL.Query().Get("since")) if raw == "" { return 0 } v, err := strconv.ParseUint(raw, 10, 64) if err != nil { return 0 } return v } func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) { ticker := time.NewTicker(heartbeatInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: if isHeartbeatTimedOut(tracker) { _ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout") return } if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil { return } } } } type lastPongTracker struct { mu sync.RWMutex ts time.Time } func newLastPongTracker() *lastPongTracker { return &lastPongTracker{ts: time.Now()} } func (t *lastPongTracker) touch() { t.mu.Lock() defer t.mu.Unlock() t.ts = time.Now() } func (t *lastPongTracker) value() time.Time { t.mu.RLock() defer t.mu.RUnlock() return t.ts } func isHeartbeatTimedOut(t *lastPongTracker) bool { return time.Since(t.value()) > heartbeatTimeout } func (h *Hub) replaySince(c *conn, since uint64) error { if since == 0 { return nil } h.mu.RLock() history := append([]Event(nil), h.history[c.userID]...) h.mu.RUnlock() for _, event := range history { if event.Seq <= since { continue } if err := h.writeConn(context.Background(), c, event); err != nil { return err } } return nil } func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error { c.writeMu.Lock() defer c.writeMu.Unlock() return wsjson.Write(ctx, c.ws, event) } func (h *Hub) nextSequenceLocked() uint64 { h.nextSeq++ return h.nextSeq } func (h *Hub) appendHistoryLocked(userID string, event Event) { history := append(h.history[userID], event) if len(history) > h.replayBufferSizePerUser { history = history[len(history)-h.replayBufferSizePerUser:] } h.history[userID] = history } func (h *Hub) snapshotConnectionsLocked(userID string) []*conn { conns := h.clients[userID] if len(conns) == 0 { return nil } out := make([]*conn, 0, len(conns)) for c := range conns { out = append(out, c) } return out } func (h *Hub) canRegisterLocked(c *conn) error { if len(h.clients[c.userID]) >= h.maxConnectionsPerUser { return errors.New("too many websocket connections for user") } if h.sessionCounts[c.userID] == nil { return nil } if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess { return errors.New("too many websocket connections for session") } return nil } func (h *Hub) register(c *conn) error { h.mu.Lock() defer h.mu.Unlock() if h.clients[c.userID] == nil { h.clients[c.userID] = make(map[*conn]struct{}) } if err := h.canRegisterLocked(c); err != nil { return err } h.clients[c.userID][c] = struct{}{} if h.sessionCounts[c.userID] == nil { h.sessionCounts[c.userID] = make(map[string]int) } h.sessionCounts[c.userID][c.sessionKey]++ h.logger.Info("ws connected", "user_id", c.userID, "session", c.sessionKey, "user_conns", len(h.clients[c.userID]), "session_conns", h.sessionCounts[c.userID][c.sessionKey], ) return nil } func (h *Hub) unregister(c *conn) { h.mu.Lock() defer h.mu.Unlock() if conns, ok := h.clients[c.userID]; ok { delete(conns, c) if len(conns) == 0 { delete(h.clients, c.userID) } } if sessions := h.sessionCounts[c.userID]; sessions != nil { next := sessions[c.sessionKey] - 1 if next <= 0 { delete(sessions, c.sessionKey) } else { sessions[c.sessionKey] = next } if len(sessions) == 0 { delete(h.sessionCounts, c.userID) } } _ = c.ws.Close(websocket.StatusNormalClosure, "") } func (h *Hub) Broadcast(userID string, event Event) { h.mu.Lock() event.Seq = h.nextSequenceLocked() h.appendHistoryLocked(userID, event) conns := h.snapshotConnectionsLocked(userID) h.mu.Unlock() for _, c := range conns { if err := h.writeConn(context.Background(), c, event); err != nil { h.logger.Error("ws write", "error", err, "user_id", userID) go h.unregister(c) } } } func (h *Hub) HistoryHead(userID string) uint64 { h.mu.RLock() defer h.mu.RUnlock() history := h.history[userID] if len(history) == 0 { return 0 } return history[len(history)-1].Seq } func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) { h.mu.Lock() defer h.mu.Unlock() if maxUser > 0 { h.maxConnectionsPerUser = maxUser } if maxSession > 0 { h.maxConnectionsPerSess = maxSession } if replaySize > 0 { h.replayBufferSizePerUser = replaySize } }