ultisuite-backend/internal/api/ai/handlers.go
R3D347HR4Y 0466a1c169
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
wow
2026-06-11 01:22:52 +02:00

287 lines
9.3 KiB
Go

package aiapi
import (
"encoding/json"
"errors"
"io"
"net/http"
"net/url"
"strings"
"github.com/go-chi/chi/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/api/middleware"
"github.com/ultisuite/ulti-backend/internal/apitokens"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/config"
"github.com/ultisuite/ulti-backend/internal/nextcloud"
)
const sessionAccessCookie = "ulti_access_token"
type Handler struct {
db *pgxpool.Pool
cfg *config.Config
gateway *ai.Gateway
quota *ai.QuotaService
chats *ai.ChatStore
verify *auth.Holder
}
func NewHandler(db *pgxpool.Pool, cfg *config.Config, nc *nextcloud.Client, verifier *auth.Holder) *Handler {
return &Handler{
db: db,
cfg: cfg,
gateway: ai.NewGateway(db),
quota: ai.NewQuotaService(db),
chats: ai.NewChatStore(nc, db),
verify: verifier,
}
}
func (h *Handler) Routes(authMiddleware func(http.Handler) http.Handler) chi.Router {
r := chi.NewRouter()
r.Get("/config", h.GetConfig)
r.Get("/embed-auth", h.EmbedAuth)
r.Group(func(r chi.Router) {
r.Use(authMiddleware)
r.Get("/quota", h.GetQuota)
r.Get("/models", h.ListModels)
r.Post("/chat/completions", h.ChatCompletions)
r.Post("/v1/chat/completions", h.ChatCompletions)
r.Post("/sessions", h.CreateSession)
r.Get("/chats/{chatID}", h.GetChat)
r.Delete("/chats/{chatID}", h.DeleteChat)
r.Post("/chats/sync", h.SyncChat)
})
return r
}
func (h *Handler) GetConfig(w http.ResponseWriter, r *http.Request) {
policy, err := ai.LoadAssistantPolicy(r.Context(), h.db)
if err != nil {
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, "failed to load ai config", nil)
return
}
publicPath := policy.PublicPath
if strings.TrimSpace(publicPath) == "" {
publicPath = "/ai"
}
if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" {
publicPath = h.cfg.AIAssistantPublicPath
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
"enabled": policy.Enabled || (h.cfg != nil && h.cfg.AIAssistantEnabled),
"public_path": publicPath,
"embed_default_temporary": policy.EmbedDefaultTemporary,
"default_model": policy.DefaultModel,
"enabled_tools": policy.EnabledTools,
"chat_sync_enabled": policy.ChatSyncEnabled,
})
}
func (h *Handler) EmbedAuth(w http.ResponseWriter, r *http.Request) {
claims, ok := h.resolveClaims(r)
if !ok {
w.WriteHeader(http.StatusUnauthorized)
return
}
w.Header().Set("X-Ulti-User-Email", claims.Email)
if strings.TrimSpace(claims.Name) != "" {
w.Header().Set("X-Ulti-User-Name", claims.Name)
} else {
w.Header().Set("X-Ulti-User-Name", claims.Email)
}
w.Header().Set("X-Ulti-User-Role", "user")
w.WriteHeader(http.StatusOK)
}
func (h *Handler) GetQuota(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
status, err := h.quota.Check(r.Context(), claims.Sub)
if err != nil {
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, status)
}
func (h *Handler) ListModels(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
models, err := h.gateway.ListModels(r.Context(), claims.Sub)
if err != nil {
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
"object": "list",
"data": models,
})
}
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
body, err := io.ReadAll(io.LimitReader(r.Body, 8<<20))
if err != nil {
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil)
return
}
if err := h.gateway.ProxyChatCompletions(r.Context(), claims.Sub, body, w); err != nil {
if errors.Is(err, ai.ErrQuotaExceeded) {
apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, err.Error(), nil)
return
}
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
return
}
}
func (h *Handler) CreateSession(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
var req ai.SessionContext
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil)
return
}
preset := apitokens.ChatSessionStandalone
switch strings.ToLower(strings.TrimSpace(req.App)) {
case "mail":
preset = apitokens.ChatSessionMail
case "drive":
preset = apitokens.ChatSessionDrive
case "contacts":
preset = apitokens.ChatSessionContacts
case "docs":
preset = apitokens.ChatSessionDocs
}
allowWrite := preset == apitokens.ChatSessionDocs
created, err := apitokens.CreateChatSession(r.Context(), h.db, claims.Sub, claims.Email, apitokens.ChatSessionInput{
Preset: preset,
DrivePath: req.DrivePath,
AllowWrite: allowWrite,
})
if err != nil {
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db)
publicPath := policy.PublicPath
if strings.TrimSpace(publicPath) == "" {
publicPath = "/ai"
}
if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" {
publicPath = h.cfg.AIAssistantPublicPath
}
temporary := req.Temporary || policy.EmbedDefaultTemporary
q := url.Values{}
if temporary {
q.Set("temporary-chat", "true")
}
if strings.TrimSpace(req.App) != "" {
q.Set("app", req.App)
}
embedURL := strings.TrimRight(publicPath, "/") + "/"
if enc := q.Encode(); enc != "" {
embedURL += "?" + enc
}
apiresponse.WriteJSON(w, http.StatusOK, ai.SessionResponse{
SessionID: created.ID,
EmbedURL: embedURL,
TokenSecret: created.TokenSecret,
Temporary: temporary,
})
}
func (h *Handler) SyncChat(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db)
if !policy.ChatSyncEnabled {
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "chat sync disabled", nil)
return
}
var record ai.ChatRecord
if err := json.NewDecoder(r.Body).Decode(&record); err != nil {
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil)
return
}
if err := h.chats.Sync(r.Context(), claims.Email, record); err != nil {
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
}
func (h *Handler) GetChat(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
chatID := chi.URLParam(r, "chatID")
record, err := h.chats.Get(r.Context(), claims.Email, chatID)
if err != nil {
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, record)
}
func (h *Handler) DeleteChat(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
if claims == nil {
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
return
}
chatID := chi.URLParam(r, "chatID")
if err := h.chats.Delete(r.Context(), claims.Email, chatID); err != nil {
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
}
func (h *Handler) resolveClaims(r *http.Request) (*auth.Claims, bool) {
if header := strings.TrimSpace(r.Header.Get("Authorization")); strings.HasPrefix(header, "Bearer ") {
token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer "))
if h.verify != nil && h.verify.Ready() {
claims, err := h.verify.Verify(r.Context(), token)
if err == nil {
return claims, true
}
}
}
if cookie, err := r.Cookie(sessionAccessCookie); err == nil {
token := strings.TrimSpace(cookie.Value)
if token != "" && h.verify != nil && h.verify.Ready() {
claims, err := h.verify.Verify(r.Context(), token)
if err == nil {
return claims, true
}
}
}
return nil, false
}