- Updated .env.example to include new configuration options for the UltiAI branding and API endpoints. - Enhanced Nginx configuration to support new API routes for the MCP and WebSocket connections. - Introduced sub-filters for branding adjustments in Nginx responses. - Added new JavaScript patch for API endpoint adjustments. - Implemented tests for new API functionalities and improved error handling in the AI gateway.
210 lines
5.2 KiB
Go
210 lines
5.2 KiB
Go
package llm
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
type Provider struct {
|
|
ID string `json:"id"`
|
|
Name string `json:"name"`
|
|
Type string `json:"type,omitempty"`
|
|
BaseURL string `json:"base_url"`
|
|
APIKey string `json:"api_key,omitempty"`
|
|
DefaultModel string `json:"default_model"`
|
|
}
|
|
|
|
type Settings struct {
|
|
DefaultProviderID string `json:"default_provider_id"`
|
|
Providers []Provider `json:"providers"`
|
|
ContactDiscoveryModel string `json:"contact_discovery_model,omitempty"`
|
|
ContactDiscoveryProvider string `json:"contact_discovery_provider_id,omitempty"`
|
|
}
|
|
|
|
type ChatMessage struct {
|
|
Role string `json:"role"`
|
|
Content string `json:"content"`
|
|
}
|
|
|
|
type chatRequest struct {
|
|
Model string `json:"model"`
|
|
Messages []ChatMessage `json:"messages"`
|
|
Temperature float64 `json:"temperature"`
|
|
}
|
|
|
|
type chatResponse struct {
|
|
Choices []struct {
|
|
Message struct {
|
|
Content string `json:"content"`
|
|
} `json:"message"`
|
|
} `json:"choices"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type modelsResponse struct {
|
|
Data []struct {
|
|
ID string `json:"id"`
|
|
} `json:"data"`
|
|
Error *struct {
|
|
Message string `json:"message"`
|
|
} `json:"error,omitempty"`
|
|
}
|
|
|
|
type Client struct {
|
|
http *http.Client
|
|
}
|
|
|
|
func NewClient() *Client {
|
|
return &Client{http: &http.Client{Timeout: 90 * time.Second}}
|
|
}
|
|
|
|
func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
if baseURL == "" {
|
|
return "", fmt.Errorf("llm provider base_url is required")
|
|
}
|
|
model = strings.TrimSpace(model)
|
|
if model == "" {
|
|
model = strings.TrimSpace(provider.DefaultModel)
|
|
}
|
|
if model == "" {
|
|
return "", fmt.Errorf("llm model is required")
|
|
}
|
|
|
|
reqBody := chatRequest{
|
|
Model: model,
|
|
Messages: []ChatMessage{
|
|
{Role: "system", Content: systemPrompt},
|
|
{Role: "user", Content: userPrompt},
|
|
},
|
|
Temperature: 0.2,
|
|
}
|
|
payload, err := json.Marshal(reqBody)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
url := baseURL + "/chat/completions"
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
return "", fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var parsed chatResponse
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return "", err
|
|
}
|
|
if parsed.Error != nil && parsed.Error.Message != "" {
|
|
return "", fmt.Errorf("llm error: %s", parsed.Error.Message)
|
|
}
|
|
if len(parsed.Choices) == 0 {
|
|
return "", fmt.Errorf("llm returned no choices")
|
|
}
|
|
return strings.TrimSpace(parsed.Choices[0].Message.Content), nil
|
|
}
|
|
|
|
func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) {
|
|
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("llm provider base_url is required")
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, baseURL+"/models", nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if strings.TrimSpace(provider.APIKey) != "" {
|
|
req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(provider.APIKey))
|
|
}
|
|
|
|
resp, err := c.http.Do(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if resp.StatusCode >= 400 {
|
|
return nil, fmt.Errorf("llm models request failed (%d): %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
var parsed modelsResponse
|
|
if err := json.Unmarshal(body, &parsed); err != nil {
|
|
return nil, err
|
|
}
|
|
if parsed.Error != nil && parsed.Error.Message != "" {
|
|
return nil, fmt.Errorf("llm error: %s", parsed.Error.Message)
|
|
}
|
|
|
|
models := make([]string, 0, len(parsed.Data))
|
|
seen := make(map[string]struct{}, len(parsed.Data))
|
|
for _, item := range parsed.Data {
|
|
id := strings.TrimSpace(item.ID)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[id]; ok {
|
|
continue
|
|
}
|
|
seen[id] = struct{}{}
|
|
models = append(models, id)
|
|
}
|
|
return models, nil
|
|
}
|
|
|
|
func ResolveProvider(settings Settings, providerID string) (Provider, string, error) {
|
|
if providerID == "" {
|
|
providerID = strings.TrimSpace(settings.ContactDiscoveryProvider)
|
|
}
|
|
if providerID == "" {
|
|
providerID = strings.TrimSpace(settings.DefaultProviderID)
|
|
}
|
|
for _, p := range settings.Providers {
|
|
if p.ID == providerID {
|
|
model := strings.TrimSpace(settings.ContactDiscoveryModel)
|
|
if model == "" {
|
|
model = strings.TrimSpace(p.DefaultModel)
|
|
}
|
|
return p, model, nil
|
|
}
|
|
}
|
|
if len(settings.Providers) > 0 {
|
|
p := settings.Providers[0]
|
|
model := strings.TrimSpace(settings.ContactDiscoveryModel)
|
|
if model == "" {
|
|
model = strings.TrimSpace(p.DefaultModel)
|
|
}
|
|
return p, model, nil
|
|
}
|
|
return Provider{}, "", fmt.Errorf("no llm provider configured")
|
|
}
|