98 lines
1.9 KiB
Go
98 lines
1.9 KiB
Go
package realtime
|
|
|
|
import (
|
|
"context"
|
|
"log/slog"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/coder/websocket"
|
|
"github.com/coder/websocket/wsjson"
|
|
)
|
|
|
|
type Event struct {
|
|
Type string `json:"type"`
|
|
Payload any `json:"payload"`
|
|
}
|
|
|
|
type Hub struct {
|
|
mu sync.RWMutex
|
|
clients map[string]map[*conn]struct{} // userID -> connections
|
|
logger *slog.Logger
|
|
}
|
|
|
|
type conn struct {
|
|
ws *websocket.Conn
|
|
userID string
|
|
}
|
|
|
|
func NewHub() *Hub {
|
|
return &Hub{
|
|
clients: make(map[string]map[*conn]struct{}),
|
|
logger: slog.Default().With("component", "ws-hub"),
|
|
}
|
|
}
|
|
|
|
func (h *Hub) HandleWS(w http.ResponseWriter, r *http.Request) {
|
|
userID := r.URL.Query().Get("user_id")
|
|
if userID == "" {
|
|
http.Error(w, "missing user_id", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
ws, err := websocket.Accept(w, r, &websocket.AcceptOptions{
|
|
OriginPatterns: []string{"*"},
|
|
})
|
|
if err != nil {
|
|
h.logger.Error("websocket accept", "error", err)
|
|
return
|
|
}
|
|
|
|
c := &conn{ws: ws, userID: userID}
|
|
h.register(c)
|
|
defer h.unregister(c)
|
|
|
|
ctx := r.Context()
|
|
for {
|
|
_, _, err := ws.Read(ctx)
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) Broadcast(userID string, event Event) {
|
|
h.mu.RLock()
|
|
conns := h.clients[userID]
|
|
h.mu.RUnlock()
|
|
|
|
for c := range conns {
|
|
if err := wsjson.Write(context.Background(), c.ws, event); err != nil {
|
|
h.logger.Error("ws write", "error", err, "user_id", userID)
|
|
go h.unregister(c)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (h *Hub) register(c *conn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
if h.clients[c.userID] == nil {
|
|
h.clients[c.userID] = make(map[*conn]struct{})
|
|
}
|
|
h.clients[c.userID][c] = struct{}{}
|
|
h.logger.Info("ws connected", "user_id", c.userID)
|
|
}
|
|
|
|
func (h *Hub) unregister(c *conn) {
|
|
h.mu.Lock()
|
|
defer h.mu.Unlock()
|
|
if conns, ok := h.clients[c.userID]; ok {
|
|
delete(conns, c)
|
|
if len(conns) == 0 {
|
|
delete(h.clients, c.userID)
|
|
}
|
|
}
|
|
c.ws.Close(websocket.StatusNormalClosure, "")
|
|
}
|