ultisuite-backend/internal/llm/client.go
R3D347HR4Y 3978622050
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
refactor(ai): update AI gateway and cost management features
- Refactored AI gateway to utilize new cost management structures for usage tracking.
- Replaced deprecated token extraction methods with a unified cost parsing approach.
- Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses.
- Added new metering functionality to record AI usage and costs effectively.
- Updated tests to reflect changes in usage parsing and cost calculations.
- Introduced new API endpoints for retrieving AI usage summaries and pricing information.
2026-06-16 10:46:33 +02:00

284 lines
7.5 KiB
Go

package llm
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type Provider struct {
ID string `json:"id"`
Name string `json:"name"`
Type string `json:"type,omitempty"`
BaseURL string `json:"base_url"`
APIKey string `json:"api_key,omitempty"`
DefaultModel string `json:"default_model"`
}
type Settings struct {
DefaultProviderID string `json:"default_provider_id"`
Providers []Provider `json:"providers"`
ContactDiscoveryModel string `json:"contact_discovery_model,omitempty"`
ContactDiscoveryProvider string `json:"contact_discovery_provider_id,omitempty"`
}
type ChatMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type chatRequest struct {
Model string `json:"model"`
Messages []ChatMessage `json:"messages"`
Temperature float64 `json:"temperature"`
}
type chatResponse struct {
Choices []struct {
Message struct {
Content string `json:"content"`
} `json:"message"`
} `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
} `json:"usage,omitempty"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type modelsResponse struct {
Data []struct {
ID string `json:"id"`
} `json:"data"`
Error *struct {
Message string `json:"message"`
} `json:"error,omitempty"`
}
type Client struct {
http *http.Client
}
func NewClient() *Client {
return &Client{http: &http.Client{Timeout: 90 * time.Second}}
}
func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) {
result, err := c.CompleteWithUsage(ctx, provider, model, systemPrompt, userPrompt)
if err != nil {
return "", err
}
return result.Content, nil
}
// CompletionResult holds LLM output and usage metadata.
type CompletionResult struct {
Content string
Model string
Usage UsageDetail
}
// UsageDetail mirrors ai/cost.UsageDetail for llm package consumers.
type UsageDetail struct {
PromptTokens int
CompletionTokens int
CachedInputTokens int
ReasoningTokens int
TotalTokens int
}
func (c *Client) CompleteWithUsage(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (CompletionResult, error) {
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
if baseURL == "" {
return CompletionResult{}, fmt.Errorf("llm provider base_url is required")
}
model = strings.TrimSpace(model)
if model == "" {
model = strings.TrimSpace(provider.DefaultModel)
}
if model == "" {
return CompletionResult{}, fmt.Errorf("llm model is required")
}
reqBody := chatRequest{
Model: model,
Messages: []ChatMessage{
{Role: "system", Content: systemPrompt},
{Role: "user", Content: userPrompt},
},
Temperature: 0.2,
}
payload, err := json.Marshal(reqBody)
if err != nil {
return CompletionResult{}, err
}
url := baseURL + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil {
return CompletionResult{}, err
}
req.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(provider.APIKey) != "" {
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
}
resp, err := c.http.Do(req)
if err != nil {
return CompletionResult{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return CompletionResult{}, err
}
if resp.StatusCode >= 400 {
return CompletionResult{}, fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
}
var parsed chatResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return CompletionResult{}, err
}
if parsed.Error != nil && parsed.Error.Message != "" {
return CompletionResult{}, fmt.Errorf("llm error: %s", parsed.Error.Message)
}
if len(parsed.Choices) == 0 {
return CompletionResult{}, fmt.Errorf("llm returned no choices")
}
usage := parseUsageFromResponse(parsed.Usage)
return CompletionResult{
Content: strings.TrimSpace(parsed.Choices[0].Message.Content),
Model: model,
Usage: usage,
}, nil
}
func parseUsageFromResponse(u *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
}) UsageDetail {
if u == nil {
return UsageDetail{TotalTokens: 1}
}
d := UsageDetail{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptTokensDetails != nil {
d.CachedInputTokens = u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil {
d.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
}
if d.TotalTokens == 0 {
d.TotalTokens = d.PromptTokens + d.CompletionTokens
}
if d.TotalTokens == 0 {
d.TotalTokens = 1
}
return d
}
func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) {
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
if baseURL == "" {
return nil, fmt.Errorf("llm provider base_url is required")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
if err != nil {
return nil, err
}
if strings.TrimSpace(provider.APIKey) != "" {
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
}
resp, err := c.http.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil {
return nil, err
}
if resp.StatusCode >= 400 {
return nil, fmt.Errorf("llm models request failed (%d): %s", resp.StatusCode, string(body))
}
var parsed modelsResponse
if err := json.Unmarshal(body, &parsed); err != nil {
return nil, err
}
if parsed.Error != nil && parsed.Error.Message != "" {
return nil, fmt.Errorf("llm error: %s", parsed.Error.Message)
}
models := make([]string, 0, len(parsed.Data))
seen := make(map[string]struct{}, len(parsed.Data))
for _, item := range parsed.Data {
id := strings.TrimSpace(item.ID)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
models = append(models, id)
}
return models, nil
}
func ResolveProvider(settings Settings, providerID string) (Provider, string, error) {
if providerID == "" {
providerID = strings.TrimSpace(settings.ContactDiscoveryProvider)
}
if providerID == "" {
providerID = strings.TrimSpace(settings.DefaultProviderID)
}
for _, p := range settings.Providers {
if p.ID == providerID {
model := strings.TrimSpace(settings.ContactDiscoveryModel)
if model == "" {
model = strings.TrimSpace(p.DefaultModel)
}
return p, model, nil
}
}
if len(settings.Providers) > 0 {
p := settings.Providers[0]
model := strings.TrimSpace(settings.ContactDiscoveryModel)
if model == "" {
model = strings.TrimSpace(p.DefaultModel)
}
return p, model, nil
}
return Provider{}, "", fmt.Errorf("no llm provider configured")
}