ultisuite-backend/internal/ai/providers.go
R3D347HR4Y 3978622050
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
refactor(ai): update AI gateway and cost management features
- Refactored AI gateway to utilize new cost management structures for usage tracking.
- Replaced deprecated token extraction methods with a unified cost parsing approach.
- Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses.
- Added new metering functionality to record AI usage and costs effectively.
- Updated tests to reflect changes in usage parsing and cost calculations.
- Introduced new API endpoints for retrieving AI usage summaries and pricing information.
2026-06-16 10:46:33 +02:00

261 lines
6.8 KiB
Go

package ai
import (
"context"
"encoding/json"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/llm"
)
const orgSettingsSingletonID = 1
type orgLLMPolicy struct {
DefaultProviderID string `json:"default_provider_id"`
Providers []llm.Provider `json:"providers"`
EnforceOrgProviders bool `json:"enforce_org_providers"`
AllowUserOverride bool `json:"allow_user_override"`
ContactDiscoveryModel string `json:"contact_discovery_model,omitempty"`
}
func LoadEffectiveLLMSettings(ctx context.Context, db *pgxpool.Pool, externalUserID string) (llm.Settings, error) {
if db == nil {
return llm.Settings{}, fmt.Errorf("database unavailable")
}
org, err := loadOrgLLMPolicy(ctx, db)
if err != nil {
return llm.Settings{}, err
}
user, err := loadUserLLMSettings(ctx, db, externalUserID)
if err != nil {
return llm.Settings{}, err
}
if org.EnforceOrgProviders && len(org.Providers) > 0 {
if !org.AllowUserOverride {
return orgToSettings(org), nil
}
merged := orgToSettings(org)
if strings.TrimSpace(user.DefaultProviderID) != "" {
merged.DefaultProviderID = user.DefaultProviderID
}
if strings.TrimSpace(user.ContactDiscoveryModel) != "" {
merged.ContactDiscoveryModel = user.ContactDiscoveryModel
}
if strings.TrimSpace(user.ContactDiscoveryProvider) != "" {
merged.ContactDiscoveryProvider = user.ContactDiscoveryProvider
}
return merged, nil
}
if len(user.Providers) > 0 {
return user, nil
}
if len(org.Providers) > 0 {
return orgToSettings(org), nil
}
return user, nil
}
func orgToSettings(org orgLLMPolicy) llm.Settings {
return llm.Settings{
DefaultProviderID: org.DefaultProviderID,
Providers: org.Providers,
ContactDiscoveryModel: org.ContactDiscoveryModel,
ContactDiscoveryProvider: org.DefaultProviderID,
}
}
// LoadOrgLLMSettings returns org-level LLM provider configuration.
func LoadOrgLLMSettings(ctx context.Context, db *pgxpool.Pool) (llm.Settings, error) {
org, err := loadOrgLLMPolicy(ctx, db)
if err != nil {
return llm.Settings{}, err
}
return orgToSettings(org), nil
}
func loadOrgLLMPolicy(ctx context.Context, db *pgxpool.Pool) (orgLLMPolicy, error) {
var raw []byte
err := db.QueryRow(ctx, `
SELECT settings->'llm' FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil {
if err == pgx.ErrNoRows {
return orgLLMPolicy{}, nil
}
return orgLLMPolicy{}, err
}
if len(raw) == 0 || string(raw) == "null" {
return orgLLMPolicy{}, nil
}
var out orgLLMPolicy
if err := json.Unmarshal(raw, &out); err != nil {
return orgLLMPolicy{}, err
}
return out, nil
}
func loadUserLLMSettings(ctx context.Context, db *pgxpool.Pool, externalUserID string) (llm.Settings, error) {
var raw []byte
err := db.QueryRow(ctx, `
SELECT COALESCE(s.preferences->'llm', '{}'::jsonb)
FROM users u
LEFT JOIN settings s ON s.user_id = u.id
WHERE u.external_id = $1
`, externalUserID).Scan(&raw)
if err != nil {
if err == pgx.ErrNoRows {
return llm.Settings{}, nil
}
return llm.Settings{}, err
}
var out llm.Settings
if len(raw) > 0 {
if err := json.Unmarshal(raw, &out); err != nil {
return llm.Settings{}, err
}
}
return out, nil
}
func IsAssistantEnabled(ctx context.Context, db *pgxpool.Pool, deployEnabled bool) (AssistantPolicy, bool) {
policy, err := LoadAssistantPolicy(ctx, db)
if err != nil {
policy = AssistantPolicy{}
}
enabled := policy.Enabled || deployEnabled || isPluginEnabled(ctx, db, "ai-assistant")
return policy, enabled
}
func isPluginEnabled(ctx context.Context, db *pgxpool.Pool, pluginID string) bool {
if db == nil {
return false
}
var raw []byte
err := db.QueryRow(ctx, `
SELECT settings->'plugins' FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil {
return false
}
if len(raw) == 0 || string(raw) == "null" {
return false
}
var plugins []struct {
ID string `json:"id"`
Enabled bool `json:"enabled"`
}
if err := json.Unmarshal(raw, &plugins); err != nil {
return false
}
for _, plugin := range plugins {
if plugin.ID == pluginID {
return plugin.Enabled
}
}
return false
}
func LoadAssistantPolicy(ctx context.Context, db *pgxpool.Pool) (AssistantPolicy, error) {
defaults := AssistantPolicy{
Enabled: false,
PublicPath: "/ai",
EmbedDefaultTemporary: false,
EnabledTools: []string{"mail", "drive", "contacts", "agenda", "search", "web_search"},
ChatSyncEnabled: true,
ChatNCPath: "/.ultimail/ai/chats",
}
if db == nil {
return defaults, nil
}
var raw []byte
err := db.QueryRow(ctx, `
SELECT settings->'ai_assistant' FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil {
if err == pgx.ErrNoRows {
return defaults, nil
}
return defaults, err
}
if len(raw) == 0 || string(raw) == "null" {
return defaults, nil
}
var stored AssistantPolicy
if err := json.Unmarshal(raw, &stored); err != nil {
return defaults, err
}
if stored.PublicPath == "" {
stored.PublicPath = defaults.PublicPath
}
if stored.ChatNCPath == "" {
stored.ChatNCPath = defaults.ChatNCPath
}
if len(stored.EnabledTools) == 0 {
stored.EnabledTools = defaults.EnabledTools
}
return stored, nil
}
func ResolveDefaultModel(ctx context.Context, db *pgxpool.Pool, policy AssistantPolicy) string {
if model := strings.TrimSpace(policy.DefaultModel); model != "" {
return model
}
settings, err := LoadOrgLLMSettings(ctx, db)
if err != nil || len(settings.Providers) == 0 {
return ""
}
if defaultID := strings.TrimSpace(settings.DefaultProviderID); defaultID != "" {
for _, provider := range settings.Providers {
if provider.ID == defaultID {
if model := strings.TrimSpace(provider.DefaultModel); model != "" {
return model
}
break
}
}
}
for _, provider := range settings.Providers {
if model := strings.TrimSpace(provider.DefaultModel); model != "" {
return model
}
}
return ""
}
func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) {
defaults := QuotaLimits{RequestsPerDay: 75, TokensPerMonth: 2_000_000}
if db == nil {
return defaults, nil
}
var raw []byte
err := db.QueryRow(ctx, `
SELECT settings->'usage_quotas' FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil {
if err == pgx.ErrNoRows {
return defaults, nil
}
return defaults, err
}
if len(raw) == 0 || string(raw) == "null" {
return defaults, nil
}
var stored map[string]any
if err := json.Unmarshal(raw, &stored); err != nil {
return defaults, err
}
if v, ok := stored["llm_requests_per_day"].(float64); ok && v > 0 {
defaults.RequestsPerDay = int(v)
}
if v, ok := stored["llm_tokens_per_month"].(float64); ok && v > 0 {
defaults.TokensPerMonth = int64(v)
}
return defaults, nil
}