- Introduced new endpoints for contact discovery, including scanning, listing, and managing discovered contacts. - Implemented retry logic for handling missing DAV credentials during contact operations. - Added public share functionality for drive API, allowing users to manage public shares, including upload, delete, and rename operations. - Updated Nextcloud configuration to support public share links and improved error handling for public share permissions. - Enhanced logging and validation across contact and drive APIs for better error tracking and user feedback. - Added tests for new contact matching and ranking functionalities to ensure accuracy and reliability.
405 lines
9.4 KiB
Go
405 lines
9.4 KiB
Go
package realtime
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"errors"
|
|
"log/slog"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/coder/websocket/wsjson"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/auth"
|
|
"github.com/ultisuite/ulti-backend/internal/users"
|
|
)
|
|
|
|
type Event struct {
|
|
Type string `json:"type"`
|
|
Payload any `json:"payload"`
|
|
Seq uint64 `json:"seq,omitempty"`
|
|
}
|
|
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
|
|
clients map[string]map[*conn]struct{} // userID -> active connections
|
|
sessionCounts map[string]map[string]int // userID -> session fingerprint -> count
|
|
history map[string][]Event // userID -> recent ordered events
|
|
nextSeq uint64
|
|
|
|
maxConnectionsPerUser int
|
|
maxConnectionsPerSess int
|
|
replayBufferSizePerUser int
|
|
|
|
logger *slog.Logger
|
|
verifier *auth.Holder
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
type conn struct {
|
|
ws *websocket.Conn
|
|
userID string
|
|
sessionKey string
|
|
writeMu sync.Mutex
|
|
}
|
|
|
|
const (
|
|
defaultMaxConnectionsPerUser = 10
|
|
defaultMaxConnectionsPerSess = 3
|
|
defaultReplayBufferSizePerUser = 500
|
|
heartbeatInterval = 25 * time.Second
|
|
heartbeatTimeout = 90 * time.Second
|
|
)
|
|
|
|
var (
|
|
errVerifierUnavailable = errors.New("authentication unavailable")
|
|
errUserStoreUnavailable = errors.New("user store unavailable")
|
|
errMissingToken = errors.New("missing token")
|
|
errInvalidToken = errors.New("invalid token")
|
|
errUserProvisioning = errors.New("failed to provision user")
|
|
)
|
|
|
|
func NewHub(verifier *auth.Holder, db *pgxpool.Pool) *Hub {
|
|
return &Hub{
|
|
clients: make(map[string]map[*conn]struct{}),
|
|
sessionCounts: make(map[string]map[string]int),
|
|
history: make(map[string][]Event),
|
|
maxConnectionsPerUser: defaultMaxConnectionsPerUser,
|
|
maxConnectionsPerSess: defaultMaxConnectionsPerSess,
|
|
replayBufferSizePerUser: defaultReplayBufferSizePerUser,
|
|
logger: slog.Default().With("component", "ws-hub"),
|
|
verifier: verifier,
|
|
db: db,
|
|
}
|
|
}
|
|
|
|
func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
|
|
userID, token, err := h.authenticate(r)
|
|
if err != nil {
|
|
h.writeAuthError(w, err)
|
|
return
|
|
}
|
|
sessionKey := tokenSessionKey(token)
|
|
since := parseSinceCursor(r)
|
|
|
|
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("websocket accept", "error", err)
|
|
return
|
|
}
|
|
|
|
c := &conn{ws: ws, userID: userID, sessionKey: sessionKey}
|
|
if err := h.register(c); err != nil {
|
|
h.logger.Warn("ws rejected", "user_id", userID, "session", sessionKey, "error", err)
|
|
_ = ws.Close(websocket.StatusPolicyViolation, err.Error())
|
|
return
|
|
}
|
|
defer h.unregister(c)
|
|
|
|
if err := h.replaySince(c, since); err != nil {
|
|
h.logger.Warn("ws replay failed", "user_id", userID, "error", err)
|
|
return
|
|
}
|
|
|
|
ctx := r.Context()
|
|
lastPong := newLastPongTracker()
|
|
go h.runHeartbeat(ctx, c, lastPong)
|
|
|
|
for {
|
|
var incoming Event
|
|
if err := wsjson.Read(ctx, ws, &incoming); err != nil {
|
|
break
|
|
}
|
|
|
|
switch incoming.Type {
|
|
case TypeWSPong:
|
|
lastPong.touch()
|
|
case TypeWSPing:
|
|
if err := h.writeConn(ctx, c, Event{Type: TypeWSPong, Payload: map[string]any{}}); err != nil {
|
|
return
|
|
}
|
|
lastPong.touch()
|
|
}
|
|
|
|
if isHeartbeatTimedOut(lastPong) {
|
|
_ = ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) authenticate(r *http.Request) (string, string, error) {
|
|
if h.verifier == nil || !h.verifier.Ready() {
|
|
return "", "", errVerifierUnavailable
|
|
}
|
|
if h.db == nil {
|
|
return "", "", errUserStoreUnavailable
|
|
}
|
|
|
|
token, ok := extractToken(r)
|
|
if !ok {
|
|
return "", "", errMissingToken
|
|
}
|
|
claims, err := h.verifier.Verify(r.Context(), token)
|
|
if err != nil {
|
|
return "", "", errInvalidToken
|
|
}
|
|
userID, err := users.EnsureUser(r.Context(), h.db, claims)
|
|
if err != nil {
|
|
return "", "", errUserProvisioning
|
|
}
|
|
return userID, token, nil
|
|
}
|
|
|
|
func (h *Hub) writeAuthError(w http.ResponseWriter, err error) {
|
|
switch err {
|
|
case errVerifierUnavailable, errUserStoreUnavailable:
|
|
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
|
case errMissingToken, errInvalidToken:
|
|
http.Error(w, err.Error(), http.StatusUnauthorized)
|
|
default:
|
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func extractToken(r *http.Request) (string, bool) {
|
|
if authz := strings.TrimSpace(r.Header.Get("Authorization")); authz != "" {
|
|
if token, ok := strings.CutPrefix(authz, "Bearer "); ok && strings.TrimSpace(token) != "" {
|
|
return strings.TrimSpace(token), true
|
|
}
|
|
return "", false
|
|
}
|
|
for _, key := range []string{"token", "access_token"} {
|
|
if token := strings.TrimSpace(r.URL.Query().Get(key)); token != "" {
|
|
return token, true
|
|
}
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func tokenSessionKey(token string) string {
|
|
sum := sha256.Sum256([]byte(token))
|
|
return hex.EncodeToString(sum[:8])
|
|
}
|
|
|
|
func parseSinceCursor(r *http.Request) uint64 {
|
|
raw := strings.TrimSpace(r.URL.Query().Get("since"))
|
|
if raw == "" {
|
|
return 0
|
|
}
|
|
v, err := strconv.ParseUint(raw, 10, 64)
|
|
if err != nil {
|
|
return 0
|
|
}
|
|
return v
|
|
}
|
|
|
|
func (h *Hub) runHeartbeat(ctx context.Context, c *conn, tracker *lastPongTracker) {
|
|
ticker := time.NewTicker(heartbeatInterval)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
if isHeartbeatTimedOut(tracker) {
|
|
_ = c.ws.Close(websocket.StatusPolicyViolation, "heartbeat timeout")
|
|
return
|
|
}
|
|
if err := h.writeConn(ctx, c, Event{Type: TypeWSPing, Payload: map[string]any{}}); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
type lastPongTracker struct {
|
|
mu sync.RWMutex
|
|
ts time.Time
|
|
}
|
|
|
|
func newLastPongTracker() *lastPongTracker {
|
|
return &lastPongTracker{ts: time.Now()}
|
|
}
|
|
|
|
func (t *lastPongTracker) touch() {
|
|
t.mu.Lock()
|
|
defer t.mu.Unlock()
|
|
t.ts = time.Now()
|
|
}
|
|
|
|
func (t *lastPongTracker) value() time.Time {
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
return t.ts
|
|
}
|
|
|
|
func isHeartbeatTimedOut(t *lastPongTracker) bool {
|
|
return time.Since(t.value()) > heartbeatTimeout
|
|
}
|
|
|
|
func (h *Hub) replaySince(c *conn, since uint64) error {
|
|
if since == 0 {
|
|
return nil
|
|
}
|
|
|
|
h.mu.RLock()
|
|
history := append([]Event(nil), h.history[c.userID]...)
|
|
h.mu.RUnlock()
|
|
|
|
for _, event := range history {
|
|
if event.Seq <= since {
|
|
continue
|
|
}
|
|
if err := h.writeConn(context.Background(), c, event); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *Hub) writeConn(ctx context.Context, c *conn, event Event) error {
|
|
c.writeMu.Lock()
|
|
defer c.writeMu.Unlock()
|
|
return wsjson.Write(ctx, c.ws, event)
|
|
}
|
|
|
|
func (h *Hub) nextSequenceLocked() uint64 {
|
|
h.nextSeq++
|
|
return h.nextSeq
|
|
}
|
|
|
|
func (h *Hub) appendHistoryLocked(userID string, event Event) {
|
|
history := append(h.history[userID], event)
|
|
if len(history) > h.replayBufferSizePerUser {
|
|
history = history[len(history)-h.replayBufferSizePerUser:]
|
|
}
|
|
h.history[userID] = history
|
|
}
|
|
|
|
func (h *Hub) snapshotConnectionsLocked(userID string) []*conn {
|
|
conns := h.clients[userID]
|
|
if len(conns) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]*conn, 0, len(conns))
|
|
for c := range conns {
|
|
out = append(out, c)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (h *Hub) canRegisterLocked(c *conn) error {
|
|
if len(h.clients[c.userID]) >= h.maxConnectionsPerUser {
|
|
return errors.New("too many websocket connections for user")
|
|
}
|
|
if h.sessionCounts[c.userID] == nil {
|
|
return nil
|
|
}
|
|
if h.sessionCounts[c.userID][c.sessionKey] >= h.maxConnectionsPerSess {
|
|
return errors.New("too many websocket connections for session")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (h *Hub) register(c *conn) error {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if h.clients[c.userID] == nil {
|
|
h.clients[c.userID] = make(map[*conn]struct{})
|
|
}
|
|
if err := h.canRegisterLocked(c); err != nil {
|
|
return err
|
|
}
|
|
|
|
h.clients[c.userID][c] = struct{}{}
|
|
if h.sessionCounts[c.userID] == nil {
|
|
h.sessionCounts[c.userID] = make(map[string]int)
|
|
}
|
|
h.sessionCounts[c.userID][c.sessionKey]++
|
|
|
|
h.logger.Info("ws connected",
|
|
"user_id", c.userID,
|
|
"session", c.sessionKey,
|
|
"user_conns", len(h.clients[c.userID]),
|
|
"session_conns", h.sessionCounts[c.userID][c.sessionKey],
|
|
)
|
|
return nil
|
|
}
|
|
|
|
func (h *Hub) unregister(c *conn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
|
|
if conns, ok := h.clients[c.userID]; ok {
|
|
delete(conns, c)
|
|
if len(conns) == 0 {
|
|
delete(h.clients, c.userID)
|
|
}
|
|
}
|
|
|
|
if sessions := h.sessionCounts[c.userID]; sessions != nil {
|
|
next := sessions[c.sessionKey] - 1
|
|
if next <= 0 {
|
|
delete(sessions, c.sessionKey)
|
|
} else {
|
|
sessions[c.sessionKey] = next
|
|
}
|
|
if len(sessions) == 0 {
|
|
delete(h.sessionCounts, c.userID)
|
|
}
|
|
}
|
|
|
|
_ = c.ws.Close(websocket.StatusNormalClosure, "")
|
|
}
|
|
|
|
func (h *Hub) Broadcast(userID string, event Event) {
|
|
h.mu.Lock()
|
|
event.Seq = h.nextSequenceLocked()
|
|
h.appendHistoryLocked(userID, event)
|
|
conns := h.snapshotConnectionsLocked(userID)
|
|
h.mu.Unlock()
|
|
|
|
for _, c := range conns {
|
|
if err := h.writeConn(context.Background(), c, event); err != nil {
|
|
h.logger.Error("ws write", "error", err, "user_id", userID)
|
|
go h.unregister(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) HistoryHead(userID string) uint64 {
|
|
h.mu.RLock()
|
|
defer h.mu.RUnlock()
|
|
history := h.history[userID]
|
|
if len(history) == 0 {
|
|
return 0
|
|
}
|
|
return history[len(history)-1].Seq
|
|
}
|
|
|
|
func (h *Hub) SetLimits(maxUser, maxSession, replaySize int) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
if maxUser > 0 {
|
|
h.maxConnectionsPerUser = maxUser
|
|
}
|
|
if maxSession > 0 {
|
|
h.maxConnectionsPerSess = maxSession
|
|
}
|
|
if replaySize > 0 {
|
|
h.replayBufferSizePerUser = replaySize
|
|
}
|
|
}
|