package aiapi import ( "bytes" "encoding/json" "errors" "io" "net/http" "net/url" "strings" "github.com/go-chi/chi/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/ai" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/apitokens" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/config" "github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/permission" ) const sessionAccessCookie = "ulti_access_token" type Handler struct { db *pgxpool.Pool cfg *config.Config gateway *ai.Gateway quota *ai.QuotaService chats *ai.ChatStore verify *auth.Holder } func NewHandler(db *pgxpool.Pool, cfg *config.Config, nc *nextcloud.Client, verifier *auth.Holder) *Handler { return &Handler{ db: db, cfg: cfg, gateway: ai.NewGateway(db), quota: ai.NewQuotaService(db), chats: ai.NewChatStore(nc, db), verify: verifier, } } func (h *Handler) Routes(authMiddleware func(http.Handler) http.Handler) chi.Router { r := chi.NewRouter() r.Get("/config", h.GetConfig) r.Get("/embed-auth", h.EmbedAuth) r.Post("/embed-signin", h.EmbedSignin) // OpenWebUI gateway (Bearer AI_GATEWAY_API_KEY) or user JWT — not behind Auth middleware r.Get("/models", h.ListModels) r.Post("/chat/completions", h.ChatCompletions) r.Post("/v1/chat/completions", h.ChatCompletions) r.Group(func(r chi.Router) { r.Use(authMiddleware) r.Get("/quota", h.GetQuota) r.Post("/sessions", h.CreateSession) r.Get("/chats/{chatID}", h.GetChat) r.Delete("/chats/{chatID}", h.DeleteChat) r.Post("/chats/sync", h.SyncChat) }) return r } func (h *Handler) GetConfig(w http.ResponseWriter, r *http.Request) { deployEnabled := h.cfg != nil && h.cfg.AIAssistantEnabled policy, enabled := ai.IsAssistantEnabled(r.Context(), h.db, deployEnabled) publicPath := policy.PublicPath if strings.TrimSpace(publicPath) == "" { publicPath = "/ai" } if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" { publicPath = h.cfg.AIAssistantPublicPath } models := make([]map[string]any, 0, len(policy.Models)) for _, entry := range policy.Models { models = append(models, map[string]any{ "model_id": entry.ModelID, "label": entry.Label, "enabled": entry.Enabled, }) } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ "enabled": enabled, "public_path": publicPath, "embed_default_temporary": policy.EmbedDefaultTemporary, "default_model": policy.DefaultModel, "enabled_tools": policy.EnabledTools, "chat_sync_enabled": policy.ChatSyncEnabled, "models": models, "restrict_models": len(policy.Models) > 0, }) } func (h *Handler) EmbedAuth(w http.ResponseWriter, r *http.Request) { claims, ok := h.resolveClaims(r) if !ok || strings.TrimSpace(claims.Email) == "" { w.WriteHeader(http.StatusUnauthorized) return } h.writeTrustedUserHeaders(w, claims) w.WriteHeader(http.StatusOK) } // EmbedSignin proxies OpenWebUI trusted-header signin (nginx routes /api/v1/auths/signin here). func (h *Handler) EmbedSignin(w http.ResponseWriter, r *http.Request) { claims, ok := h.resolveClaims(r) if !ok || strings.TrimSpace(claims.Email) == "" { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } baseURL := strings.TrimRight(strings.TrimSpace(h.cfg.OpenWebUIInternalURL), "/") if baseURL == "" { baseURL = "http://openwebui:8080" } body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20)) if err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil) return } if len(body) == 0 { body = []byte(`{"email":"","password":""}`) } upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, baseURL+"/api/v1/auths/signin", bytes.NewReader(body)) if err != nil { apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) return } contentType := strings.TrimSpace(r.Header.Get("Content-Type")) if contentType == "" { contentType = "application/json" } upstreamReq.Header.Set("Content-Type", contentType) upstreamReq.Header.Set("X-Ulti-User-Email", claims.Email) name := strings.TrimSpace(claims.Name) if name == "" { name = claims.Email } upstreamReq.Header.Set("X-Ulti-User-Name", name) upstreamReq.Header.Set("X-Ulti-User-Role", openWebUIRole(claims)) resp, err := http.DefaultClient.Do(upstreamReq) if err != nil { apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil) return } defer resp.Body.Close() for k, vals := range resp.Header { for _, v := range vals { w.Header().Add(k, v) } } w.WriteHeader(resp.StatusCode) _, _ = io.Copy(w, io.LimitReader(resp.Body, 8<<20)) } func (h *Handler) writeTrustedUserHeaders(w http.ResponseWriter, claims *auth.Claims) { w.Header().Set("X-Ulti-User-Email", claims.Email) if strings.TrimSpace(claims.Name) != "" { w.Header().Set("X-Ulti-User-Name", claims.Name) } else { w.Header().Set("X-Ulti-User-Name", claims.Email) } w.Header().Set("X-Ulti-User-Role", openWebUIRole(claims)) } func openWebUIRole(claims *auth.Claims) string { if claims != nil && permission.HasRole(claims.Groups, permission.RoleAdmin) { return "admin" } return "user" } func (h *Handler) GetQuota(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } status, err := h.quota.Check(r.Context(), claims.Sub) if err != nil { apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, status) } func (h *Handler) ListModels(w http.ResponseWriter, r *http.Request) { externalUserID, useOrg, ok := h.resolveAIAccess(r) if !ok { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } var ( models []map[string]any err error ) if useOrg { models, err = h.gateway.ListOrgModels(r.Context()) } else { models, err = h.gateway.ListModels(r.Context(), externalUserID) } if err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ "object": "list", "data": models, }) } func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) { externalUserID, useOrg, ok := h.resolveAIAccess(r) if !ok { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } body, err := io.ReadAll(io.LimitReader(r.Body, 8<<20)) if err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil) return } subject := externalUserID if useOrg { subject = "openwebui-gateway" } if err := h.gateway.ProxyChatCompletions(r.Context(), subject, body, w); err != nil { if errors.Is(err, ai.ErrQuotaExceeded) { apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, err.Error(), nil) return } apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil) return } } func (h *Handler) CreateSession(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } var req ai.SessionContext if err := json.NewDecoder(r.Body).Decode(&req); err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil) return } preset := apitokens.ChatSessionStandalone switch strings.ToLower(strings.TrimSpace(req.App)) { case "mail": preset = apitokens.ChatSessionMail case "drive": preset = apitokens.ChatSessionDrive case "contacts": preset = apitokens.ChatSessionContacts case "docs": preset = apitokens.ChatSessionDocs } allowWrite := preset == apitokens.ChatSessionDocs created, err := apitokens.CreateChatSession(r.Context(), h.db, claims.Sub, claims.Email, apitokens.ChatSessionInput{ Preset: preset, DrivePath: req.DrivePath, AllowWrite: allowWrite, }) if err != nil { apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) return } policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db) publicPath := policy.PublicPath if strings.TrimSpace(publicPath) == "" { publicPath = "/ai" } if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" { publicPath = h.cfg.AIAssistantPublicPath } temporary := req.Temporary || policy.EmbedDefaultTemporary q := url.Values{} if temporary { q.Set("temporary-chat", "true") } if strings.TrimSpace(req.App) != "" { q.Set("app", req.App) } embedURL := strings.TrimRight(publicPath, "/") + "/" if enc := q.Encode(); enc != "" { embedURL += "?" + enc } apiresponse.WriteJSON(w, http.StatusOK, ai.SessionResponse{ SessionID: created.ID, EmbedURL: embedURL, TokenSecret: created.TokenSecret, Temporary: temporary, }) } func (h *Handler) SyncChat(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db) if !policy.ChatSyncEnabled { apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "chat sync disabled", nil) return } var record ai.ChatRecord if err := json.NewDecoder(r.Body).Decode(&record); err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil) return } if err := h.chats.Sync(r.Context(), claims.Email, record); err != nil { apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true}) } func (h *Handler) GetChat(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } chatID := chi.URLParam(r, "chatID") record, err := h.chats.Get(r.Context(), claims.Email, chatID) if err != nil { apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil) return } apiresponse.WriteJSON(w, http.StatusOK, record) } func (h *Handler) DeleteChat(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil) return } chatID := chi.URLParam(r, "chatID") if err := h.chats.Delete(r.Context(), claims.Email, chatID); err != nil { apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil) return } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true}) } func (h *Handler) resolveAIAccess(r *http.Request) (externalUserID string, useOrgSettings bool, ok bool) { if token := bearerToken(r); token != "" { if h.cfg != nil && h.cfg.AIGatewayAPIKey != "" && token == h.cfg.AIGatewayAPIKey { return "", true, true } if h.verify != nil && h.verify.Ready() { claims, err := h.verify.Verify(r.Context(), token) if err == nil && claims != nil && strings.TrimSpace(claims.Sub) != "" { return claims.Sub, false, true } } } if claims := middleware.ClaimsFromContext(r.Context()); claims != nil && strings.TrimSpace(claims.Sub) != "" { return claims.Sub, false, true } return "", false, false } func bearerToken(r *http.Request) string { header := strings.TrimSpace(r.Header.Get("Authorization")) if !strings.HasPrefix(header, "Bearer ") { return "" } return strings.TrimSpace(strings.TrimPrefix(header, "Bearer ")) } func (h *Handler) resolveClaims(r *http.Request) (*auth.Claims, bool) { if header := strings.TrimSpace(r.Header.Get("Authorization")); strings.HasPrefix(header, "Bearer ") { token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer ")) if h.verify != nil && h.verify.Ready() { claims, err := h.verify.Verify(r.Context(), token) if err == nil { return claims, true } } } if cookie, err := r.Cookie(sessionAccessCookie); err == nil { token := strings.TrimSpace(cookie.Value) if token != "" && h.verify != nil && h.verify.Ready() { claims, err := h.verify.Verify(r.Context(), token) if err == nil { return claims, true } } } return nil, false }