ultisuite-backend/internal/ai/gateway.go
R3D347HR4Y 621b0099d6
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
feat(deploy): enhance Nginx configuration and API integration for UltiAI
- Updated .env.example to include new configuration options for the UltiAI branding and API endpoints.
- Enhanced Nginx configuration to support new API routes for the MCP and WebSocket connections.
- Introduced sub-filters for branding adjustments in Nginx responses.
- Added new JavaScript patch for API endpoint adjustments.
- Implemented tests for new API functionalities and improved error handling in the AI gateway.
2026-06-15 00:22:23 +02:00

293 lines
7.5 KiB
Go

package ai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/llm"
)
type Gateway struct {
db *pgxpool.Pool
client *http.Client
quota *QuotaService
}
func NewGateway(db *pgxpool.Pool) *Gateway {
return &Gateway{
db: db,
client: &http.Client{
Timeout: 0,
},
quota: NewQuotaService(db),
}
}
type chatCompletionRequest struct {
Model string `json:"model"`
Messages []llm.ChatMessage `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
Stream bool `json:"stream,omitempty"`
Tools []any `json:"tools,omitempty"`
}
type usagePayload struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type chatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
Message llm.ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
Delta *llm.ChatMessage `json:"delta,omitempty"`
} `json:"choices"`
Usage *usagePayload `json:"usage,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
func (g *Gateway) ListModels(ctx context.Context, externalUserID string) ([]map[string]any, error) {
settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID)
if err != nil {
return nil, err
}
return g.listModelsFromSettings(ctx, settings)
}
func (g *Gateway) ListOrgModels(ctx context.Context) ([]map[string]any, error) {
settings, err := LoadOrgLLMSettings(ctx, g.db)
if err != nil {
return nil, err
}
return g.listModelsFromSettings(ctx, settings)
}
func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Settings) ([]map[string]any, error) {
policy, _ := LoadAssistantPolicy(ctx, g.db)
client := llm.NewClient()
seen := make(map[string]struct{})
out := make([]map[string]any, 0)
for _, provider := range settings.Providers {
models, err := client.ListModels(ctx, provider)
if err != nil {
continue
}
for _, modelID := range models {
if _, ok := seen[modelID]; ok {
continue
}
seen[modelID] = struct{}{}
out = append(out, map[string]any{
"id": modelID,
"object": "model",
"owned_by": provider.Name,
})
}
}
if len(out) == 0 && len(settings.Providers) > 0 {
p := settings.Providers[0]
model := strings.TrimSpace(p.DefaultModel)
if model != "" {
out = append(out, map[string]any{
"id": model,
"object": "model",
"owned_by": p.Name,
})
}
}
return ApplyModelCatalog(out, policy.Models), nil
}
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
if strings.TrimSpace(quotaExternalUserID) != "" {
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil {
return err
}
}
var modelProbe struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &modelProbe); err != nil {
return fmt.Errorf("invalid chat completion request: %w", err)
}
policy, _ := LoadAssistantPolicy(ctx, g.db)
if !IsModelAllowed(modelProbe.Model, policy.Models) {
return fmt.Errorf("model %q is not allowed", strings.TrimSpace(modelProbe.Model))
}
var (
settings llm.Settings
err error
)
if useOrgSettings {
settings, err = LoadOrgLLMSettings(ctx, g.db)
} else {
settings, err = LoadEffectiveLLMSettings(ctx, g.db, quotaExternalUserID)
}
if err != nil {
return err
}
provider, model, err := resolveProviderForModel(settings, modelProbe.Model)
if err != nil {
return err
}
upstreamBody, err := repairChatCompletionBody(body)
if err != nil {
return err
}
stream := chatCompletionStreamRequested(upstreamBody)
if strings.TrimSpace(modelProbe.Model) == "" {
var req map[string]any
if err := json.Unmarshal(upstreamBody, &req); err == nil {
req["model"] = model
if patched, err := json.Marshal(req); err == nil {
upstreamBody = patched
}
}
}
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody))
if err != nil {
return err
}
upstreamReq.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(provider.APIKey) != "" {
upstreamReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
}
resp, err := g.client.Do(upstreamReq)
if err != nil {
return err
}
defer resp.Body.Close()
if stream {
return g.proxyStream(ctx, quotaExternalUserID, w, resp)
}
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil {
return err
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(resp.StatusCode)
_, _ = w.Write(payload)
if resp.StatusCode >= 400 {
return nil
}
if strings.TrimSpace(quotaExternalUserID) != "" {
tokens := extractUsageTokens(payload)
_ = g.quota.Record(ctx, quotaExternalUserID, tokens)
}
return nil
}
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error {
rc := http.NewResponseController(w)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")
w.WriteHeader(resp.StatusCode)
reader := bufio.NewReader(resp.Body)
var totalTokens int64
for {
line, err := reader.ReadString('\n')
if len(line) > 0 {
_, _ = w.Write([]byte(line))
if err := rc.Flush(); err != nil {
return fmt.Errorf("streaming not supported: %w", err)
}
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")))
}
}
if err != nil {
if err == io.EOF {
break
}
return err
}
select {
case <-ctx.Done():
return ctx.Err()
default:
}
}
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
if totalTokens == 0 {
totalTokens = 1
}
_ = g.quota.Record(ctx, quotaExternalUserID, totalTokens)
}
return nil
}
func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, string, error) {
model = strings.TrimSpace(model)
if model != "" {
for _, p := range settings.Providers {
if p.ID == model {
return p, strings.TrimSpace(p.DefaultModel), nil
}
}
}
provider, resolvedModel, err := llm.ResolveProvider(settings, settings.DefaultProviderID)
if err != nil {
return llm.Provider{}, "", err
}
if model != "" {
resolvedModel = model
}
return provider, resolvedModel, nil
}
func extractUsageTokens(payload []byte) int64 {
var parsed chatCompletionResponse
if err := json.Unmarshal(payload, &parsed); err != nil {
return 1
}
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
return int64(parsed.Usage.TotalTokens)
}
if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 {
return int64(parsed.Usage.CompletionTokens)
}
return 1
}
func extractStreamUsageTokens(payload []byte) int64 {
var parsed chatCompletionResponse
if err := json.Unmarshal(payload, &parsed); err != nil {
return 0
}
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
return int64(parsed.Usage.TotalTokens)
}
return 0
}
func NowUnix() int64 {
return time.Now().Unix()
}