- Refactored AI gateway to utilize new cost management structures for usage tracking. - Replaced deprecated token extraction methods with a unified cost parsing approach. - Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses. - Added new metering functionality to record AI usage and costs effectively. - Updated tests to reflect changes in usage parsing and cost calculations. - Introduced new API endpoints for retrieving AI usage summaries and pricing information.
284 lines
7.5 KiB
Go
284 lines
7.5 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Provider struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Type string `json:"type,omitempty"`
|
|
BaseURL string `json:"base_url"`
|
|
APIKey string `json:"api_key,omitempty"`
|
|
DefaultModel string `json:"default_model"`
|
|
}
|
|
|
|
type Settings struct {
|
|
DefaultProviderID string `json:"default_provider_id"`
|
|
Providers []Provider `json:"providers"`
|
|
ContactDiscoveryModel string `json:"contact_discovery_model,omitempty"`
|
|
ContactDiscoveryProvider string `json:"contact_discovery_provider_id,omitempty"`
|
|
}
|
|
|
|
type ChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type chatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
Temperature float64 `json:"temperature"`
|
|
}
|
|
|
|
type chatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Usage *struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
PromptTokensDetails *struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details,omitempty"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens int `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details,omitempty"`
|
|
} `json:"usage,omitempty"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type modelsResponse struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type Client struct {
|
|
http *http.Client
|
|
}
|
|
|
|
func NewClient() *Client {
|
|
return &Client{http: &http.Client{Timeout: 90 * time.Second}}
|
|
}
|
|
|
|
func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) {
|
|
result, err := c.CompleteWithUsage(ctx, provider, model, systemPrompt, userPrompt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return result.Content, nil
|
|
}
|
|
|
|
// CompletionResult holds LLM output and usage metadata.
|
|
type CompletionResult struct {
|
|
Content string
|
|
Model string
|
|
Usage UsageDetail
|
|
}
|
|
|
|
// UsageDetail mirrors ai/cost.UsageDetail for llm package consumers.
|
|
type UsageDetail struct {
|
|
PromptTokens int
|
|
CompletionTokens int
|
|
CachedInputTokens int
|
|
ReasoningTokens int
|
|
TotalTokens int
|
|
}
|
|
|
|
func (c *Client) CompleteWithUsage(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (CompletionResult, error) {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
if baseURL == "" {
|
|
return CompletionResult{}, fmt.Errorf("llm provider base_url is required")
|
|
}
|
|
model = strings.TrimSpace(model)
|
|
if model == "" {
|
|
model = strings.TrimSpace(provider.DefaultModel)
|
|
}
|
|
if model == "" {
|
|
return CompletionResult{}, fmt.Errorf("llm model is required")
|
|
}
|
|
|
|
reqBody := chatRequest{
|
|
Model: model,
|
|
Messages: []ChatMessage{
|
|
{Role: "system", Content: systemPrompt},
|
|
{Role: "user", Content: userPrompt},
|
|
},
|
|
Temperature: 0.2,
|
|
}
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return CompletionResult{}, err
|
|
}
|
|
|
|
url := baseURL + "/chat/completions"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
|
if err != nil {
|
|
return CompletionResult{}, err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return CompletionResult{}, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return CompletionResult{}, err
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
return CompletionResult{}, fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var parsed chatResponse
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return CompletionResult{}, err
|
|
}
|
|
if parsed.Error != nil && parsed.Error.Message != "" {
|
|
return CompletionResult{}, fmt.Errorf("llm error: %s", parsed.Error.Message)
|
|
}
|
|
if len(parsed.Choices) == 0 {
|
|
return CompletionResult{}, fmt.Errorf("llm returned no choices")
|
|
}
|
|
usage := parseUsageFromResponse(parsed.Usage)
|
|
return CompletionResult{
|
|
Content: strings.TrimSpace(parsed.Choices[0].Message.Content),
|
|
Model: model,
|
|
Usage: usage,
|
|
}, nil
|
|
}
|
|
|
|
func parseUsageFromResponse(u *struct {
|
|
PromptTokens int `json:"prompt_tokens"`
|
|
CompletionTokens int `json:"completion_tokens"`
|
|
TotalTokens int `json:"total_tokens"`
|
|
PromptTokensDetails *struct {
|
|
CachedTokens int `json:"cached_tokens"`
|
|
} `json:"prompt_tokens_details,omitempty"`
|
|
CompletionTokensDetails *struct {
|
|
ReasoningTokens int `json:"reasoning_tokens"`
|
|
} `json:"completion_tokens_details,omitempty"`
|
|
}) UsageDetail {
|
|
if u == nil {
|
|
return UsageDetail{TotalTokens: 1}
|
|
}
|
|
d := UsageDetail{
|
|
PromptTokens: u.PromptTokens,
|
|
CompletionTokens: u.CompletionTokens,
|
|
TotalTokens: u.TotalTokens,
|
|
}
|
|
if u.PromptTokensDetails != nil {
|
|
d.CachedInputTokens = u.PromptTokensDetails.CachedTokens
|
|
}
|
|
if u.CompletionTokensDetails != nil {
|
|
d.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
|
|
}
|
|
if d.TotalTokens == 0 {
|
|
d.TotalTokens = d.PromptTokens + d.CompletionTokens
|
|
}
|
|
if d.TotalTokens == 0 {
|
|
d.TotalTokens = 1
|
|
}
|
|
return d
|
|
}
|
|
|
|
func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("llm provider base_url is required")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
return nil, fmt.Errorf("llm models request failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var parsed modelsResponse
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return nil, err
|
|
}
|
|
if parsed.Error != nil && parsed.Error.Message != "" {
|
|
return nil, fmt.Errorf("llm error: %s", parsed.Error.Message)
|
|
}
|
|
|
|
models := make([]string, 0, len(parsed.Data))
|
|
seen := make(map[string]struct{}, len(parsed.Data))
|
|
for _, item := range parsed.Data {
|
|
id := strings.TrimSpace(item.ID)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[id]; ok {
|
|
continue
|
|
}
|
|
seen[id] = struct{}{}
|
|
models = append(models, id)
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func ResolveProvider(settings Settings, providerID string) (Provider, string, error) {
|
|
if providerID == "" {
|
|
providerID = strings.TrimSpace(settings.ContactDiscoveryProvider)
|
|
}
|
|
if providerID == "" {
|
|
providerID = strings.TrimSpace(settings.DefaultProviderID)
|
|
}
|
|
for _, p := range settings.Providers {
|
|
if p.ID == providerID {
|
|
model := strings.TrimSpace(settings.ContactDiscoveryModel)
|
|
if model == "" {
|
|
model = strings.TrimSpace(p.DefaultModel)
|
|
}
|
|
return p, model, nil
|
|
}
|
|
}
|
|
if len(settings.Providers) > 0 {
|
|
p := settings.Providers[0]
|
|
model := strings.TrimSpace(settings.ContactDiscoveryModel)
|
|
if model == "" {
|
|
model = strings.TrimSpace(p.DefaultModel)
|
|
}
|
|
return p, model, nil
|
|
}
|
|
return Provider{}, "", fmt.Errorf("no llm provider configured")
|
|
}
|