- Updated .env.example to include new configuration options for the UltiAI branding and API endpoints. - Enhanced Nginx configuration to support new API routes for the MCP and WebSocket connections. - Introduced sub-filters for branding adjustments in Nginx responses. - Added new JavaScript patch for API endpoint adjustments. - Implemented tests for new API functionalities and improved error handling in the AI gateway.
293 lines
7.5 KiB
Go
293 lines
7.5 KiB
Go
package ai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/llm"
|
|
)
|
|
|
|
type Gateway struct {
|
|
db *pgxpool.Pool
|
|
client *http.Client
|
|
quota *QuotaService
|
|
}
|
|
|
|
func NewGateway(db *pgxpool.Pool) *Gateway {
|
|
return &Gateway{
|
|
db: db,
|
|
client: &http.Client{
|
|
Timeout: 0,
|
|
},
|
|
quota: NewQuotaService(db),
|
|
}
|
|
}
|
|
|
|
type chatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []llm.ChatMessage `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Tools []any `json:"tools,omitempty"`
|
|
}
|
|
|
|
type usagePayload struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
}
|
|
|
|
type chatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Message llm.ChatMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
Delta *llm.ChatMessage `json:"delta,omitempty"`
|
|
} `json:"choices"`
|
|
Usage *usagePayload `json:"usage,omitempty"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
func (g *Gateway) ListModels(ctx context.Context, externalUserID string) ([]map[string]any, error) {
|
|
settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return g.listModelsFromSettings(ctx, settings)
|
|
}
|
|
|
|
func (g *Gateway) ListOrgModels(ctx context.Context) ([]map[string]any, error) {
|
|
settings, err := LoadOrgLLMSettings(ctx, g.db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return g.listModelsFromSettings(ctx, settings)
|
|
}
|
|
|
|
func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Settings) ([]map[string]any, error) {
|
|
policy, _ := LoadAssistantPolicy(ctx, g.db)
|
|
client := llm.NewClient()
|
|
seen := make(map[string]struct{})
|
|
out := make([]map[string]any, 0)
|
|
for _, provider := range settings.Providers {
|
|
models, err := client.ListModels(ctx, provider)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
for _, modelID := range models {
|
|
if _, ok := seen[modelID]; ok {
|
|
continue
|
|
}
|
|
seen[modelID] = struct{}{}
|
|
out = append(out, map[string]any{
|
|
"id": modelID,
|
|
"object": "model",
|
|
"owned_by": provider.Name,
|
|
})
|
|
}
|
|
}
|
|
if len(out) == 0 && len(settings.Providers) > 0 {
|
|
p := settings.Providers[0]
|
|
model := strings.TrimSpace(p.DefaultModel)
|
|
if model != "" {
|
|
out = append(out, map[string]any{
|
|
"id": model,
|
|
"object": "model",
|
|
"owned_by": p.Name,
|
|
})
|
|
}
|
|
}
|
|
return ApplyModelCatalog(out, policy.Models), nil
|
|
}
|
|
|
|
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
|
|
if strings.TrimSpace(quotaExternalUserID) != "" {
|
|
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
var modelProbe struct {
|
|
Model string `json:"model"`
|
|
}
|
|
if err := json.Unmarshal(body, &modelProbe); err != nil {
|
|
return fmt.Errorf("invalid chat completion request: %w", err)
|
|
}
|
|
|
|
policy, _ := LoadAssistantPolicy(ctx, g.db)
|
|
if !IsModelAllowed(modelProbe.Model, policy.Models) {
|
|
return fmt.Errorf("model %q is not allowed", strings.TrimSpace(modelProbe.Model))
|
|
}
|
|
|
|
var (
|
|
settings llm.Settings
|
|
err error
|
|
)
|
|
if useOrgSettings {
|
|
settings, err = LoadOrgLLMSettings(ctx, g.db)
|
|
} else {
|
|
settings, err = LoadEffectiveLLMSettings(ctx, g.db, quotaExternalUserID)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
provider, model, err := resolveProviderForModel(settings, modelProbe.Model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
upstreamBody, err := repairChatCompletionBody(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stream := chatCompletionStreamRequested(upstreamBody)
|
|
if strings.TrimSpace(modelProbe.Model) == "" {
|
|
var req map[string]any
|
|
if err := json.Unmarshal(upstreamBody, &req); err == nil {
|
|
req["model"] = model
|
|
if patched, err := json.Marshal(req); err == nil {
|
|
upstreamBody = patched
|
|
}
|
|
}
|
|
}
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
upstreamReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := g.client.Do(upstreamReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if stream {
|
|
return g.proxyStream(ctx, quotaExternalUserID, w, resp)
|
|
}
|
|
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = w.Write(payload)
|
|
if resp.StatusCode >= 400 {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(quotaExternalUserID) != "" {
|
|
tokens := extractUsageTokens(payload)
|
|
_ = g.quota.Record(ctx, quotaExternalUserID, tokens)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error {
|
|
rc := http.NewResponseController(w)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
var totalTokens int64
|
|
for {
|
|
line, err := reader.ReadString('\n')
|
|
if len(line) > 0 {
|
|
_, _ = w.Write([]byte(line))
|
|
if err := rc.Flush(); err != nil {
|
|
return fmt.Errorf("streaming not supported: %w", err)
|
|
}
|
|
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
|
|
totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")))
|
|
}
|
|
}
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
}
|
|
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
|
|
if totalTokens == 0 {
|
|
totalTokens = 1
|
|
}
|
|
_ = g.quota.Record(ctx, quotaExternalUserID, totalTokens)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, string, error) {
|
|
model = strings.TrimSpace(model)
|
|
if model != "" {
|
|
for _, p := range settings.Providers {
|
|
if p.ID == model {
|
|
return p, strings.TrimSpace(p.DefaultModel), nil
|
|
}
|
|
}
|
|
}
|
|
provider, resolvedModel, err := llm.ResolveProvider(settings, settings.DefaultProviderID)
|
|
if err != nil {
|
|
return llm.Provider{}, "", err
|
|
}
|
|
if model != "" {
|
|
resolvedModel = model
|
|
}
|
|
return provider, resolvedModel, nil
|
|
}
|
|
|
|
func extractUsageTokens(payload []byte) int64 {
|
|
var parsed chatCompletionResponse
|
|
if err := json.Unmarshal(payload, &parsed); err != nil {
|
|
return 1
|
|
}
|
|
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
|
|
return int64(parsed.Usage.TotalTokens)
|
|
}
|
|
if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 {
|
|
return int64(parsed.Usage.CompletionTokens)
|
|
}
|
|
return 1
|
|
}
|
|
|
|
func extractStreamUsageTokens(payload []byte) int64 {
|
|
var parsed chatCompletionResponse
|
|
if err := json.Unmarshal(payload, &parsed); err != nil {
|
|
return 0
|
|
}
|
|
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
|
|
return int64(parsed.Usage.TotalTokens)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func NowUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|