ultisuite-backend/internal/realtime/hub.go
R3D347HR4Y a2e17c5b6c Enhance WebSocket hub with authentication and event handling improvements
- Updated the WebSocket hub to replace the insecure `user_id` query parameter with an authentication token for secure connections.
- Introduced typed events for mail operations (created, updated, deleted) to streamline event handling.
- Implemented heartbeat functionality (ping/pong) to maintain connection health.
- Enhanced client reconnection logic and delta replay for improved user experience.
- Added limits on connections per user/session to prevent abuse and ensure stability.
2026-05-22 18:09:02 +02:00

405 lines
9.4 KiB
Go

package realtime
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"log/slog"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/users"
)
type Event struct {
Type string `json:"type"`
Payload any `json:"payload"`
Seq uint64 `json:"seq,omitempty"`
}
type Hub struct {
mu sync.RWMutex
clients map[string]map[*conn]struct{} // userID -> active connections
sessionCounts map[string]map[string]int // userID -> session fingerprint -> count
history map[string][]Event // userID -> recent ordered events
nextSeq uint64
maxConnectionsPerUser int
maxConnectionsPerSess int
replayBufferSizePerUser int
logger *slog.Logger
verifier *auth.Verifier
db *pgxpool.Pool
}
type conn struct {
ws *websocket.Conn
userID string
sessionKey string
writeMu sync.Mutex
}
const (
defaultMaxConnectionsPerUser = 10
defaultMaxConnectionsPerSess = 3
defaultReplayBufferSizePerUser = 500
heartbeatInterval = 25 * time.Second
heartbeatTimeout = 90 * time.Second
)
var (
errVerifierUnavailable = errors.New("authentication unavailable")
errUserStoreUnavailable = errors.New("user store unavailable")
errMissingToken = errors.New("missing token")
errInvalidToken = errors.New("invalid token")
errUserProvisioning = errors.New("failed to provision user")
)
func NewHub(verifier *auth.Verifier, db *pgxpool.Pool) *Hub {
return &Hub{
clients: make(map[string]map[*conn]struct{}),
sessionCounts: make(map[string]map[string]int),
history: make(map[string][]Event),
maxConnectionsPerUser: defaultMaxConnectionsPerUser,
maxConnectionsPerSess: defaultMaxConnectionsPerSess,
replayBufferSizePerUser: defaultReplayBufferSizePerUser,
logger: slog.Default().With("component", "ws-hub"),
verifier: verifier,
db: db,
}
}
func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
userID, token, err := h.authenticate(r)
if err != nil {
h.writeAuthError(w, err)
return
}
sessionKey := tokenSessionKey(token)
since := parseSinceCursor(r)
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"*"},
})
if err != nil {
h.logger.Error("websocket accept", "error", err)
return
}
c := &conn{ws: ws, userID: userID, sessionKey: sessionKey}
if err := h.register(c); err != nil {
h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err)
_ = ws.Close(websocket.StatusPolicyViolation, err.Error())
return
}
defer h.unregister(c)
if err := h.replaySince(c, since); err != nil {
h.logger.Warn("ws replay failed", "user_id", userID, "error", err)
return
}
ctx := r.Context()
lastPong := newLastPongTracker()
go h.runHeartbeat(ctx, c, lastPong)
for {
var incoming Event
if err := wsjson.Read(ctx, ws, &incoming); err != nil {
break
}
switch incoming.Type {
case TypeWSPong:
lastPong.touch()
case TypeWSPing:
if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil {
return
}
lastPong.touch()
}
if isHeartbeatTimedOut(lastPong) {
_ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
break
}
}
}
func (h *Hub) authenticate(r *http.Request) (string, string, error) {
if h.verifier == nil {
return "", "", errVerifierUnavailable
}
if h.db == nil {
return "", "", errUserStoreUnavailable
}
token, ok := extractToken(r)
if !ok {
return "", "", errMissingToken
}
claims, err := h.verifier.Verify(r.Context(), token)
if err != nil {
return "", "", errInvalidToken
}
userID, err := users.EnsureUser(r.Context(), h.db, claims)
if err != nil {
return "", "", errUserProvisioning
}
return userID, token, nil
}
func (h *Hub) writeAuthError(w http.ResponseWriter, err error) {
switch err {
case errVerifierUnavailable, errUserStoreUnavailable:
http.Error(w, err.Error(), http.StatusServiceUnavailable)
case errMissingToken, errInvalidToken:
http.Error(w, err.Error(), http.StatusUnauthorized)
default:
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}
func extractToken(r *http.Request) (string, bool) {
if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" {
if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" {
return strings.TrimSpace(token), true
}
return "", false
}
for _, key := range []string{"token", "access_token"} {
if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" {
return token, true
}
}
return "", false
}
func tokenSessionKey(token string) string {
sum := sha256.Sum256([]byte(token))
return hex.EncodeToString(sum[:8])
}
func parseSinceCursor(r *http.Request) uint64 {
raw := strings.TrimSpace(r.URL.Query().Get("since"))
if raw == "" {
return 0
}
v, err := strconv.ParseUint(raw, 10, 64)
if err != nil {
return 0
}
return v
}
func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
if isHeartbeatTimedOut(tracker) {
_ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
return
}
if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil {
return
}
}
}
}
type lastPongTracker struct {
mu sync.RWMutex
ts time.Time
}
func newLastPongTracker() *lastPongTracker {
return &lastPongTracker{ts: time.Now()}
}
func (t *lastPongTracker) touch() {
t.mu.Lock()
defer t.mu.Unlock()
t.ts = time.Now()
}
func (t *lastPongTracker) value() time.Time {
t.mu.RLock()
defer t.mu.RUnlock()
return t.ts
}
func isHeartbeatTimedOut(t *lastPongTracker) bool {
return time.Since(t.value()) > heartbeatTimeout
}
func (h *Hub) replaySince(c *conn, since uint64) error {
if since == 0 {
return nil
}
h.mu.RLock()
history := append([]Event(nil), h.history[c.userID]...)
h.mu.RUnlock()
for _, event := range history {
if event.Seq <= since {
continue
}
if err := h.writeConn(context.Background(), c, event); err != nil {
return err
}
}
return nil
}
func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error {
c.writeMu.Lock()
defer c.writeMu.Unlock()
return wsjson.Write(ctx, c.ws, event)
}
func (h *Hub) nextSequenceLocked() uint64 {
h.nextSeq++
return h.nextSeq
}
func (h *Hub) appendHistoryLocked(userID string, event Event) {
history := append(h.history[userID], event)
if len(history) > h.replayBufferSizePerUser {
history = history[len(history)-h.replayBufferSizePerUser:]
}
h.history[userID] = history
}
func (h *Hub) snapshotConnectionsLocked(userID string) []*conn {
conns := h.clients[userID]
if len(conns) == 0 {
return nil
}
out := make([]*conn, 0, len(conns))
for c := range conns {
out = append(out, c)
}
return out
}
func (h *Hub) canRegisterLocked(c *conn) error {
if len(h.clients[c.userID]) >= h.maxConnectionsPerUser {
return errors.New("too many websocket connections for user")
}
if h.sessionCounts[c.userID] == nil {
return nil
}
if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess {
return errors.New("too many websocket connections for session")
}
return nil
}
func (h *Hub) register(c *conn) error {
h.mu.Lock()
defer h.mu.Unlock()
if h.clients[c.userID] == nil {
h.clients[c.userID] = make(map[*conn]struct{})
}
if err := h.canRegisterLocked(c); err != nil {
return err
}
h.clients[c.userID][c] = struct{}{}
if h.sessionCounts[c.userID] == nil {
h.sessionCounts[c.userID] = make(map[string]int)
}
h.sessionCounts[c.userID][c.sessionKey]++
h.logger.Info("ws connected",
"user_id", c.userID,
"session", c.sessionKey,
"user_conns", len(h.clients[c.userID]),
"session_conns", h.sessionCounts[c.userID][c.sessionKey],
)
return nil
}
func (h *Hub) unregister(c *conn) {
h.mu.Lock()
defer h.mu.Unlock()
if conns, ok := h.clients[c.userID]; ok {
delete(conns, c)
if len(conns) == 0 {
delete(h.clients, c.userID)
}
}
if sessions := h.sessionCounts[c.userID]; sessions != nil {
next := sessions[c.sessionKey] - 1
if next <= 0 {
delete(sessions, c.sessionKey)
} else {
sessions[c.sessionKey] = next
}
if len(sessions) == 0 {
delete(h.sessionCounts, c.userID)
}
}
_ = c.ws.Close(websocket.StatusNormalClosure, "")
}
func (h *Hub) Broadcast(userID string, event Event) {
h.mu.Lock()
event.Seq = h.nextSequenceLocked()
h.appendHistoryLocked(userID, event)
conns := h.snapshotConnectionsLocked(userID)
h.mu.Unlock()
for _, c := range conns {
if err := h.writeConn(context.Background(), c, event); err != nil {
h.logger.Error("ws write", "error", err, "user_id", userID)
go h.unregister(c)
}
}
}
func (h *Hub) HistoryHead(userID string) uint64 {
h.mu.RLock()
defer h.mu.RUnlock()
history := h.history[userID]
if len(history) == 0 {
return 0
}
return history[len(history)-1].Seq
}
func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) {
h.mu.Lock()
defer h.mu.Unlock()
if maxUser > 0 {
h.maxConnectionsPerUser = maxUser
}
if maxSession > 0 {
h.maxConnectionsPerSess = maxSession
}
if replaySize > 0 {
h.replayBufferSizePerUser = replaySize
}
}