package ai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/ai/cost" "github.com/ultisuite/ulti-backend/internal/llm" ) type Gateway struct { db *pgxpool.Pool client *http.Client quota *QuotaService } func NewGateway(db *pgxpool.Pool) *Gateway { return &Gateway{ db: db, client: &http.Client{ Timeout: 0, }, quota: NewQuotaService(db), } } type chatCompletionRequest struct { Model string `json:"model"` Messages []llm.ChatMessage `json:"messages"` Temperature *float64 `json:"temperature,omitempty"` Stream bool `json:"stream,omitempty"` Tools []any `json:"tools,omitempty"` } type chatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []struct { Index int `json:"index"` Message llm.ChatMessage `json:"message"` FinishReason string `json:"finish_reason"` Delta *llm.ChatMessage `json:"delta,omitempty"` } `json:"choices"` Usage *struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` PromptTokensDetails *struct { CachedTokens int `json:"cached_tokens"` } `json:"prompt_tokens_details,omitempty"` CompletionTokensDetails *struct { ReasoningTokens int `json:"reasoning_tokens"` } `json:"completion_tokens_details,omitempty"` } `json:"usage,omitempty"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` } func (g *Gateway) ListModels(ctx context.Context, externalUserID string) ([]map[string]any, error) { settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID) if err != nil { return nil, err } return g.listModelsFromSettings(ctx, settings) } func (g *Gateway) ListOrgModels(ctx context.Context) ([]map[string]any, error) { settings, err := LoadOrgLLMSettings(ctx, g.db) if err != nil { return nil, err } return g.listModelsFromSettings(ctx, settings) } func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Settings) ([]map[string]any, error) { policy, _ := LoadAssistantPolicy(ctx, g.db) client := llm.NewClient() seen := make(map[string]struct{}) out := make([]map[string]any, 0) for _, provider := range settings.Providers { models, err := client.ListModels(ctx, provider) if err != nil { continue } for _, modelID := range models { if _, ok := seen[modelID]; ok { continue } seen[modelID] = struct{}{} out = append(out, map[string]any{ "id": modelID, "object": "model", "owned_by": provider.Name, }) } } if len(out) == 0 && len(settings.Providers) > 0 { p := settings.Providers[0] model := strings.TrimSpace(p.DefaultModel) if model != "" { out = append(out, map[string]any{ "id": model, "object": "model", "owned_by": p.Name, }) } } return ApplyModelCatalog(out, policy.Models), nil } func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error { var modelProbe struct { Model string `json:"model"` } if err := json.Unmarshal(body, &modelProbe); err != nil { return fmt.Errorf("invalid chat completion request: %w", err) } policy, _ := LoadAssistantPolicy(ctx, g.db) if !IsModelAllowed(modelProbe.Model, policy.Models) { return fmt.Errorf("model %q is not allowed", strings.TrimSpace(modelProbe.Model)) } var ( settings llm.Settings err error ) if useOrgSettings { settings, err = LoadOrgLLMSettings(ctx, g.db) } else { settings, err = LoadEffectiveLLMSettings(ctx, g.db, quotaExternalUserID) } if err != nil { return err } provider, model, err := resolveProviderForModel(settings, modelProbe.Model) if err != nil { return err } billingScope := ResolveBillingScope(ctx, g.db, quotaExternalUserID, provider, useOrgSettings) if strings.TrimSpace(quotaExternalUserID) != "" { if err := g.quota.AssertAvailable(ctx, quotaExternalUserID, provider, useOrgSettings); err != nil { return err } } upstreamBody, err := repairChatCompletionBody(body) if err != nil { return err } stream := chatCompletionStreamRequested(upstreamBody) if strings.TrimSpace(modelProbe.Model) == "" { var req map[string]any if err := json.Unmarshal(upstreamBody, &req); err == nil { req["model"] = model if patched, err := json.Marshal(req); err == nil { upstreamBody = patched } } } baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/") upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody)) if err != nil { return err } upstreamReq.Header.Set("Content-Type", "application/json") if strings.TrimSpace(provider.APIKey) != "" { upstreamReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey)) } resp, err := g.client.Do(upstreamReq) if err != nil { return err } defer resp.Body.Close() if stream { return g.proxyStream(ctx, quotaExternalUserID, model, provider, billingScope, w, resp) } payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) if err != nil { return err } w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) _, _ = w.Write(payload) if resp.StatusCode >= 400 { return nil } if strings.TrimSpace(quotaExternalUserID) != "" { usage := cost.ParseUsage(payload) _ = g.quota.RecordUsage(ctx, cost.RecordInput{ ExternalUserID: quotaExternalUserID, Feature: "gateway", ModelID: model, ProviderID: provider.ID, BillingScope: billingScope, ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey), Usage: usage, }) } return nil } func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID, model string, provider llm.Provider, billingScope string, w http.ResponseWriter, resp *http.Response) error { rc := http.NewResponseController(w) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") w.WriteHeader(resp.StatusCode) reader := bufio.NewReader(resp.Body) var usage cost.UsageDetail for { line, err := reader.ReadString('\n') if len(line) > 0 { _, _ = w.Write([]byte(line)) if err := rc.Flush(); err != nil { return fmt.Errorf("streaming not supported: %w", err) } if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { chunk := []byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")) usage = cost.MergeStreamUsage(usage, chunk) } } if err != nil { if err == io.EOF { break } return err } select { case <-ctx.Done(): return ctx.Err() default: } } if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" { if usage.TotalTokens == 0 { usage.TotalTokens = 1 } _ = g.quota.RecordUsage(ctx, cost.RecordInput{ ExternalUserID: quotaExternalUserID, Feature: "gateway", ModelID: model, ProviderID: provider.ID, BillingScope: billingScope, ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey), Usage: usage, }) } return nil } func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, string, error) { model = strings.TrimSpace(model) if model != "" { for _, p := range settings.Providers { if p.ID == model { return p, strings.TrimSpace(p.DefaultModel), nil } } } provider, resolvedModel, err := llm.ResolveProvider(settings, settings.DefaultProviderID) if err != nil { return llm.Provider{}, "", err } if model != "" { resolvedModel = model } return provider, resolvedModel, nil } func NowUnix() int64 { return time.Now().Unix() }