package ai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/llm" ) type Gateway struct { db *pgxpool.Pool client *http.Client quota *QuotaService } func NewGateway(db *pgxpool.Pool) *Gateway { return &Gateway{ db: db, client: &http.Client{ Timeout: 0, }, quota: NewQuotaService(db), } } type chatCompletionRequest struct { Model string `json:"model"` Messages []llm.ChatMessage `json:"messages"` Temperature *float64 `json:"temperature,omitempty"` Stream bool `json:"stream,omitempty"` Tools []any `json:"tools,omitempty"` } type usagePayload struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` TotalTokens int `json:"total_tokens"` } type chatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` Model string `json:"model"` Choices []struct { Index int `json:"index"` Message llm.ChatMessage `json:"message"` FinishReason string `json:"finish_reason"` Delta *llm.ChatMessage `json:"delta,omitempty"` } `json:"choices"` Usage *usagePayload `json:"usage,omitempty"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` } func (g *Gateway) ListModels(ctx context.Context, externalUserID string) ([]map[string]any, error) { settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID) if err != nil { return nil, err } return g.listModelsFromSettings(ctx, settings) } func (g *Gateway) ListOrgModels(ctx context.Context) ([]map[string]any, error) { settings, err := LoadOrgLLMSettings(ctx, g.db) if err != nil { return nil, err } return g.listModelsFromSettings(ctx, settings) } func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Settings) ([]map[string]any, error) { policy, _ := LoadAssistantPolicy(ctx, g.db) client := llm.NewClient() seen := make(map[string]struct{}) out := make([]map[string]any, 0) for _, provider := range settings.Providers { models, err := client.ListModels(ctx, provider) if err != nil { continue } for _, modelID := range models { if _, ok := seen[modelID]; ok { continue } seen[modelID] = struct{}{} out = append(out, map[string]any{ "id": modelID, "object": "model", "owned_by": provider.Name, }) } } if len(out) == 0 && len(settings.Providers) > 0 { p := settings.Providers[0] model := strings.TrimSpace(p.DefaultModel) if model != "" { out = append(out, map[string]any{ "id": model, "object": "model", "owned_by": p.Name, }) } } return ApplyModelCatalog(out, policy.Models), nil } func (g *Gateway) ProxyChatCompletions(ctx context.Context, externalUserID string, body []byte, w http.ResponseWriter) error { if err := g.quota.AssertAvailable(ctx, externalUserID); err != nil { return err } var req chatCompletionRequest if err := json.Unmarshal(body, &req); err != nil { return fmt.Errorf("invalid chat completion request: %w", err) } policy, _ := LoadAssistantPolicy(ctx, g.db) if !IsModelAllowed(req.Model, policy.Models) { return fmt.Errorf("model %q is not allowed", strings.TrimSpace(req.Model)) } settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID) if err != nil { return err } provider, model, err := resolveProviderForModel(settings, req.Model) if err != nil { return err } if strings.TrimSpace(req.Model) == "" { req.Model = model } upstreamBody, err := json.Marshal(req) if err != nil { return err } baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/") upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody)) if err != nil { return err } upstreamReq.Header.Set("Content-Type", "application/json") if strings.TrimSpace(provider.APIKey) != "" { upstreamReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey)) } resp, err := g.client.Do(upstreamReq) if err != nil { return err } defer resp.Body.Close() if req.Stream { return g.proxyStream(ctx, externalUserID, w, resp) } payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) if err != nil { return err } w.Header().Set("Content-Type", "application/json") w.WriteHeader(resp.StatusCode) _, _ = w.Write(payload) if resp.StatusCode >= 400 { return nil } tokens := extractUsageTokens(payload) _ = g.quota.Record(ctx, externalUserID, tokens) return nil } func (g *Gateway) proxyStream(ctx context.Context, externalUserID string, w http.ResponseWriter, resp *http.Response) error { flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("streaming not supported") } w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.WriteHeader(resp.StatusCode) reader := bufio.NewReader(resp.Body) var totalTokens int64 for { line, err := reader.ReadString('\n') if len(line) > 0 { _, _ = w.Write([]byte(line)) flusher.Flush() if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))) } } if err != nil { if err == io.EOF { break } return err } select { case <-ctx.Done(): return ctx.Err() default: } } if resp.StatusCode < 400 { if totalTokens == 0 { totalTokens = 1 } _ = g.quota.Record(ctx, externalUserID, totalTokens) } return nil } func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, string, error) { model = strings.TrimSpace(model) if model != "" { for _, p := range settings.Providers { if p.ID == model { return p, strings.TrimSpace(p.DefaultModel), nil } } } provider, resolvedModel, err := llm.ResolveProvider(settings, settings.DefaultProviderID) if err != nil { return llm.Provider{}, "", err } if model != "" { resolvedModel = model } return provider, resolvedModel, nil } func extractUsageTokens(payload []byte) int64 { var parsed chatCompletionResponse if err := json.Unmarshal(payload, &parsed); err != nil { return 1 } if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 { return int64(parsed.Usage.TotalTokens) } if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 { return int64(parsed.Usage.CompletionTokens) } return 1 } func extractStreamUsageTokens(payload []byte) int64 { var parsed chatCompletionResponse if err := json.Unmarshal(payload, &parsed); err != nil { return 0 } if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 { return int64(parsed.Usage.TotalTokens) } return 0 } func NowUnix() int64 { return time.Now().Unix() }