- Updated .env.example to include new configuration options for AI gateway and WebUI secret key. - Modified Nginx configuration to support additional API routes for model management and migration. - Implemented new API endpoints for discovering organization-level LLM models and managing hosted mail services. - Enhanced AI gateway logic to support organization-specific model access and permissions. - Improved error handling and response structures in the AI and mail APIs. - Added integration tests for new features and updated existing tests for model access control.
403 lines
13 KiB
Go
403 lines
13 KiB
Go
package aiapi
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/ai"
|
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
|
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
|
"github.com/ultisuite/ulti-backend/internal/apitokens"
|
|
"github.com/ultisuite/ulti-backend/internal/auth"
|
|
"github.com/ultisuite/ulti-backend/internal/config"
|
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
|
"github.com/ultisuite/ulti-backend/internal/permission"
|
|
)
|
|
|
|
const sessionAccessCookie = "ulti_access_token"
|
|
|
|
type Handler struct {
|
|
db *pgxpool.Pool
|
|
cfg *config.Config
|
|
gateway *ai.Gateway
|
|
quota *ai.QuotaService
|
|
chats *ai.ChatStore
|
|
verify *auth.Holder
|
|
}
|
|
|
|
func NewHandler(db *pgxpool.Pool, cfg *config.Config, nc *nextcloud.Client, verifier *auth.Holder) *Handler {
|
|
return &Handler{
|
|
db: db,
|
|
cfg: cfg,
|
|
gateway: ai.NewGateway(db),
|
|
quota: ai.NewQuotaService(db),
|
|
chats: ai.NewChatStore(nc, db),
|
|
verify: verifier,
|
|
}
|
|
}
|
|
|
|
func (h *Handler) Routes(authMiddleware func(http.Handler) http.Handler) chi.Router {
|
|
r := chi.NewRouter()
|
|
r.Get("/config", h.GetConfig)
|
|
r.Get("/embed-auth", h.EmbedAuth)
|
|
r.Post("/embed-signin", h.EmbedSignin)
|
|
// OpenWebUI gateway (Bearer AI_GATEWAY_API_KEY) or user JWT — not behind Auth middleware
|
|
r.Get("/models", h.ListModels)
|
|
r.Post("/chat/completions", h.ChatCompletions)
|
|
r.Post("/v1/chat/completions", h.ChatCompletions)
|
|
r.Group(func(r chi.Router) {
|
|
r.Use(authMiddleware)
|
|
r.Get("/quota", h.GetQuota)
|
|
r.Post("/sessions", h.CreateSession)
|
|
r.Get("/chats/{chatID}", h.GetChat)
|
|
r.Delete("/chats/{chatID}", h.DeleteChat)
|
|
r.Post("/chats/sync", h.SyncChat)
|
|
})
|
|
return r
|
|
}
|
|
|
|
func (h *Handler) GetConfig(w http.ResponseWriter, r *http.Request) {
|
|
deployEnabled := h.cfg != nil && h.cfg.AIAssistantEnabled
|
|
policy, enabled := ai.IsAssistantEnabled(r.Context(), h.db, deployEnabled)
|
|
publicPath := policy.PublicPath
|
|
if strings.TrimSpace(publicPath) == "" {
|
|
publicPath = "/ai"
|
|
}
|
|
if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" {
|
|
publicPath = h.cfg.AIAssistantPublicPath
|
|
}
|
|
models := make([]map[string]any, 0, len(policy.Models))
|
|
for _, entry := range policy.Models {
|
|
models = append(models, map[string]any{
|
|
"model_id": entry.ModelID,
|
|
"label": entry.Label,
|
|
"enabled": entry.Enabled,
|
|
})
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
|
"enabled": enabled,
|
|
"public_path": publicPath,
|
|
"embed_default_temporary": policy.EmbedDefaultTemporary,
|
|
"default_model": policy.DefaultModel,
|
|
"enabled_tools": policy.EnabledTools,
|
|
"chat_sync_enabled": policy.ChatSyncEnabled,
|
|
"models": models,
|
|
"restrict_models": len(policy.Models) > 0,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) EmbedAuth(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := h.resolveClaims(r)
|
|
if !ok || strings.TrimSpace(claims.Email) == "" {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
h.writeTrustedUserHeaders(w, claims)
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// EmbedSignin proxies OpenWebUI trusted-header signin (nginx routes /api/v1/auths/signin here).
|
|
func (h *Handler) EmbedSignin(w http.ResponseWriter, r *http.Request) {
|
|
claims, ok := h.resolveClaims(r)
|
|
if !ok || strings.TrimSpace(claims.Email) == "" {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
|
|
baseURL := strings.TrimRight(strings.TrimSpace(h.cfg.OpenWebUIInternalURL), "/")
|
|
if baseURL == "" {
|
|
baseURL = "http://openwebui:8080"
|
|
}
|
|
|
|
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil)
|
|
return
|
|
}
|
|
if len(body) == 0 {
|
|
body = []byte(`{"email":"","password":""}`)
|
|
}
|
|
|
|
upstreamReq, err := http.NewRequestWithContext(r.Context(), http.MethodPost, baseURL+"/api/v1/auths/signin", bytes.NewReader(body))
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
contentType := strings.TrimSpace(r.Header.Get("Content-Type"))
|
|
if contentType == "" {
|
|
contentType = "application/json"
|
|
}
|
|
upstreamReq.Header.Set("Content-Type", contentType)
|
|
upstreamReq.Header.Set("X-Ulti-User-Email", claims.Email)
|
|
name := strings.TrimSpace(claims.Name)
|
|
if name == "" {
|
|
name = claims.Email
|
|
}
|
|
upstreamReq.Header.Set("X-Ulti-User-Name", name)
|
|
upstreamReq.Header.Set("X-Ulti-User-Role", openWebUIRole(claims))
|
|
|
|
resp, err := http.DefaultClient.Do(upstreamReq)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
for k, vals := range resp.Header {
|
|
for _, v := range vals {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = io.Copy(w, io.LimitReader(resp.Body, 8<<20))
|
|
}
|
|
|
|
func (h *Handler) writeTrustedUserHeaders(w http.ResponseWriter, claims *auth.Claims) {
|
|
w.Header().Set("X-Ulti-User-Email", claims.Email)
|
|
if strings.TrimSpace(claims.Name) != "" {
|
|
w.Header().Set("X-Ulti-User-Name", claims.Name)
|
|
} else {
|
|
w.Header().Set("X-Ulti-User-Name", claims.Email)
|
|
}
|
|
w.Header().Set("X-Ulti-User-Role", openWebUIRole(claims))
|
|
}
|
|
|
|
func openWebUIRole(claims *auth.Claims) string {
|
|
if claims != nil && permission.HasRole(claims.Groups, permission.RoleAdmin) {
|
|
return "admin"
|
|
}
|
|
return "user"
|
|
}
|
|
|
|
func (h *Handler) GetQuota(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
status, err := h.quota.Check(r.Context(), claims.Sub)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, status)
|
|
}
|
|
|
|
func (h *Handler) ListModels(w http.ResponseWriter, r *http.Request) {
|
|
externalUserID, useOrg, ok := h.resolveAIAccess(r)
|
|
if !ok {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
var (
|
|
models []map[string]any
|
|
err error
|
|
)
|
|
if useOrg {
|
|
models, err = h.gateway.ListOrgModels(r.Context())
|
|
} else {
|
|
models, err = h.gateway.ListModels(r.Context(), externalUserID)
|
|
}
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
|
"object": "list",
|
|
"data": models,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) ChatCompletions(w http.ResponseWriter, r *http.Request) {
|
|
externalUserID, useOrg, ok := h.resolveAIAccess(r)
|
|
if !ok {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
body, err := io.ReadAll(io.LimitReader(r.Body, 8<<20))
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid body", nil)
|
|
return
|
|
}
|
|
subject := externalUserID
|
|
if useOrg {
|
|
subject = "openwebui-gateway"
|
|
}
|
|
if err := h.gateway.ProxyChatCompletions(r.Context(), subject, body, w); err != nil {
|
|
if errors.Is(err, ai.ErrQuotaExceeded) {
|
|
apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (h *Handler) CreateSession(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
var req ai.SessionContext
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil)
|
|
return
|
|
}
|
|
preset := apitokens.ChatSessionStandalone
|
|
switch strings.ToLower(strings.TrimSpace(req.App)) {
|
|
case "mail":
|
|
preset = apitokens.ChatSessionMail
|
|
case "drive":
|
|
preset = apitokens.ChatSessionDrive
|
|
case "contacts":
|
|
preset = apitokens.ChatSessionContacts
|
|
case "docs":
|
|
preset = apitokens.ChatSessionDocs
|
|
}
|
|
allowWrite := preset == apitokens.ChatSessionDocs
|
|
created, err := apitokens.CreateChatSession(r.Context(), h.db, claims.Sub, claims.Email, apitokens.ChatSessionInput{
|
|
Preset: preset,
|
|
DrivePath: req.DrivePath,
|
|
AllowWrite: allowWrite,
|
|
})
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db)
|
|
publicPath := policy.PublicPath
|
|
if strings.TrimSpace(publicPath) == "" {
|
|
publicPath = "/ai"
|
|
}
|
|
if h.cfg != nil && strings.TrimSpace(h.cfg.AIAssistantPublicPath) != "" {
|
|
publicPath = h.cfg.AIAssistantPublicPath
|
|
}
|
|
temporary := req.Temporary || policy.EmbedDefaultTemporary
|
|
q := url.Values{}
|
|
if temporary {
|
|
q.Set("temporary-chat", "true")
|
|
}
|
|
if strings.TrimSpace(req.App) != "" {
|
|
q.Set("app", req.App)
|
|
}
|
|
embedURL := strings.TrimRight(publicPath, "/") + "/"
|
|
if enc := q.Encode(); enc != "" {
|
|
embedURL += "?" + enc
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, ai.SessionResponse{
|
|
SessionID: created.ID,
|
|
EmbedURL: embedURL,
|
|
TokenSecret: created.TokenSecret,
|
|
Temporary: temporary,
|
|
})
|
|
}
|
|
|
|
func (h *Handler) SyncChat(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
policy, _ := ai.LoadAssistantPolicy(r.Context(), h.db)
|
|
if !policy.ChatSyncEnabled {
|
|
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "chat sync disabled", nil)
|
|
return
|
|
}
|
|
var record ai.ChatRecord
|
|
if err := json.NewDecoder(r.Body).Decode(&record); err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid json", nil)
|
|
return
|
|
}
|
|
if err := h.chats.Sync(r.Context(), claims.Email, record); err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadGateway, apiresponse.CodeInternal, err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
|
}
|
|
|
|
func (h *Handler) GetChat(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
chatID := chi.URLParam(r, "chatID")
|
|
record, err := h.chats.Get(r.Context(), claims.Email, chatID)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, record)
|
|
}
|
|
|
|
func (h *Handler) DeleteChat(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthUnauthorized, "unauthorized", nil)
|
|
return
|
|
}
|
|
chatID := chi.URLParam(r, "chatID")
|
|
if err := h.chats.Delete(r.Context(), claims.Email, chatID); err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, "chat not found", nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
|
}
|
|
|
|
func (h *Handler) resolveAIAccess(r *http.Request) (externalUserID string, useOrgSettings bool, ok bool) {
|
|
if token := bearerToken(r); token != "" {
|
|
if h.cfg != nil && h.cfg.AIGatewayAPIKey != "" && token == h.cfg.AIGatewayAPIKey {
|
|
return "", true, true
|
|
}
|
|
if h.verify != nil && h.verify.Ready() {
|
|
claims, err := h.verify.Verify(r.Context(), token)
|
|
if err == nil && claims != nil && strings.TrimSpace(claims.Sub) != "" {
|
|
return claims.Sub, false, true
|
|
}
|
|
}
|
|
}
|
|
if claims := middleware.ClaimsFromContext(r.Context()); claims != nil && strings.TrimSpace(claims.Sub) != "" {
|
|
return claims.Sub, false, true
|
|
}
|
|
return "", false, false
|
|
}
|
|
|
|
func bearerToken(r *http.Request) string {
|
|
header := strings.TrimSpace(r.Header.Get("Authorization"))
|
|
if !strings.HasPrefix(header, "Bearer ") {
|
|
return ""
|
|
}
|
|
return strings.TrimSpace(strings.TrimPrefix(header, "Bearer "))
|
|
}
|
|
|
|
func (h *Handler) resolveClaims(r *http.Request) (*auth.Claims, bool) {
|
|
if header := strings.TrimSpace(r.Header.Get("Authorization")); strings.HasPrefix(header, "Bearer ") {
|
|
token := strings.TrimSpace(strings.TrimPrefix(header, "Bearer "))
|
|
if h.verify != nil && h.verify.Ready() {
|
|
claims, err := h.verify.Verify(r.Context(), token)
|
|
if err == nil {
|
|
return claims, true
|
|
}
|
|
}
|
|
}
|
|
if cookie, err := r.Cookie(sessionAccessCookie); err == nil {
|
|
token := strings.TrimSpace(cookie.Value)
|
|
if token != "" && h.verify != nil && h.verify.Ready() {
|
|
claims, err := h.verify.Verify(r.Context(), token)
|
|
if err == nil {
|
|
return claims, true
|
|
}
|
|
}
|
|
}
|
|
return nil, false
|
|
}
|