- Refactored AI gateway to utilize new cost management structures for usage tracking. - Replaced deprecated token extraction methods with a unified cost parsing approach. - Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses. - Added new metering functionality to record AI usage and costs effectively. - Updated tests to reflect changes in usage parsing and cost calculations. - Introduced new API endpoints for retrieving AI usage summaries and pricing information.
292 lines
8.0 KiB
Go
292 lines
8.0 KiB
Go
package ai
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
|
"github.com/ultisuite/ulti-backend/internal/llm"
|
|
)
|
|
|
|
type Gateway struct {
|
|
db *pgxpool.Pool
|
|
client *http.Client
|
|
quota *QuotaService
|
|
}
|
|
|
|
func NewGateway(db *pgxpool.Pool) *Gateway {
|
|
return &Gateway{
|
|
db: db,
|
|
client: &http.Client{
|
|
Timeout: 0,
|
|
},
|
|
quota: NewQuotaService(db),
|
|
}
|
|
}
|
|
|
|
type chatCompletionRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []llm.ChatMessage `json:"messages"`
|
|
Temperature *float64 `json:"temperature,omitempty"`
|
|
Stream bool `json:"stream,omitempty"`
|
|
Tools []any `json:"tools,omitempty"`
|
|
}
|
|
|
|
type chatCompletionResponse struct {
|
|
ID string `json:"id"`
|
|
Object string `json:"object"`
|
|
Created int64 `json:"created"`
|
|
Model string `json:"model"`
|
|
Choices []struct {
|
|
Index int `json:"index"`
|
|
Message llm.ChatMessage `json:"message"`
|
|
FinishReason string `json:"finish_reason"`
|
|
Delta *llm.ChatMessage `json:"delta,omitempty"`
|
|
} `json:"choices"`
|
|
Usage *struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
PromptTokensDetails *struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details,omitempty"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens int `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details,omitempty"`
|
|
} `json:"usage,omitempty"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
func (g *Gateway) ListModels(ctx context.Context, externalUserID string) ([]map[string]any, error) {
|
|
settings, err := LoadEffectiveLLMSettings(ctx, g.db, externalUserID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return g.listModelsFromSettings(ctx, settings)
|
|
}
|
|
|
|
func (g *Gateway) ListOrgModels(ctx context.Context) ([]map[string]any, error) {
|
|
settings, err := LoadOrgLLMSettings(ctx, g.db)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return g.listModelsFromSettings(ctx, settings)
|
|
}
|
|
|
|
func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Settings) ([]map[string]any, error) {
|
|
policy, _ := LoadAssistantPolicy(ctx, g.db)
|
|
client := llm.NewClient()
|
|
seen := make(map[string]struct{})
|
|
out := make([]map[string]any, 0)
|
|
for _, provider := range settings.Providers {
|
|
models, err := client.ListModels(ctx, provider)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
for _, modelID := range models {
|
|
if _, ok := seen[modelID]; ok {
|
|
continue
|
|
}
|
|
seen[modelID] = struct{}{}
|
|
out = append(out, map[string]any{
|
|
"id": modelID,
|
|
"object": "model",
|
|
"owned_by": provider.Name,
|
|
})
|
|
}
|
|
}
|
|
if len(out) == 0 && len(settings.Providers) > 0 {
|
|
p := settings.Providers[0]
|
|
model := strings.TrimSpace(p.DefaultModel)
|
|
if model != "" {
|
|
out = append(out, map[string]any{
|
|
"id": model,
|
|
"object": "model",
|
|
"owned_by": p.Name,
|
|
})
|
|
}
|
|
}
|
|
return ApplyModelCatalog(out, policy.Models), nil
|
|
}
|
|
|
|
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
|
|
var modelProbe struct {
|
|
Model string `json:"model"`
|
|
}
|
|
if err := json.Unmarshal(body, &modelProbe); err != nil {
|
|
return fmt.Errorf("invalid chat completion request: %w", err)
|
|
}
|
|
|
|
policy, _ := LoadAssistantPolicy(ctx, g.db)
|
|
if !IsModelAllowed(modelProbe.Model, policy.Models) {
|
|
return fmt.Errorf("model %q is not allowed", strings.TrimSpace(modelProbe.Model))
|
|
}
|
|
|
|
var (
|
|
settings llm.Settings
|
|
err error
|
|
)
|
|
if useOrgSettings {
|
|
settings, err = LoadOrgLLMSettings(ctx, g.db)
|
|
} else {
|
|
settings, err = LoadEffectiveLLMSettings(ctx, g.db, quotaExternalUserID)
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
provider, model, err := resolveProviderForModel(settings, modelProbe.Model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
billingScope := ResolveBillingScope(ctx, g.db, quotaExternalUserID, provider, useOrgSettings)
|
|
|
|
if strings.TrimSpace(quotaExternalUserID) != "" {
|
|
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID, provider, useOrgSettings); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
upstreamBody, err := repairChatCompletionBody(body)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stream := chatCompletionStreamRequested(upstreamBody)
|
|
if strings.TrimSpace(modelProbe.Model) == "" {
|
|
var req map[string]any
|
|
if err := json.Unmarshal(upstreamBody, &req); err == nil {
|
|
req["model"] = model
|
|
if patched, err := json.Marshal(req); err == nil {
|
|
upstreamBody = patched
|
|
}
|
|
}
|
|
}
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, baseURL+"/chat/completions", bytes.NewReader(upstreamBody))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
upstreamReq.Header.Set("Content-Type", "application/json")
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
upstreamReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := g.client.Do(upstreamReq)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if stream {
|
|
return g.proxyStream(ctx, quotaExternalUserID, model, provider, billingScope, w, resp)
|
|
}
|
|
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(resp.StatusCode)
|
|
_, _ = w.Write(payload)
|
|
if resp.StatusCode >= 400 {
|
|
return nil
|
|
}
|
|
if strings.TrimSpace(quotaExternalUserID) != "" {
|
|
usage := cost.ParseUsage(payload)
|
|
_ = g.quota.RecordUsage(ctx, cost.RecordInput{
|
|
ExternalUserID: quotaExternalUserID,
|
|
Feature: "gateway",
|
|
ModelID: model,
|
|
ProviderID: provider.ID,
|
|
BillingScope: billingScope,
|
|
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
|
|
Usage: usage,
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID, model string, provider llm.Provider, billingScope string, w http.ResponseWriter, resp *http.Response) error {
|
|
rc := http.NewResponseController(w)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no")
|
|
w.WriteHeader(resp.StatusCode)
|
|
|
|
reader := bufio.NewReader(resp.Body)
|
|
var usage cost.UsageDetail
|
|
for {
|
|
line, err := reader.ReadString('\n')
|
|
if len(line) > 0 {
|
|
_, _ = w.Write([]byte(line))
|
|
if err := rc.Flush(); err != nil {
|
|
return fmt.Errorf("streaming not supported: %w", err)
|
|
}
|
|
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
|
|
chunk := []byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))
|
|
usage = cost.MergeStreamUsage(usage, chunk)
|
|
}
|
|
}
|
|
if err != nil {
|
|
if err == io.EOF {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
default:
|
|
}
|
|
}
|
|
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
|
|
if usage.TotalTokens == 0 {
|
|
usage.TotalTokens = 1
|
|
}
|
|
_ = g.quota.RecordUsage(ctx, cost.RecordInput{
|
|
ExternalUserID: quotaExternalUserID,
|
|
Feature: "gateway",
|
|
ModelID: model,
|
|
ProviderID: provider.ID,
|
|
BillingScope: billingScope,
|
|
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
|
|
Usage: usage,
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, string, error) {
|
|
model = strings.TrimSpace(model)
|
|
if model != "" {
|
|
for _, p := range settings.Providers {
|
|
if p.ID == model {
|
|
return p, strings.TrimSpace(p.DefaultModel), nil
|
|
}
|
|
}
|
|
}
|
|
provider, resolvedModel, err := llm.ResolveProvider(settings, settings.DefaultProviderID)
|
|
if err != nil {
|
|
return llm.Provider{}, "", err
|
|
}
|
|
if model != "" {
|
|
resolvedModel = model
|
|
}
|
|
return provider, resolvedModel, nil
|
|
}
|
|
|
|
func NowUnix() int64 {
|
|
return time.Now().Unix()
|
|
}
|