diff --git a/internal/ai/cost/fingerprint.go b/internal/ai/cost/fingerprint.go new file mode 100644 index 0000000..51dcb23 --- /dev/null +++ b/internal/ai/cost/fingerprint.go @@ -0,0 +1,18 @@ +package cost + +import ( + "crypto/sha256" + "encoding/hex" + "strings" +) + +// KeyFingerprint returns a stable short hash for grouping usage by API key. +func KeyFingerprint(providerID, apiKey string) string { + providerID = strings.TrimSpace(providerID) + apiKey = strings.TrimSpace(apiKey) + if apiKey == "" { + return providerID + ":none" + } + sum := sha256.Sum256([]byte(providerID + ":" + apiKey)) + return hex.EncodeToString(sum[:8]) +} diff --git a/internal/ai/cost/meter.go b/internal/ai/cost/meter.go new file mode 100644 index 0000000..83da293 --- /dev/null +++ b/internal/ai/cost/meter.go @@ -0,0 +1,149 @@ +package cost + +import ( + "context" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + BillingScopeOrg = "org" + BillingScopeUser = "user" +) + +type RecordInput struct { + ExternalUserID string + Feature string + ModelID string + ProviderID string + BillingScope string + ProviderKeyFingerprint string + Usage UsageDetail + RequestID string +} + +type Meter struct { + db *pgxpool.Pool + pricing *PricingStore +} + +func NewMeter(db *pgxpool.Pool) *Meter { + return &Meter{db: db, pricing: NewPricingStore(db)} +} + +func (m *Meter) RecordUsage(ctx context.Context, in RecordInput) error { + if m.db == nil { + return nil + } + userID, err := m.resolveUserID(ctx, in.ExternalUserID) + if err != nil { + return err + } + + price, found := m.pricing.LookupPrice(ctx, in.ModelID) + cost, estimated := ComputeCostMicroEUR(in.Usage, price, found) + + tx, err := m.db.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + _, err = tx.Exec(ctx, ` + INSERT INTO ai_usage_events ( + user_id, feature, model_id, provider_id, billing_scope, + provider_key_fingerprint, prompt_tokens, completion_tokens, + cached_input_tokens, reasoning_tokens, cost_micro_eur, estimated, request_id + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + `, userID, in.Feature, in.ModelID, in.ProviderID, in.BillingScope, + in.ProviderKeyFingerprint, in.Usage.PromptTokens, in.Usage.CompletionTokens, + in.Usage.CachedInputTokens, in.Usage.ReasoningTokens, cost, estimated, nullStr(in.RequestID)) + if err != nil { + return err + } + + today := time.Now().UTC().Truncate(24 * time.Hour) + month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) + tokens := int64(in.Usage.TotalTokens) + orgCost := int64(0) + userCost := int64(0) + if in.BillingScope == BillingScopeOrg { + orgCost = cost + } else { + userCost = cost + } + + _, err = tx.Exec(ctx, ` + INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens, cost_micro_eur_org, cost_micro_eur_user) + VALUES ($1, $2, 1, $3, $4, $5) + ON CONFLICT (user_id, usage_date) DO UPDATE SET + requests = ai_usage_daily.requests + 1, + tokens = ai_usage_daily.tokens + EXCLUDED.tokens, + cost_micro_eur_org = ai_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org, + cost_micro_eur_user = ai_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user + `, userID, today, tokens, orgCost, userCost) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + INSERT INTO ai_usage_monthly (user_id, usage_month, tokens, cost_micro_eur_org, cost_micro_eur_user) + VALUES ($1, $2, $3, $4, $5) + ON CONFLICT (user_id, usage_month) DO UPDATE SET + tokens = ai_usage_monthly.tokens + EXCLUDED.tokens, + cost_micro_eur_org = ai_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org, + cost_micro_eur_user = ai_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user + `, userID, month, tokens, orgCost, userCost) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + INSERT INTO ai_org_usage_daily (usage_date, cost_micro_eur_org, cost_micro_eur_user, requests) + VALUES ($1, $2, $3, 1) + ON CONFLICT (usage_date) DO UPDATE SET + cost_micro_eur_org = ai_org_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org, + cost_micro_eur_user = ai_org_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user, + requests = ai_org_usage_daily.requests + 1 + `, today, orgCost, userCost) + if err != nil { + return err + } + + _, err = tx.Exec(ctx, ` + INSERT INTO ai_org_usage_monthly (usage_month, cost_micro_eur_org, cost_micro_eur_user) + VALUES ($1, $2, $3) + ON CONFLICT (usage_month) DO UPDATE SET + cost_micro_eur_org = ai_org_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org, + cost_micro_eur_user = ai_org_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user + `, month, orgCost, userCost) + if err != nil { + return err + } + + return tx.Commit(ctx) +} + +func (m *Meter) resolveUserID(ctx context.Context, externalUserID string) (string, error) { + var userID string + err := m.db.QueryRow(ctx, ` + SELECT id::text FROM users WHERE external_id = $1 + `, externalUserID).Scan(&userID) + if err != nil { + if err == pgx.ErrNoRows { + return "", fmt.Errorf("user not found") + } + return "", err + } + return userID, nil +} + +func nullStr(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/internal/ai/cost/parse.go b/internal/ai/cost/parse.go new file mode 100644 index 0000000..c41a763 --- /dev/null +++ b/internal/ai/cost/parse.go @@ -0,0 +1,65 @@ +package cost + +import "encoding/json" + +type usagePayload struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details,omitempty"` +} + +type chatCompletionResponse struct { + Usage *usagePayload `json:"usage,omitempty"` +} + +// ParseUsage extracts token details from an OpenAI-compatible chat completion payload. +func ParseUsage(payload []byte) UsageDetail { + var parsed chatCompletionResponse + if err := json.Unmarshal(payload, &parsed); err != nil || parsed.Usage == nil { + return UsageDetail{TotalTokens: 1} + } + return usageFromPayload(parsed.Usage) +} + +func usageFromPayload(u *usagePayload) UsageDetail { + if u == nil { + return UsageDetail{TotalTokens: 1} + } + detail := UsageDetail{ + PromptTokens: u.PromptTokens, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + } + if u.PromptTokensDetails != nil { + detail.CachedInputTokens = u.PromptTokensDetails.CachedTokens + } + if u.CompletionTokensDetails != nil { + detail.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens + } + if detail.TotalTokens == 0 { + detail.TotalTokens = detail.PromptTokens + detail.CompletionTokens + } + if detail.TotalTokens == 0 { + detail.TotalTokens = 1 + } + return detail +} + +// MergeStreamUsage keeps the latest non-zero usage from streaming chunks. +func MergeStreamUsage(acc UsageDetail, chunk []byte) UsageDetail { + var parsed chatCompletionResponse + if err := json.Unmarshal(chunk, &parsed); err != nil || parsed.Usage == nil { + return acc + } + next := usageFromPayload(parsed.Usage) + if next.TotalTokens == 0 && next.PromptTokens == 0 && next.CompletionTokens == 0 { + return acc + } + return next +} diff --git a/internal/ai/cost/parse_test.go b/internal/ai/cost/parse_test.go new file mode 100644 index 0000000..c46de87 --- /dev/null +++ b/internal/ai/cost/parse_test.go @@ -0,0 +1,72 @@ +package cost + +import "testing" + +func TestParseUsageOpenAI(t *testing.T) { + payload := []byte(`{ + "usage": { + "prompt_tokens": 1200, + "completion_tokens": 340, + "total_tokens": 1540, + "prompt_tokens_details": {"cached_tokens": 800}, + "completion_tokens_details": {"reasoning_tokens": 50} + } + }`) + u := ParseUsage(payload) + if u.PromptTokens != 1200 || u.CompletionTokens != 340 { + t.Fatalf("unexpected tokens: %+v", u) + } + if u.CachedInputTokens != 800 || u.ReasoningTokens != 50 { + t.Fatalf("unexpected details: %+v", u) + } +} + +func TestParseUsageFallback(t *testing.T) { + u := ParseUsage([]byte(`{"choices":[]}`)) + if u.TotalTokens != 1 { + t.Fatalf("expected fallback 1, got %d", u.TotalTokens) + } +} + +func TestComputeCostMicroEUR(t *testing.T) { + price := ModelPrice{ + InputMicroEURPerMTok: 1_000_000, + OutputMicroEURPerMTok: 2_000_000, + } + usage := UsageDetail{ + PromptTokens: 1000, + CompletionTokens: 500, + CachedInputTokens: 200, + TotalTokens: 1500, + } + cost, estimated := ComputeCostMicroEUR(usage, price, true) + if estimated { + t.Fatal("should not be estimated when price found") + } + // uncached 800 * 1 + cached 200 * 0.5 + output 500 * 2 = 800+100+1000 = 1900 micro (cached uses half input) + if cost < 1800 || cost > 2000 { + t.Fatalf("unexpected cost %d", cost) + } +} + +func TestComputeCostUnknownModel(t *testing.T) { + usage := UsageDetail{TotalTokens: 1_000_000} + cost, estimated := ComputeCostMicroEUR(usage, ModelPrice{}, false) + if !estimated { + t.Fatal("expected estimated") + } + if cost != fallbackInputMicroEURPerMTok { + t.Fatalf("expected fallback cost %d, got %d", fallbackInputMicroEURPerMTok, cost) + } +} + +func TestMergeStreamUsage(t *testing.T) { + acc := UsageDetail{} + chunk1 := []byte(`{"usage":{"prompt_tokens":10,"completion_tokens":0,"total_tokens":10}}`) + chunk2 := []byte(`{"usage":{"prompt_tokens":100,"completion_tokens":40,"total_tokens":140}}`) + acc = MergeStreamUsage(acc, chunk1) + acc = MergeStreamUsage(acc, chunk2) + if acc.TotalTokens != 140 { + t.Fatalf("expected 140, got %d", acc.TotalTokens) + } +} diff --git a/internal/ai/cost/policy.go b/internal/ai/cost/policy.go new file mode 100644 index 0000000..993b76d --- /dev/null +++ b/internal/ai/cost/policy.go @@ -0,0 +1,421 @@ +package cost + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +var ErrCostLimitExceeded = errors.New("llm cost limit exceeded") + +type EffectiveLimits struct { + DailyLimitMicroEUR *int64 + MonthlyLimitMicroEUR *int64 + WarnThresholdPct int +} + +type SpendStatus struct { + CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"` + CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"` + CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"` + CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"` + CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"` + CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"` + WarnThresholdPct int `json:"warn_threshold_pct"` + Currency string `json:"currency"` + BillingScopeOrg bool `json:"billing_scope_org"` + ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"` + + // Legacy fields (deprecated) + RequestsUsedToday int `json:"requests_used_today"` + RequestsLimit int `json:"requests_limit"` + TokensUsedMonth int64 `json:"tokens_used_month"` + TokensLimit int64 `json:"tokens_limit"` + RequestsRemaining int `json:"requests_remaining"` + TokensRemaining int64 `json:"tokens_remaining"` +} + +type ProviderKeySpend struct { + Fingerprint string `json:"fingerprint"` + Label string `json:"label"` + CostMonthMicroEUR int64 `json:"cost_month_micro_eur"` + CostMonthEUR float64 `json:"cost_month_eur"` + BillingScope string `json:"billing_scope"` +} + +type PolicyService struct { + db *pgxpool.Pool +} + +func NewPolicyService(db *pgxpool.Pool) *PolicyService { + return &PolicyService{db: db} +} + +type policyCandidate struct { + daily *int64 + monthly *int64 + warn int + priority int +} + +func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) { + limits := EffectiveLimits{WarnThresholdPct: 80} + if s.db == nil { + return limits, nil + } + + var candidates []policyCandidate + + // Org policy (priority 0) + org, err := s.loadPolicy(ctx, "org", "") + if err == nil && org != nil { + candidates = append(candidates, policyCandidate{ + daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0, + }) + } + + // Group policies (priority 10) — most restrictive wins among groups + rows, err := s.db.Query(ctx, ` + SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct + FROM ai_cost_policies p + JOIN user_group_members ugm ON ugm.group_id = p.scope_id + WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid + `, userID) + if err == nil { + defer rows.Close() + for rows.Next() { + var c policyCandidate + c.priority = 10 + if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil { + continue + } + candidates = append(candidates, c) + } + } + + // User policy (priority 20) + userPol, err := s.loadPolicy(ctx, "user", userID) + if err == nil && userPol != nil { + candidates = append(candidates, policyCandidate{ + daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20, + }) + } + + if len(candidates) == 0 { + return limits, nil + } + + // User override wins if present; else merge group+org with most restrictive + var userCandidate *policyCandidate + var others []policyCandidate + for i := range candidates { + if candidates[i].priority >= 20 { + userCandidate = &candidates[i] + } else { + others = append(others, candidates[i]) + } + } + + var merged policyCandidate + if userCandidate != nil { + merged = *userCandidate + } else { + merged = mergeMostRestrictive(others) + } + + limits.DailyLimitMicroEUR = merged.daily + limits.MonthlyLimitMicroEUR = merged.monthly + if merged.warn > 0 { + limits.WarnThresholdPct = merged.warn + } + return limits, nil +} + +type policyRow struct { + daily *int64 + monthly *int64 + warn int +} + +func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) { + var daily, monthly *int64 + var warn int + var err error + if scopeType == "org" { + err = s.db.QueryRow(ctx, ` + SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct + FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1 + `).Scan(&daily, &monthly, &warn) + } else { + err = s.db.QueryRow(ctx, ` + SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct + FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid + `, scopeType, scopeID).Scan(&daily, &monthly, &warn) + } + if err != nil { + if err == pgx.ErrNoRows { + return nil, nil + } + return nil, err + } + return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil +} + +func mergeMostRestrictive(candidates []policyCandidate) policyCandidate { + if len(candidates) == 0 { + return policyCandidate{warn: 80} + } + out := candidates[0] + for _, c := range candidates[1:] { + if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) { + out.daily = c.daily + } + if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) { + out.monthly = c.monthly + } + if c.warn > 0 && c.warn < out.warn { + out.warn = c.warn + } + } + if out.warn == 0 { + out.warn = 80 + } + return out +} + +func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) { + userID, err := s.resolveUserID(ctx, externalUserID) + if err != nil { + return SpendStatus{}, err + } + + today := time.Now().UTC().Truncate(24 * time.Hour) + month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) + + var dailyOrg, dailyUser int64 + var requestsToday int + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0) + FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2 + `, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday) + + var monthlyOrg, monthlyUser, tokensMonth int64 + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0) + FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2 + `, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth) + + limits, _ := s.ResolveEffectiveLimits(ctx, userID) + + status := SpendStatus{ + Currency: "EUR", + BillingScopeOrg: orgBilling, + WarnThresholdPct: limits.WarnThresholdPct, + RequestsUsedToday: requestsToday, + TokensUsedMonth: tokensMonth, + } + + if orgBilling { + status.CostUsedTodayMicroEUR = dailyOrg + status.CostUsedMonthMicroEUR = monthlyOrg + status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR + status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR + } else { + status.CostUsedTodayMicroEUR = dailyUser + status.CostUsedMonthMicroEUR = monthlyUser + } + + status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR) + status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR) + + keys, _ := s.providerKeyBreakdown(ctx, userID, month) + if keys == nil { + keys = []ProviderKeySpend{} + } + status.ByProviderKeys = keys + + return status, nil +} + +func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error { + if billingScope == BillingScopeUser { + return nil + } + userID, err := s.resolveUserID(ctx, externalUserID) + if err != nil { + return err + } + limits, err := s.ResolveEffectiveLimits(ctx, userID) + if err != nil { + return err + } + + today := time.Now().UTC().Truncate(24 * time.Hour) + month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) + + var dailyOrg, monthlyOrg int64 + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily + WHERE user_id = $1::uuid AND usage_date = $2 + `, userID, today).Scan(&dailyOrg) + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly + WHERE user_id = $1::uuid AND usage_month = $2 + `, userID, month).Scan(&monthlyOrg) + + if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR { + return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded) + } + if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR { + return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded) + } + return nil +} + +func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) { + rows, err := s.db.Query(ctx, ` + SELECT provider_key_fingerprint, billing_scope, + COALESCE(SUM(cost_micro_eur), 0), + MAX(provider_id) + FROM ai_usage_events + WHERE user_id = $1::uuid AND created_at >= $2 + GROUP BY provider_key_fingerprint, billing_scope + ORDER BY SUM(cost_micro_eur) DESC + LIMIT 20 + `, userID, month) + if err != nil { + return []ProviderKeySpend{}, err + } + defer rows.Close() + out := make([]ProviderKeySpend, 0) + for rows.Next() { + var item ProviderKeySpend + var providerID string + if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil { + continue + } + item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR) + suffix := item.Fingerprint + if len(suffix) > 8 { + suffix = suffix[len(suffix)-8:] + } + item.Label = providerID + " ···" + suffix + out = append(out, item) + } + return out, rows.Err() +} + +func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) { + var userID string + err := s.db.QueryRow(ctx, ` + SELECT id::text FROM users WHERE external_id = $1 + `, externalUserID).Scan(&userID) + if err != nil { + if err == pgx.ErrNoRows { + return "", fmt.Errorf("user not found") + } + return "", err + } + return userID, nil +} + +func remaining(limit *int64, used int64) *int64 { + if limit == nil || *limit <= 0 { + return nil + } + r := *limit - used + if r < 0 { + r = 0 + } + return &r +} + +// UpsertOrgPolicy updates the org-level cost policy. +func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error { + if warnPct <= 0 { + warnPct = 80 + } + var exists bool + if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil { + return err + } + if exists { + _, err := s.db.Exec(ctx, ` + UPDATE ai_cost_policies SET + daily_limit_micro_eur = $1, + monthly_limit_micro_eur = $2, + warn_threshold_pct = $3, + updated_at = NOW() + WHERE scope_type = 'org' + `, daily, monthly, warnPct) + return err + } + _, err := s.db.Exec(ctx, ` + INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority) + VALUES ('org', NULL, $1, $2, $3, 0) + `, daily, monthly, warnPct) + return err +} + +// UpsertScopePolicy updates group or user policy. +func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error { + if warnPct <= 0 { + warnPct = 80 + } + priority := 10 + if scopeType == "user" { + priority = 20 + } + var exists bool + err := s.db.QueryRow(ctx, ` + SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid) + `, scopeType, scopeID).Scan(&exists) + if err != nil { + return err + } + if exists { + _, err = s.db.Exec(ctx, ` + UPDATE ai_cost_policies SET + daily_limit_micro_eur = $1, + monthly_limit_micro_eur = $2, + warn_threshold_pct = $3, + updated_at = NOW() + WHERE scope_type = $4 AND scope_id = $5::uuid + `, daily, monthly, warnPct, scopeType, scopeID) + return err + } + _, err = s.db.Exec(ctx, ` + INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority) + VALUES ($1, $2::uuid, $3, $4, $5, $6) + `, scopeType, scopeID, daily, monthly, warnPct, priority) + return err +} + +func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error { + if scopeType == "org" { + return fmt.Errorf("cannot delete org policy") + } + _, err := s.db.Exec(ctx, ` + DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid + `, scopeType, scopeID) + return err +} + +func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) { + row, err := s.loadPolicy(ctx, "org", "") + if err != nil { + return EffectiveLimits{WarnThresholdPct: 80}, err + } + if row == nil { + return EffectiveLimits{WarnThresholdPct: 80}, nil + } + return EffectiveLimits{ + DailyLimitMicroEUR: row.daily, + MonthlyLimitMicroEUR: row.monthly, + WarnThresholdPct: row.warn, + }, nil +} diff --git a/internal/ai/cost/policy_test.go b/internal/ai/cost/policy_test.go new file mode 100644 index 0000000..3b8e187 --- /dev/null +++ b/internal/ai/cost/policy_test.go @@ -0,0 +1,31 @@ +package cost + +import ( + "testing" +) + +func TestMergeMostRestrictive(t *testing.T) { + d10 := int64(10_000_000) + d5 := int64(5_000_000) + m100 := int64(100_000_000) + m50 := int64(50_000_000) + merged := mergeMostRestrictive([]policyCandidate{ + {daily: &d10, monthly: &m100, warn: 80}, + {daily: &d5, monthly: &m50, warn: 70}, + }) + if merged.daily == nil || *merged.daily != d5 { + t.Fatalf("expected daily 5M, got %v", merged.daily) + } + if merged.monthly == nil || *merged.monthly != m50 { + t.Fatalf("expected monthly 50M, got %v", merged.monthly) + } + if merged.warn != 70 { + t.Fatalf("expected warn 70, got %d", merged.warn) + } +} + +func TestMicroEURToEUR(t *testing.T) { + if v := MicroEURToEUR(1_500_000); v != 1.5 { + t.Fatalf("expected 1.5, got %f", v) + } +} diff --git a/internal/ai/cost/pricing.go b/internal/ai/cost/pricing.go new file mode 100644 index 0000000..a7e1425 --- /dev/null +++ b/internal/ai/cost/pricing.go @@ -0,0 +1,206 @@ +package cost + +import ( + "context" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// ModelPrice holds per-million-token rates in micro-EUR. +type ModelPrice struct { + ModelID string + ProviderType string + InputMicroEURPerMTok int64 + CachedInputMicroEURPerMTok int64 + OutputMicroEURPerMTok int64 + ReasoningMicroEURPerMTok int64 +} + +// Default fallback when model is unknown (~gpt-4o-mini input rate). +const fallbackInputMicroEURPerMTok int64 = 140000 +const fallbackOutputMicroEURPerMTok int64 = 560000 + +type PricingStore struct { + db *pgxpool.Pool +} + +func NewPricingStore(db *pgxpool.Pool) *PricingStore { + return &PricingStore{db: db} +} + +func (s *PricingStore) LookupPrice(ctx context.Context, modelID string) (ModelPrice, bool) { + modelID = strings.TrimSpace(modelID) + if modelID == "" || s.db == nil { + return ModelPrice{}, false + } + var p ModelPrice + var cached, reasoning *int64 + err := s.db.QueryRow(ctx, ` + SELECT model_id, provider_type, + input_micro_eur_per_mtok, + cached_input_micro_eur_per_mtok, + output_micro_eur_per_mtok, + reasoning_micro_eur_per_mtok + FROM ai_model_pricing + WHERE model_id = $1 AND effective_from <= CURRENT_DATE + ORDER BY effective_from DESC + LIMIT 1 + `, modelID).Scan( + &p.ModelID, &p.ProviderType, + &p.InputMicroEURPerMTok, &cached, + &p.OutputMicroEURPerMTok, &reasoning, + ) + if err != nil { + if err != pgx.ErrNoRows { + // prefix match for bedrock/azure model ids + err = s.db.QueryRow(ctx, ` + SELECT model_id, provider_type, + input_micro_eur_per_mtok, + cached_input_micro_eur_per_mtok, + output_micro_eur_per_mtok, + reasoning_micro_eur_per_mtok + FROM ai_model_pricing + WHERE $1 LIKE model_id || '%' AND effective_from <= CURRENT_DATE + ORDER BY LENGTH(model_id) DESC, effective_from DESC + LIMIT 1 + `, modelID).Scan( + &p.ModelID, &p.ProviderType, + &p.InputMicroEURPerMTok, &cached, + &p.OutputMicroEURPerMTok, &reasoning, + ) + } + if err != nil { + return ModelPrice{}, false + } + } + if cached != nil { + p.CachedInputMicroEURPerMTok = *cached + } + if reasoning != nil { + p.ReasoningMicroEURPerMTok = *reasoning + } + return p, true +} + +// ComputeCostMicroEUR calculates estimated cost from usage and model pricing. +func ComputeCostMicroEUR(usage UsageDetail, price ModelPrice, found bool) (microEUR int64, estimated bool) { + if !found { + price = ModelPrice{ + InputMicroEURPerMTok: fallbackInputMicroEURPerMTok, + OutputMicroEURPerMTok: fallbackOutputMicroEURPerMTok, + } + estimated = true + if usage.PromptTokens == 0 && usage.CompletionTokens == 0 { + microEUR = int64(usage.TotalTokens) * fallbackInputMicroEURPerMTok / 1_000_000 + if microEUR == 0 && usage.TotalTokens > 0 { + microEUR = 1 + } + return microEUR, true + } + } + + cachedRate := price.CachedInputMicroEURPerMTok + if cachedRate == 0 { + cachedRate = price.InputMicroEURPerMTok / 2 + } + reasoningRate := price.ReasoningMicroEURPerMTok + if reasoningRate == 0 { + reasoningRate = price.OutputMicroEURPerMTok + } + + uncached := usage.UncachedInputTokens() + microEUR += int64(uncached) * price.InputMicroEURPerMTok / 1_000_000 + microEUR += int64(usage.CachedInputTokens) * cachedRate / 1_000_000 + completion := usage.CompletionTokens - usage.ReasoningTokens + if completion < 0 { + completion = usage.CompletionTokens + } + microEUR += int64(completion) * price.OutputMicroEURPerMTok / 1_000_000 + microEUR += int64(usage.ReasoningTokens) * reasoningRate / 1_000_000 + + if microEUR == 0 && usage.TotalTokens > 0 { + microEUR = 1 + } + return microEUR, estimated +} + +// UpsertModelPrice stores or updates pricing for a model (effective today). +func (s *PricingStore) UpsertModelPrice(ctx context.Context, p ModelPrice) error { + if s.db == nil { + return nil + } + today := time.Now().UTC().Truncate(24 * time.Hour) + var cached, reasoning *int64 + if p.CachedInputMicroEURPerMTok > 0 { + v := p.CachedInputMicroEURPerMTok + cached = &v + } + if p.ReasoningMicroEURPerMTok > 0 { + v := p.ReasoningMicroEURPerMTok + reasoning = &v + } + _, err := s.db.Exec(ctx, ` + INSERT INTO ai_model_pricing ( + model_id, provider_type, + input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok, + output_micro_eur_per_mtok, reasoning_micro_eur_per_mtok, + effective_from, source + ) VALUES ($1, $2, $3, $4, $5, $6, $7, 'manual') + ON CONFLICT (model_id, effective_from) DO UPDATE SET + provider_type = EXCLUDED.provider_type, + input_micro_eur_per_mtok = EXCLUDED.input_micro_eur_per_mtok, + cached_input_micro_eur_per_mtok = EXCLUDED.cached_input_micro_eur_per_mtok, + output_micro_eur_per_mtok = EXCLUDED.output_micro_eur_per_mtok, + reasoning_micro_eur_per_mtok = EXCLUDED.reasoning_micro_eur_per_mtok, + source = EXCLUDED.source + `, p.ModelID, p.ProviderType, p.InputMicroEURPerMTok, cached, + p.OutputMicroEURPerMTok, reasoning, today) + return err +} + +func (s *PricingStore) ListPrices(ctx context.Context) ([]ModelPrice, error) { + if s.db == nil { + return nil, nil + } + rows, err := s.db.Query(ctx, ` + SELECT DISTINCT ON (model_id) + model_id, provider_type, + input_micro_eur_per_mtok, + cached_input_micro_eur_per_mtok, + output_micro_eur_per_mtok, + reasoning_micro_eur_per_mtok + FROM ai_model_pricing + WHERE effective_from <= CURRENT_DATE + ORDER BY model_id, effective_from DESC + `) + if err != nil { + return nil, err + } + defer rows.Close() + var out []ModelPrice + for rows.Next() { + var p ModelPrice + var cached, reasoning *int64 + if err := rows.Scan(&p.ModelID, &p.ProviderType, + &p.InputMicroEURPerMTok, &cached, + &p.OutputMicroEURPerMTok, &reasoning); err != nil { + return nil, err + } + if cached != nil { + p.CachedInputMicroEURPerMTok = *cached + } + if reasoning != nil { + p.ReasoningMicroEURPerMTok = *reasoning + } + out = append(out, p) + } + return out, rows.Err() +} + +// MicroEURToEUR converts micro-EUR to float EUR for API responses. +func MicroEURToEUR(micro int64) float64 { + return float64(micro) / 1_000_000 +} diff --git a/internal/ai/cost/usage.go b/internal/ai/cost/usage.go new file mode 100644 index 0000000..967ff26 --- /dev/null +++ b/internal/ai/cost/usage.go @@ -0,0 +1,18 @@ +package cost + +// UsageDetail holds token counts from a provider response. +type UsageDetail struct { + PromptTokens int + CompletionTokens int + CachedInputTokens int + ReasoningTokens int + TotalTokens int +} + +func (u UsageDetail) UncachedInputTokens() int { + uncached := u.PromptTokens - u.CachedInputTokens + if uncached < 0 { + return u.PromptTokens + } + return uncached +} diff --git a/internal/ai/gateway.go b/internal/ai/gateway.go index 03b275c..eb435f1 100644 --- a/internal/ai/gateway.go +++ b/internal/ai/gateway.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" + "github.com/ultisuite/ulti-backend/internal/ai/cost" "github.com/ultisuite/ulti-backend/internal/llm" ) @@ -40,12 +41,6 @@ type chatCompletionRequest struct { Tools []any `json:"tools,omitempty"` } -type usagePayload struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` -} - type chatCompletionResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -57,7 +52,17 @@ type chatCompletionResponse struct { FinishReason string `json:"finish_reason"` Delta *llm.ChatMessage `json:"delta,omitempty"` } `json:"choices"` - Usage *usagePayload `json:"usage,omitempty"` + Usage *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details,omitempty"` + } `json:"usage,omitempty"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` @@ -116,12 +121,6 @@ func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Setti } func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error { - if strings.TrimSpace(quotaExternalUserID) != "" { - if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil { - return err - } - } - var modelProbe struct { Model string `json:"model"` } @@ -151,6 +150,14 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID return err } + billingScope := ResolveBillingScope(ctx, g.db, quotaExternalUserID, provider, useOrgSettings) + + if strings.TrimSpace(quotaExternalUserID) != "" { + if err := g.quota.AssertAvailable(ctx, quotaExternalUserID, provider, useOrgSettings); err != nil { + return err + } + } + upstreamBody, err := repairChatCompletionBody(body) if err != nil { return err @@ -182,7 +189,7 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID defer resp.Body.Close() if stream { - return g.proxyStream(ctx, quotaExternalUserID, w, resp) + return g.proxyStream(ctx, quotaExternalUserID, model, provider, billingScope, w, resp) } payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) if err != nil { @@ -195,13 +202,21 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID return nil } if strings.TrimSpace(quotaExternalUserID) != "" { - tokens := extractUsageTokens(payload) - _ = g.quota.Record(ctx, quotaExternalUserID, tokens) + usage := cost.ParseUsage(payload) + _ = g.quota.RecordUsage(ctx, cost.RecordInput{ + ExternalUserID: quotaExternalUserID, + Feature: "gateway", + ModelID: model, + ProviderID: provider.ID, + BillingScope: billingScope, + ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey), + Usage: usage, + }) } return nil } -func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error { +func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID, model string, provider llm.Provider, billingScope string, w http.ResponseWriter, resp *http.Response) error { rc := http.NewResponseController(w) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") @@ -210,7 +225,7 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w w.WriteHeader(resp.StatusCode) reader := bufio.NewReader(resp.Body) - var totalTokens int64 + var usage cost.UsageDetail for { line, err := reader.ReadString('\n') if len(line) > 0 { @@ -219,7 +234,8 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w return fmt.Errorf("streaming not supported: %w", err) } if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { - totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))) + chunk := []byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")) + usage = cost.MergeStreamUsage(usage, chunk) } } if err != nil { @@ -235,10 +251,18 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w } } if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" { - if totalTokens == 0 { - totalTokens = 1 + if usage.TotalTokens == 0 { + usage.TotalTokens = 1 } - _ = g.quota.Record(ctx, quotaExternalUserID, totalTokens) + _ = g.quota.RecordUsage(ctx, cost.RecordInput{ + ExternalUserID: quotaExternalUserID, + Feature: "gateway", + ModelID: model, + ProviderID: provider.ID, + BillingScope: billingScope, + ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey), + Usage: usage, + }) } return nil } @@ -262,31 +286,6 @@ func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider, return provider, resolvedModel, nil } -func extractUsageTokens(payload []byte) int64 { - var parsed chatCompletionResponse - if err := json.Unmarshal(payload, &parsed); err != nil { - return 1 - } - if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 { - return int64(parsed.Usage.TotalTokens) - } - if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 { - return int64(parsed.Usage.CompletionTokens) - } - return 1 -} - -func extractStreamUsageTokens(payload []byte) int64 { - var parsed chatCompletionResponse - if err := json.Unmarshal(payload, &parsed); err != nil { - return 0 - } - if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 { - return int64(parsed.Usage.TotalTokens) - } - return 0 -} - func NowUnix() int64 { return time.Now().Unix() } diff --git a/internal/ai/gateway_test.go b/internal/ai/gateway_test.go index b5a04d9..199c2f7 100644 --- a/internal/ai/gateway_test.go +++ b/internal/ai/gateway_test.go @@ -1,16 +1,22 @@ package ai -import "testing" +import ( + "testing" -func TestExtractUsageTokens(t *testing.T) { + "github.com/ultisuite/ulti-backend/internal/ai/cost" +) + +func TestParseUsageViaCost(t *testing.T) { payload := []byte(`{"usage":{"total_tokens":42,"completion_tokens":10}}`) - if got := extractUsageTokens(payload); got != 42 { - t.Fatalf("extractUsageTokens() = %d, want 42", got) + u := cost.ParseUsage(payload) + if u.TotalTokens != 42 { + t.Fatalf("ParseUsage() = %d, want 42", u.TotalTokens) } } -func TestExtractUsageTokensFallback(t *testing.T) { - if got := extractUsageTokens([]byte(`{"choices":[]}`)); got != 1 { - t.Fatalf("expected fallback token count 1, got %d", got) +func TestParseUsageFallback(t *testing.T) { + u := cost.ParseUsage([]byte(`{"choices":[]}`)) + if u.TotalTokens != 1 { + t.Fatalf("expected fallback token count 1, got %d", u.TotalTokens) } } diff --git a/internal/ai/metering.go b/internal/ai/metering.go new file mode 100644 index 0000000..4b46137 --- /dev/null +++ b/internal/ai/metering.go @@ -0,0 +1,34 @@ +package ai + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/ai/cost" + "github.com/ultisuite/ulti-backend/internal/llm" +) + +// RecordFeatureUsage meters an LLM call from a non-gateway feature. +func RecordFeatureUsage(ctx context.Context, db *pgxpool.Pool, externalUserID, feature, modelID string, provider llm.Provider, usage llm.UsageDetail) { + if db == nil || externalUserID == "" { + return + } + q := NewQuotaService(db) + scope := ResolveBillingScope(ctx, db, externalUserID, provider, false) + _ = q.RecordUsage(ctx, cost.RecordInput{ + ExternalUserID: externalUserID, + Feature: feature, + ModelID: modelID, + ProviderID: provider.ID, + BillingScope: scope, + ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey), + Usage: cost.UsageDetail{ + PromptTokens: usage.PromptTokens, + CompletionTokens: usage.CompletionTokens, + CachedInputTokens: usage.CachedInputTokens, + ReasoningTokens: usage.ReasoningTokens, + TotalTokens: usage.TotalTokens, + }, + }) +} diff --git a/internal/ai/providers.go b/internal/ai/providers.go index f1c4bb1..7837aea 100644 --- a/internal/ai/providers.go +++ b/internal/ai/providers.go @@ -229,7 +229,7 @@ func ResolveDefaultModel(ctx context.Context, db *pgxpool.Pool, policy Assistant } func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) { - defaults := QuotaLimits{RequestsPerDay: 100, TokensPerMonth: 500_000} + defaults := QuotaLimits{RequestsPerDay: 75, TokensPerMonth: 2_000_000} if db == nil { return defaults, nil } diff --git a/internal/ai/quota.go b/internal/ai/quota.go index 2729a9a..41a40ff 100644 --- a/internal/ai/quota.go +++ b/internal/ai/quota.go @@ -2,123 +2,95 @@ package ai import ( "context" - "errors" - "fmt" - "time" + "strings" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/ai/cost" + "github.com/ultisuite/ulti-backend/internal/llm" ) -var ErrQuotaExceeded = errors.New("llm quota exceeded") - +// QuotaService wraps cost policy and metering for backward compatibility. type QuotaService struct { - db *pgxpool.Pool + db *pgxpool.Pool + policy *cost.PolicyService + meter *cost.Meter } func NewQuotaService(db *pgxpool.Pool) *QuotaService { - return &QuotaService{db: db} + return &QuotaService{ + db: db, + policy: cost.NewPolicyService(db), + meter: cost.NewMeter(db), + } } -func (s *QuotaService) Check(ctx context.Context, externalUserID string) (QuotaStatus, error) { - limits, err := LoadQuotaLimits(ctx, s.db) +func (s *QuotaService) Check(ctx context.Context, externalUserID string) (SpendStatus, error) { + orgBilling, _ := s.usesOrgBilling(ctx, externalUserID) + status, err := s.policy.GetStatus(ctx, externalUserID, orgBilling) if err != nil { - return QuotaStatus{}, err - } - userID, err := s.resolveUserID(ctx, externalUserID) - if err != nil { - return QuotaStatus{}, err - } - - today := time.Now().UTC().Truncate(24 * time.Hour) - month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) - - var requestsToday int - var tokensMonth int64 - _ = s.db.QueryRow(ctx, ` - SELECT COALESCE(requests, 0) FROM ai_usage_daily - WHERE user_id = $1 AND usage_date = $2 - `, userID, today).Scan(&requestsToday) - _ = s.db.QueryRow(ctx, ` - SELECT COALESCE(tokens, 0) FROM ai_usage_monthly - WHERE user_id = $1 AND usage_month = $2 - `, userID, month).Scan(&tokensMonth) - - status := QuotaStatus{ - RequestsUsedToday: requestsToday, - RequestsLimit: limits.RequestsPerDay, - TokensUsedMonth: tokensMonth, - TokensLimit: limits.TokensPerMonth, - } - if limits.RequestsPerDay > 0 { - status.RequestsRemaining = limits.RequestsPerDay - requestsToday - if status.RequestsRemaining < 0 { - status.RequestsRemaining = 0 - } - } - if limits.TokensPerMonth > 0 { - status.TokensRemaining = limits.TokensPerMonth - tokensMonth - if status.TokensRemaining < 0 { - status.TokensRemaining = 0 - } + return SpendStatus{}, err } return status, nil } -func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string) error { - status, err := s.Check(ctx, externalUserID) - if err != nil { - return err - } - if status.RequestsLimit > 0 && status.RequestsUsedToday >= status.RequestsLimit { - return fmt.Errorf("%w: daily request limit reached", ErrQuotaExceeded) - } - if status.TokensLimit > 0 && status.TokensUsedMonth >= status.TokensLimit { - return fmt.Errorf("%w: monthly token limit reached", ErrQuotaExceeded) - } - return nil +func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string, provider llm.Provider, useOrgSettings bool) error { + scope := ResolveBillingScope(ctx, s.db, externalUserID, provider, useOrgSettings) + return s.policy.AssertAvailable(ctx, externalUserID, scope) } -func (s *QuotaService) Record(ctx context.Context, externalUserID string, tokens int64) error { - if tokens < 0 { - tokens = 0 - } - userID, err := s.resolveUserID(ctx, externalUserID) - if err != nil { - return err - } - today := time.Now().UTC().Truncate(24 * time.Hour) - month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) - - _, err = s.db.Exec(ctx, ` - INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens) - VALUES ($1, $2, 1, $3) - ON CONFLICT (user_id, usage_date) DO UPDATE SET - requests = ai_usage_daily.requests + 1, - tokens = ai_usage_daily.tokens + EXCLUDED.tokens - `, userID, today, tokens) - if err != nil { - return err - } - _, err = s.db.Exec(ctx, ` - INSERT INTO ai_usage_monthly (user_id, usage_month, tokens) - VALUES ($1, $2, $3) - ON CONFLICT (user_id, usage_month) DO UPDATE SET - tokens = ai_usage_monthly.tokens + EXCLUDED.tokens - `, userID, month, tokens) - return err +func (s *QuotaService) RecordUsage(ctx context.Context, in cost.RecordInput) error { + return s.meter.RecordUsage(ctx, in) } -func (s *QuotaService) resolveUserID(ctx context.Context, externalUserID string) (string, error) { - var userID string - err := s.db.QueryRow(ctx, ` - SELECT id::text FROM users WHERE external_id = $1 - `, externalUserID).Scan(&userID) +func (s *QuotaService) usesOrgBilling(ctx context.Context, externalUserID string) (bool, error) { + settings, err := LoadEffectiveLLMSettings(ctx, s.db, externalUserID) if err != nil { - if err == pgx.ErrNoRows { - return "", fmt.Errorf("user not found") + return true, err + } + org, err := loadOrgLLMPolicy(ctx, s.db) + if err != nil { + return true, err + } + if org.EnforceOrgProviders || len(settings.Providers) == 0 { + return true, nil + } + user, err := loadUserLLMSettings(ctx, s.db, externalUserID) + if err != nil { + return true, err + } + return len(user.Providers) == 0, nil +} + +// ResolveBillingScope determines whether usage is billed to org or user's own key. +func ResolveBillingScope(ctx context.Context, db *pgxpool.Pool, externalUserID string, provider llm.Provider, useOrgSettings bool) string { + if useOrgSettings { + return cost.BillingScopeOrg + } + org, err := loadOrgLLMPolicy(ctx, db) + if err != nil || org.EnforceOrgProviders { + return cost.BillingScopeOrg + } + user, err := loadUserLLMSettings(ctx, db, externalUserID) + if err != nil || len(user.Providers) == 0 { + return cost.BillingScopeOrg + } + apiKey := strings.TrimSpace(provider.APIKey) + for _, op := range org.Providers { + if op.ID == provider.ID && strings.TrimSpace(op.APIKey) == apiKey && apiKey != "" { + return cost.BillingScopeOrg } - return "", err } - return userID, nil + for _, up := range user.Providers { + if up.ID == provider.ID && strings.TrimSpace(up.APIKey) != "" { + return cost.BillingScopeUser + } + } + return cost.BillingScopeOrg } + +// SpendStatus is the user-facing quota/spend response. +type SpendStatus = cost.SpendStatus + +// ErrQuotaExceeded aliases cost limit error for backward compatibility. +var ErrQuotaExceeded = cost.ErrCostLimitExceeded diff --git a/internal/ai/quota_test.go b/internal/ai/quota_test.go index dd79692..83802b1 100644 --- a/internal/ai/quota_test.go +++ b/internal/ai/quota_test.go @@ -2,20 +2,16 @@ package ai import "testing" -func TestQuotaStatusRemaining(t *testing.T) { - status := QuotaStatus{ - RequestsUsedToday: 40, - RequestsLimit: 100, - TokensUsedMonth: 100_000, - TokensLimit: 500_000, +func TestSpendStatusCostRemaining(t *testing.T) { + limit := int64(10_000_000) + status := SpendStatus{ + CostUsedTodayMicroEUR: 4_000_000, + CostLimitTodayMicroEUR: &limit, + Currency: "EUR", } - status.RequestsRemaining = status.RequestsLimit - status.RequestsUsedToday - status.TokensRemaining = status.TokensLimit - status.TokensUsedMonth - if status.RequestsRemaining != 60 { - t.Fatalf("requests remaining = %d", status.RequestsRemaining) - } - if status.TokensRemaining != 400_000 { - t.Fatalf("tokens remaining = %d", status.TokensRemaining) + remaining := *status.CostLimitTodayMicroEUR - status.CostUsedTodayMicroEUR + if remaining != 6_000_000 { + t.Fatalf("cost remaining = %d", remaining) } } diff --git a/internal/ai/types.go b/internal/ai/types.go index 01b6058..2f802a7 100644 --- a/internal/ai/types.go +++ b/internal/ai/types.go @@ -21,19 +21,14 @@ type AssistantPolicy struct { } type QuotaLimits struct { + DailyLimitMicroEUR *int64 `json:"llm_daily_cost_limit_micro_eur,omitempty"` + MonthlyLimitMicroEUR *int64 `json:"llm_monthly_cost_limit_micro_eur,omitempty"` + WarnThresholdPct int `json:"llm_cost_warn_threshold_pct,omitempty"` + // Deprecated legacy fields RequestsPerDay int `json:"llm_requests_per_day"` TokensPerMonth int64 `json:"llm_tokens_per_month"` } -type QuotaStatus struct { - RequestsUsedToday int `json:"requests_used_today"` - RequestsLimit int `json:"requests_limit"` - TokensUsedMonth int64 `json:"tokens_used_month"` - TokensLimit int64 `json:"tokens_limit"` - RequestsRemaining int `json:"requests_remaining"` - TokensRemaining int64 `json:"tokens_remaining"` -} - type ChatMessage struct { Role string `json:"role"` Content string `json:"content,omitempty"` diff --git a/internal/api/admin/ai_usage.go b/internal/api/admin/ai_usage.go new file mode 100644 index 0000000..59c4c3d --- /dev/null +++ b/internal/api/admin/ai_usage.go @@ -0,0 +1,384 @@ +package admin + +import ( + "context" + "fmt" + "time" + + "github.com/ultisuite/ulti-backend/internal/ai/cost" +) + +type AIUsageSummary struct { + CostTodayMicroEUR int64 `json:"cost_today_micro_eur"` + CostMonthMicroEUR int64 `json:"cost_month_micro_eur"` + CostTodayEUR float64 `json:"cost_today_eur"` + CostMonthEUR float64 `json:"cost_month_eur"` + Currency string `json:"currency"` + DailySeries []AIUsageDayPoint `json:"daily_series"` + TopUsers []AIUsageTopUser `json:"top_users"` + TopModels []AIUsageTopModel `json:"top_models"` + OrgPolicy cost.EffectiveLimits `json:"org_policy"` +} + +type AIUsageDayPoint struct { + Date string `json:"date"` + CostOrgMicroEUR int64 `json:"cost_org_micro_eur"` + CostUserMicroEUR int64 `json:"cost_user_micro_eur"` + CostOrgEUR float64 `json:"cost_org_eur"` + Requests int `json:"requests"` +} + +type AIUsageTopUser struct { + UserID string `json:"user_id"` + Email string `json:"email"` + DisplayName string `json:"display_name"` + CostOrgMicroEUR int64 `json:"cost_org_micro_eur"` + CostOrgEUR float64 `json:"cost_org_eur"` +} + +type AIUsageTopModel struct { + ModelID string `json:"model_id"` + CostMicroEUR int64 `json:"cost_micro_eur"` + CostEUR float64 `json:"cost_eur"` + RequestCount int `json:"request_count"` +} + +type AIUserUsageDetail struct { + UserID string `json:"user_id"` + Email string `json:"email"` + DisplayName string `json:"display_name"` + Summary AIUsageSummary `json:"summary"` + Events []AIUsageEventItem `json:"events"` + EventsTotal int `json:"events_total"` +} + +type AIUsageEventItem struct { + ID string `json:"id"` + CreatedAt string `json:"created_at"` + Feature string `json:"feature"` + ModelID string `json:"model_id"` + ProviderID string `json:"provider_id"` + BillingScope string `json:"billing_scope"` + CostMicroEUR int64 `json:"cost_micro_eur"` + CostEUR float64 `json:"cost_eur"` + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + CachedInputTokens int `json:"cached_input_tokens"` + Estimated bool `json:"estimated"` +} + +type AIPricingEntry struct { + ModelID string `json:"model_id"` + ProviderType string `json:"provider_type"` + InputMicroEURPerMTok int64 `json:"input_micro_eur_per_mtok"` + CachedInputMicroEURPerMTok int64 `json:"cached_input_micro_eur_per_mtok,omitempty"` + OutputMicroEURPerMTok int64 `json:"output_micro_eur_per_mtok"` + ReasoningMicroEURPerMTok int64 `json:"reasoning_micro_eur_per_mtok,omitempty"` + InputEURPerMTok float64 `json:"input_eur_per_mtok"` + OutputEURPerMTok float64 `json:"output_eur_per_mtok"` +} + +type AICostPolicyEntry struct { + ScopeType string `json:"scope_type"` + ScopeID *string `json:"scope_id,omitempty"` + ScopeName string `json:"scope_name,omitempty"` + DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"` + MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"` + DailyLimitEUR *float64 `json:"daily_limit_eur,omitempty"` + MonthlyLimitEUR *float64 `json:"monthly_limit_eur,omitempty"` + WarnThresholdPct int `json:"warn_threshold_pct"` +} + +type putAIPricingRequest struct { + Prices []AIPricingEntry `json:"prices"` +} + +type putAICostPolicyRequest struct { + DailyLimitEUR *float64 `json:"daily_limit_eur"` + MonthlyLimitEUR *float64 `json:"monthly_limit_eur"` + WarnThresholdPct int `json:"warn_threshold_pct"` + DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"` + MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"` +} + +func (s *Service) GetAIUsageSummary(ctx context.Context, billingScope string) (AIUsageSummary, error) { + today := time.Now().UTC().Truncate(24 * time.Hour) + month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) + + out := AIUsageSummary{ + Currency: "EUR", + DailySeries: []AIUsageDayPoint{}, + TopUsers: []AIUsageTopUser{}, + TopModels: []AIUsageTopModel{}, + } + policySvc := cost.NewPolicyService(s.db) + orgPolicy, _ := policySvc.GetOrgPolicy(ctx) + out.OrgPolicy = orgPolicy + + var dailyOrg, dailyUser int64 + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0) + FROM ai_org_usage_daily WHERE usage_date = $1 + `, today).Scan(&dailyOrg, &dailyUser) + + var monthlyOrg int64 + _ = s.db.QueryRow(ctx, ` + SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_org_usage_monthly WHERE usage_month = $1 + `, month).Scan(&monthlyOrg) + + if billingScope == "user" { + out.CostTodayMicroEUR = dailyUser + } else { + out.CostTodayMicroEUR = dailyOrg + } + out.CostMonthMicroEUR = monthlyOrg + out.CostTodayEUR = cost.MicroEURToEUR(out.CostTodayMicroEUR) + out.CostMonthEUR = cost.MicroEURToEUR(out.CostMonthMicroEUR) + + rows, err := s.db.Query(ctx, ` + SELECT usage_date, cost_micro_eur_org, cost_micro_eur_user, requests + FROM ai_org_usage_daily + WHERE usage_date >= $1 + ORDER BY usage_date ASC + `, today.AddDate(0, 0, -29)) + if err == nil { + defer rows.Close() + for rows.Next() { + var pt AIUsageDayPoint + var d time.Time + if err := rows.Scan(&d, &pt.CostOrgMicroEUR, &pt.CostUserMicroEUR, &pt.Requests); err != nil { + continue + } + pt.Date = d.Format("2006-01-02") + pt.CostOrgEUR = cost.MicroEURToEUR(pt.CostOrgMicroEUR) + out.DailySeries = append(out.DailySeries, pt) + } + } + + topUserRows, err := s.db.Query(ctx, ` + SELECT u.id::text, COALESCE(u.email, ''), COALESCE(u.display_name, ''), + COALESCE(SUM(e.cost_micro_eur), 0) + FROM ai_usage_events e + JOIN users u ON u.id = e.user_id + WHERE e.created_at >= $1 AND e.billing_scope = 'org' + GROUP BY u.id, u.email, u.display_name + ORDER BY SUM(e.cost_micro_eur) DESC + LIMIT 10 + `, month) + if err == nil { + defer topUserRows.Close() + for topUserRows.Next() { + var u AIUsageTopUser + if err := topUserRows.Scan(&u.UserID, &u.Email, &u.DisplayName, &u.CostOrgMicroEUR); err != nil { + continue + } + u.CostOrgEUR = cost.MicroEURToEUR(u.CostOrgMicroEUR) + out.TopUsers = append(out.TopUsers, u) + } + } + + topModelRows, err := s.db.Query(ctx, ` + SELECT model_id, COALESCE(SUM(cost_micro_eur), 0), COUNT(*) + FROM ai_usage_events + WHERE created_at >= $1 AND billing_scope = 'org' + GROUP BY model_id + ORDER BY SUM(cost_micro_eur) DESC + LIMIT 10 + `, month) + if err == nil { + defer topModelRows.Close() + for topModelRows.Next() { + var m AIUsageTopModel + if err := topModelRows.Scan(&m.ModelID, &m.CostMicroEUR, &m.RequestCount); err != nil { + continue + } + m.CostEUR = cost.MicroEURToEUR(m.CostMicroEUR) + out.TopModels = append(out.TopModels, m) + } + } + + return out, nil +} + +func (s *Service) GetUserAIUsage(ctx context.Context, userID string, limit, offset int) (AIUserUsageDetail, error) { + if limit <= 0 { + limit = 50 + } + var detail AIUserUsageDetail + detail.Events = []AIUsageEventItem{} + err := s.db.QueryRow(ctx, ` + SELECT id::text, COALESCE(email, ''), COALESCE(display_name, '') + FROM users WHERE id = $1::uuid + `, userID).Scan(&detail.UserID, &detail.Email, &detail.DisplayName) + if err != nil { + return detail, fmt.Errorf("user not found") + } + + summary, _ := s.GetAIUsageSummary(ctx, "org") + detail.Summary = summary + + _ = s.db.QueryRow(ctx, ` + SELECT COUNT(*) FROM ai_usage_events WHERE user_id = $1::uuid + `, userID).Scan(&detail.EventsTotal) + + rows, err := s.db.Query(ctx, ` + SELECT id::text, created_at, feature, model_id, provider_id, billing_scope, + cost_micro_eur, prompt_tokens, completion_tokens, cached_input_tokens, estimated + FROM ai_usage_events + WHERE user_id = $1::uuid + ORDER BY created_at DESC + LIMIT $2 OFFSET $3 + `, userID, limit, offset) + if err != nil { + return detail, err + } + defer rows.Close() + for rows.Next() { + var ev AIUsageEventItem + var created time.Time + if err := rows.Scan(&ev.ID, &created, &ev.Feature, &ev.ModelID, &ev.ProviderID, + &ev.BillingScope, &ev.CostMicroEUR, &ev.PromptTokens, &ev.CompletionTokens, + &ev.CachedInputTokens, &ev.Estimated); err != nil { + continue + } + ev.CreatedAt = created.UTC().Format(time.RFC3339) + ev.CostEUR = cost.MicroEURToEUR(ev.CostMicroEUR) + detail.Events = append(detail.Events, ev) + } + return detail, rows.Err() +} + +func (s *Service) ListAIPricing(ctx context.Context) ([]AIPricingEntry, error) { + store := cost.NewPricingStore(s.db) + prices, err := store.ListPrices(ctx) + if err != nil { + return []AIPricingEntry{}, err + } + out := make([]AIPricingEntry, 0, len(prices)) + for _, p := range prices { + out = append(out, AIPricingEntry{ + ModelID: p.ModelID, + ProviderType: p.ProviderType, + InputMicroEURPerMTok: p.InputMicroEURPerMTok, + CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok, + OutputMicroEURPerMTok: p.OutputMicroEURPerMTok, + ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok, + InputEURPerMTok: cost.MicroEURToEUR(p.InputMicroEURPerMTok), + OutputEURPerMTok: cost.MicroEURToEUR(p.OutputMicroEURPerMTok), + }) + } + return out, nil +} + +func (s *Service) PutAIPricing(ctx context.Context, actorSub string, req putAIPricingRequest) ([]AIPricingEntry, error) { + store := cost.NewPricingStore(s.db) + for _, p := range req.Prices { + if err := store.UpsertModelPrice(ctx, cost.ModelPrice{ + ModelID: p.ModelID, + ProviderType: p.ProviderType, + InputMicroEURPerMTok: p.InputMicroEURPerMTok, + CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok, + OutputMicroEURPerMTok: p.OutputMicroEURPerMTok, + ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok, + }); err != nil { + return nil, err + } + } + s.logAudit(ctx, actorSub, "update_ai_pricing", map[string]any{"count": len(req.Prices)}) + return s.ListAIPricing(ctx) +} + +func (s *Service) ListAICostPolicies(ctx context.Context) ([]AICostPolicyEntry, error) { + rows, err := s.db.Query(ctx, ` + SELECT p.scope_type, p.scope_id::text, p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct, + COALESCE(g.name, u.email, 'Organisation') + FROM ai_cost_policies p + LEFT JOIN user_groups g ON p.scope_type = 'group' AND g.id = p.scope_id + LEFT JOIN users u ON p.scope_type = 'user' AND u.id = p.scope_id + ORDER BY p.priority ASC, p.scope_type + `) + if err != nil { + return []AICostPolicyEntry{}, err + } + defer rows.Close() + var out = make([]AICostPolicyEntry, 0) + for rows.Next() { + var e AICostPolicyEntry + var scopeID *string + if err := rows.Scan(&e.ScopeType, &scopeID, &e.DailyLimitMicroEUR, &e.MonthlyLimitMicroEUR, + &e.WarnThresholdPct, &e.ScopeName); err != nil { + continue + } + e.ScopeID = scopeID + if e.DailyLimitMicroEUR != nil { + v := cost.MicroEURToEUR(*e.DailyLimitMicroEUR) + e.DailyLimitEUR = &v + } + if e.MonthlyLimitMicroEUR != nil { + v := cost.MicroEURToEUR(*e.MonthlyLimitMicroEUR) + e.MonthlyLimitEUR = &v + } + out = append(out, e) + } + return out, rows.Err() +} + +func (s *Service) PutUserAICostPolicy(ctx context.Context, actorSub, userID string, req putAICostPolicyRequest) error { + daily, monthly := resolveCostLimits(req) + policy := cost.NewPolicyService(s.db) + if err := policy.UpsertScopePolicy(ctx, "user", userID, daily, monthly, req.WarnThresholdPct); err != nil { + return err + } + s.logAudit(ctx, actorSub, "set_user_ai_cost_policy", map[string]any{"user_id": userID}) + return nil +} + +func (s *Service) PutGroupAICostPolicy(ctx context.Context, actorSub, groupID string, req putAICostPolicyRequest) error { + daily, monthly := resolveCostLimits(req) + policy := cost.NewPolicyService(s.db) + if err := policy.UpsertScopePolicy(ctx, "group", groupID, daily, monthly, req.WarnThresholdPct); err != nil { + return err + } + s.logAudit(ctx, actorSub, "set_group_ai_cost_policy", map[string]any{"group_id": groupID}) + return nil +} + +func (s *Service) syncUsageQuotasToCostPolicy(ctx context.Context, usageQuotas map[string]any) error { + if usageQuotas == nil { + return nil + } + var daily, monthly *int64 + warnPct := 80 + + if v, ok := usageQuotas["llm_daily_cost_limit_eur"].(float64); ok && v > 0 { + m := int64(v * 1_000_000) + daily = &m + } + if v, ok := usageQuotas["llm_monthly_cost_limit_eur"].(float64); ok && v > 0 { + m := int64(v * 1_000_000) + monthly = &m + } + if v, ok := usageQuotas["llm_cost_warn_threshold_pct"].(float64); ok && v > 0 { + warnPct = int(v) + } + + policy := cost.NewPolicyService(s.db) + return policy.UpsertOrgPolicy(ctx, daily, monthly, warnPct) +} + +func resolveCostLimits(req putAICostPolicyRequest) (*int64, *int64) { + var daily, monthly *int64 + if req.DailyLimitMicroEUR != nil { + daily = req.DailyLimitMicroEUR + } else if req.DailyLimitEUR != nil && *req.DailyLimitEUR > 0 { + v := int64(*req.DailyLimitEUR * 1_000_000) + daily = &v + } + if req.MonthlyLimitMicroEUR != nil { + monthly = req.MonthlyLimitMicroEUR + } else if req.MonthlyLimitEUR != nil && *req.MonthlyLimitEUR > 0 { + v := int64(*req.MonthlyLimitEUR * 1_000_000) + monthly = &v + } + return daily, monthly +} diff --git a/internal/api/admin/ai_usage_handlers.go b/internal/api/admin/ai_usage_handlers.go new file mode 100644 index 0000000..a56be87 --- /dev/null +++ b/internal/api/admin/ai_usage_handlers.go @@ -0,0 +1,116 @@ +package admin + +import ( + "net/http" + "strconv" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" +) + +func (h *Handler) GetAIUsage(w http.ResponseWriter, r *http.Request) { + scope := r.URL.Query().Get("scope") + if scope == "" { + scope = "org" + } + summary, err := h.svc.GetAIUsageSummary(r.Context(), scope) + if err != nil { + h.logger.Error("get ai usage", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, summary) +} + +func (h *Handler) GetUserAIUsage(w http.ResponseWriter, r *http.Request) { + userID := chi.URLParam(r, "userID") + limit, _ := strconv.Atoi(r.URL.Query().Get("limit")) + offset, _ := strconv.Atoi(r.URL.Query().Get("offset")) + detail, err := h.svc.GetUserAIUsage(r.Context(), userID, limit, offset) + if err != nil { + if err.Error() == "user not found" { + apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, err.Error(), nil) + return + } + h.logger.Error("get user ai usage", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, detail) +} + +func (h *Handler) GetAIPricing(w http.ResponseWriter, r *http.Request) { + prices, err := h.svc.ListAIPricing(r.Context()) + if err != nil { + h.logger.Error("list ai pricing", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices}) +} + +func (h *Handler) PutAIPricing(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + var req putAIPricingRequest + if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil { + return + } + prices, err := h.svc.PutAIPricing(r.Context(), claims.Sub, req) + if err != nil { + h.logger.Error("put ai pricing", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices}) +} + +func (h *Handler) GetAICostPolicies(w http.ResponseWriter, r *http.Request) { + policies, err := h.svc.ListAICostPolicies(r.Context()) + if err != nil { + h.logger.Error("list ai cost policies", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"policies": policies}) +} + +func (h *Handler) PutUserAICostPolicy(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + userID := chi.URLParam(r, "userID") + if verr := validateUserID(userID); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + var req putAICostPolicyRequest + if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil { + return + } + if err := h.svc.PutUserAICostPolicy(r.Context(), claims.Sub, userID, req); err != nil { + h.logger.Error("put user ai cost policy", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true}) +} + +func (h *Handler) PutGroupAICostPolicy(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + groupID := chi.URLParam(r, "groupID") + if verr := validateGroupID(groupID); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + var req putAICostPolicyRequest + if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil { + return + } + if err := h.svc.PutGroupAICostPolicy(r.Context(), claims.Sub, groupID, req); err != nil { + h.logger.Error("put group ai cost policy", "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true}) +} diff --git a/internal/api/admin/handlers.go b/internal/api/admin/handlers.go index c19cae5..9616e6b 100644 --- a/internal/api/admin/handlers.go +++ b/internal/api/admin/handlers.go @@ -79,6 +79,14 @@ func (h *Handler) Routes() chi.Router { r.With(write).Put("/org/settings", h.PutOrgSettings) r.With(read).Post("/org/llm/discover-models", h.DiscoverOrgLLMModels) + r.With(read).Get("/ai/usage", h.GetAIUsage) + r.With(read).Get("/ai/usage/users/{userID}", h.GetUserAIUsage) + r.With(read).Get("/ai/pricing", h.GetAIPricing) + r.With(write).Put("/ai/pricing", h.PutAIPricing) + r.With(read).Get("/ai/policies", h.GetAICostPolicies) + r.With(write).Put("/users/{userID}/ai-policy", h.PutUserAICostPolicy) + r.With(write).Put("/user-groups/{groupID}/ai-policy", h.PutGroupAICostPolicy) + r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI) r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider) r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider) diff --git a/internal/api/admin/org_settings.go b/internal/api/admin/org_settings.go index a5ec95c..eaeb924 100644 --- a/internal/api/admin/org_settings.go +++ b/internal/api/admin/org_settings.go @@ -45,11 +45,14 @@ func defaultOrgPolicy() map[string]any { "warn_threshold_pct": 90, }, "usage_quotas": map[string]any{ - "llm_requests_per_day": 100, - "llm_tokens_per_month": 500000, - "search_requests_per_day": 50, - "max_api_tokens_per_user": 10, - "max_webhooks_per_user": 20, + "llm_daily_cost_limit_eur": 2, + "llm_monthly_cost_limit_eur": 35, + "llm_cost_warn_threshold_pct": 80, + "llm_requests_per_day": 75, + "llm_tokens_per_month": 2_000_000, + "search_requests_per_day": 20, + "max_api_tokens_per_user": 5, + "max_webhooks_per_user": 5, }, "file_policies": map[string]any{ "max_upload_mib": 512, @@ -909,6 +912,12 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map } } + if usageQuotas, ok := merged["usage_quotas"].(map[string]any); ok { + if err := s.syncUsageQuotasToCostPolicy(ctx, usageQuotas); err != nil { + s.logger.Warn("sync ai cost policy failed", "error", err) + } + } + return s.GetOrgSettings(ctx) } diff --git a/internal/api/meet/transcript_processor.go b/internal/api/meet/transcript_processor.go index c3205a3..399295c 100644 --- a/internal/api/meet/transcript_processor.go +++ b/internal/api/meet/transcript_processor.go @@ -13,6 +13,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" + "github.com/ultisuite/ulti-backend/internal/ai" "github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/orgpolicy" @@ -104,11 +105,14 @@ func (p *TranscriptProcessor) runPostActions( actions := policy.PostActions if actions.LLMEnabled { - summary, err := p.summarize(ctx, policy, rawTranscript) + summary, provider, model, usage, err := p.summarize(ctx, policy, rawTranscript) if err != nil { p.logger.Warn("llm summary failed", "error", err, "job_id", jobID) } else if strings.TrimSpace(summary) != "" { finalText = summary + if extID, err := ai.ResolveExternalIDByEmail(ctx, p.db, in.OrganizerEmail); err == nil && extID != "" { + ai.RecordFeatureUsage(ctx, p.db, extID, "ultimeet", model, provider, usage) + } } } @@ -145,16 +149,20 @@ func (p *TranscriptProcessor) runPostActions( return err } -func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, error) { +func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, llm.Provider, string, llm.UsageDetail, error) { provider, model, err := p.resolveLLM(ctx, policy.PostActions.LLMProviderID) if err != nil { - return "", err + return "", llm.Provider{}, "", llm.UsageDetail{}, err } prompt := strings.TrimSpace(policy.PostActions.LLMPrompt) if prompt == "" { prompt = "Résume cette réunion en français." } - return p.llm.Complete(ctx, provider, model, prompt, transcript) + result, err := p.llm.CompleteWithUsage(ctx, provider, model, prompt, transcript) + if err != nil { + return "", provider, model, llm.UsageDetail{}, err + } + return result.Content, provider, result.Model, result.Usage, nil } func (p *TranscriptProcessor) resolveLLM(ctx context.Context, providerID string) (llm.Provider, string, error) { diff --git a/internal/contacts/discovery/enrich.go b/internal/contacts/discovery/enrich.go index 8df3647..e44e1c3 100644 --- a/internal/contacts/discovery/enrich.go +++ b/internal/contacts/discovery/enrich.go @@ -7,6 +7,9 @@ import ( "strings" "time" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/ai" "github.com/ultisuite/ulti-backend/internal/llm" ) @@ -60,7 +63,7 @@ func parseEnrichedData(raw string) (*EnrichedContactData, error) { return &data, nil } -func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) { +func enrichWithLLMTimeout(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) { enrichCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -69,7 +72,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm. err error }, 1) go func() { - data, err := enrichWithLLM(enrichCtx, client, settings, email, displayName, signatures) + data, err := enrichWithLLM(enrichCtx, db, externalUserID, client, settings, email, displayName, signatures) resultCh <- struct { data *EnrichedContactData err error @@ -89,7 +92,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm. } } -func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) { +func enrichWithLLM(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) { if client == nil || len(signatures) == 0 { return nil, fmt.Errorf("no signatures to enrich") } @@ -98,11 +101,12 @@ func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Setting return nil, err } prompt := buildEnrichPrompt(email, displayName, signatures) - raw, err := client.Complete(ctx, provider, model, enrichSystemPrompt, prompt) + result, err := client.CompleteWithUsage(ctx, provider, model, enrichSystemPrompt, prompt) if err != nil { return nil, err } - return parseEnrichedData(raw) + ai.RecordFeatureUsage(ctx, db, externalUserID, "contact_discovery", result.Model, provider, result.Usage) + return parseEnrichedData(result.Content) } func enrichedDataToSuggestions(userID, profileID string, data *EnrichedContactData) []Suggestion { diff --git a/internal/contacts/discovery/enrich_job.go b/internal/contacts/discovery/enrich_job.go index 0fe2145..75a5dcb 100644 --- a/internal/contacts/discovery/enrich_job.go +++ b/internal/contacts/discovery/enrich_job.go @@ -121,7 +121,7 @@ func (s *Service) runProfileEnrichment(externalUserID, profileID, ncUserID, book } enriched, enrichErr := enrichWithLLMTimeout( - ctx, s.llm, llmSettings, + ctx, s.db, externalUserID, s.llm, llmSettings, profile.PrimaryEmail, profile.DisplayName, sigs, llmEnrichTimeout, ) if enrichErr != nil { diff --git a/internal/contacts/discovery/improve_contact.go b/internal/contacts/discovery/improve_contact.go index 06e9e56..8badaf0 100644 --- a/internal/contacts/discovery/improve_contact.go +++ b/internal/contacts/discovery/improve_contact.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" + "github.com/ultisuite/ulti-backend/internal/ai" "github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/websearch" ) @@ -120,11 +121,12 @@ func (s *Service) ImproveContact(ctx context.Context, externalUserID string, inp searchSection := s.fetchContactSearchResults(improveCtx, externalUserID, input) prompt := buildImproveContactPrompt(input, searchSection) - raw, err := s.llm.Complete(improveCtx, provider, model, improveContactSystemPrompt, prompt) + raw, err := s.llm.CompleteWithUsage(improveCtx, provider, model, improveContactSystemPrompt, prompt) if err != nil { return nil, err } - data, err := parseEnrichedData(raw) + ai.RecordFeatureUsage(ctx, s.db, externalUserID, "contact_discovery", raw.Model, provider, raw.Usage) + data, err := parseEnrichedData(raw.Content) if err != nil { return nil, fmt.Errorf("parse improved contact: %w", err) } diff --git a/internal/contacts/discovery/service.go b/internal/contacts/discovery/service.go index 914b20d..afaf90b 100644 --- a/internal/contacts/discovery/service.go +++ b/internal/contacts/discovery/service.go @@ -334,7 +334,7 @@ func (s *Service) executeScan(ctx context.Context, externalUserID, ncUserID, boo heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background()) go s.enrichHeartbeat(heartbeatCtx, scanID, externalUserID, messagesScanned, enrichDone, totalMessages, enrichTotal) - enriched, enrichErr := enrichWithLLMTimeout(ctx, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout) + enriched, enrichErr := enrichWithLLMTimeout(ctx, s.db, externalUserID, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout) heartbeatCancel() if enrichErr != nil { s.logger.Warn("llm enrichment failed", "email", email, "error", enrichErr) diff --git a/internal/llm/client.go b/internal/llm/client.go index dbba24d..40cf397 100644 --- a/internal/llm/client.go +++ b/internal/llm/client.go @@ -44,6 +44,17 @@ type chatResponse struct { Content string `json:"content"` } `json:"message"` } `json:"choices"` + Usage *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details,omitempty"` + } `json:"usage,omitempty"` Error *struct { Message string `json:"message"` } `json:"error,omitempty"` @@ -67,16 +78,40 @@ func NewClient() *Client { } func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) { + result, err := c.CompleteWithUsage(ctx, provider, model, systemPrompt, userPrompt) + if err != nil { + return "", err + } + return result.Content, nil +} + +// CompletionResult holds LLM output and usage metadata. +type CompletionResult struct { + Content string + Model string + Usage UsageDetail +} + +// UsageDetail mirrors ai/cost.UsageDetail for llm package consumers. +type UsageDetail struct { + PromptTokens int + CompletionTokens int + CachedInputTokens int + ReasoningTokens int + TotalTokens int +} + +func (c *Client) CompleteWithUsage(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (CompletionResult, error) { baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/") if baseURL == "" { - return "", fmt.Errorf("llm provider base_url is required") + return CompletionResult{}, fmt.Errorf("llm provider base_url is required") } model = strings.TrimSpace(model) if model == "" { model = strings.TrimSpace(provider.DefaultModel) } if model == "" { - return "", fmt.Errorf("llm model is required") + return CompletionResult{}, fmt.Errorf("llm model is required") } reqBody := chatRequest{ @@ -89,13 +124,13 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP } payload, err := json.Marshal(reqBody) if err != nil { - return "", err + return CompletionResult{}, err } url := baseURL + "/chat/completions" req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) if err != nil { - return "", err + return CompletionResult{}, err } req.Header.Set("Content-Type", "application/json") if strings.TrimSpace(provider.APIKey) != "" { @@ -104,29 +139,68 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP resp, err := c.http.Do(req) if err != nil { - return "", err + return CompletionResult{}, err } defer resp.Body.Close() body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) if err != nil { - return "", err + return CompletionResult{}, err } if resp.StatusCode >= 400 { - return "", fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body)) + return CompletionResult{}, fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body)) } var parsed chatResponse if err := json.Unmarshal(body, &parsed); err != nil { - return "", err + return CompletionResult{}, err } if parsed.Error != nil && parsed.Error.Message != "" { - return "", fmt.Errorf("llm error: %s", parsed.Error.Message) + return CompletionResult{}, fmt.Errorf("llm error: %s", parsed.Error.Message) } if len(parsed.Choices) == 0 { - return "", fmt.Errorf("llm returned no choices") + return CompletionResult{}, fmt.Errorf("llm returned no choices") } - return strings.TrimSpace(parsed.Choices[0].Message.Content), nil + usage := parseUsageFromResponse(parsed.Usage) + return CompletionResult{ + Content: strings.TrimSpace(parsed.Choices[0].Message.Content), + Model: model, + Usage: usage, + }, nil +} + +func parseUsageFromResponse(u *struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *struct { + CachedTokens int `json:"cached_tokens"` + } `json:"prompt_tokens_details,omitempty"` + CompletionTokensDetails *struct { + ReasoningTokens int `json:"reasoning_tokens"` + } `json:"completion_tokens_details,omitempty"` +}) UsageDetail { + if u == nil { + return UsageDetail{TotalTokens: 1} + } + d := UsageDetail{ + PromptTokens: u.PromptTokens, + CompletionTokens: u.CompletionTokens, + TotalTokens: u.TotalTokens, + } + if u.PromptTokensDetails != nil { + d.CachedInputTokens = u.PromptTokensDetails.CachedTokens + } + if u.CompletionTokensDetails != nil { + d.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens + } + if d.TotalTokens == 0 { + d.TotalTokens = d.PromptTokens + d.CompletionTokens + } + if d.TotalTokens == 0 { + d.TotalTokens = 1 + } + return d } func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) { diff --git a/migrations/000052_ai_cost_metering.down.sql b/migrations/000052_ai_cost_metering.down.sql new file mode 100644 index 0000000..36078b8 --- /dev/null +++ b/migrations/000052_ai_cost_metering.down.sql @@ -0,0 +1,9 @@ +DROP TABLE IF EXISTS ai_cost_policies; +DROP TABLE IF EXISTS ai_org_usage_monthly; +DROP TABLE IF EXISTS ai_org_usage_daily; +ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_org; +ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_user; +ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_org; +ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_user; +DROP TABLE IF EXISTS ai_usage_events; +DROP TABLE IF EXISTS ai_model_pricing; diff --git a/migrations/000052_ai_cost_metering.up.sql b/migrations/000052_ai_cost_metering.up.sql new file mode 100644 index 0000000..80ef011 --- /dev/null +++ b/migrations/000052_ai_cost_metering.up.sql @@ -0,0 +1,96 @@ +-- Model pricing (micro-EUR per 1M tokens) +CREATE TABLE IF NOT EXISTS ai_model_pricing ( + model_id TEXT NOT NULL, + provider_type TEXT NOT NULL DEFAULT 'generic', + input_micro_eur_per_mtok BIGINT NOT NULL, + cached_input_micro_eur_per_mtok BIGINT, + output_micro_eur_per_mtok BIGINT NOT NULL, + reasoning_micro_eur_per_mtok BIGINT, + effective_from DATE NOT NULL DEFAULT CURRENT_DATE, + source TEXT NOT NULL DEFAULT 'manual', + PRIMARY KEY (model_id, effective_from) +); + +-- Detailed usage ledger +CREATE TABLE IF NOT EXISTS ai_usage_events ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + feature TEXT NOT NULL, + model_id TEXT NOT NULL, + provider_id TEXT NOT NULL, + billing_scope TEXT NOT NULL CHECK (billing_scope IN ('org', 'user')), + provider_key_fingerprint TEXT NOT NULL DEFAULT '', + prompt_tokens INT NOT NULL DEFAULT 0, + completion_tokens INT NOT NULL DEFAULT 0, + cached_input_tokens INT NOT NULL DEFAULT 0, + reasoning_tokens INT NOT NULL DEFAULT 0, + cost_micro_eur BIGINT NOT NULL DEFAULT 0, + estimated BOOLEAN NOT NULL DEFAULT false, + request_id TEXT +); + +CREATE INDEX IF NOT EXISTS idx_ai_usage_events_user_time ON ai_usage_events(user_id, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_ai_usage_events_scope_key ON ai_usage_events(billing_scope, provider_key_fingerprint, created_at DESC); +CREATE INDEX IF NOT EXISTS idx_ai_usage_events_created ON ai_usage_events(created_at DESC); + +-- Extend daily rollups +ALTER TABLE ai_usage_daily + ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0; + +-- Extend monthly rollups +ALTER TABLE ai_usage_monthly + ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0; + +-- Org-wide aggregates +CREATE TABLE IF NOT EXISTS ai_org_usage_daily ( + usage_date DATE NOT NULL PRIMARY KEY, + cost_micro_eur_org BIGINT NOT NULL DEFAULT 0, + cost_micro_eur_user BIGINT NOT NULL DEFAULT 0, + requests INT NOT NULL DEFAULT 0 +); + +CREATE TABLE IF NOT EXISTS ai_org_usage_monthly ( + usage_month DATE NOT NULL PRIMARY KEY, + cost_micro_eur_org BIGINT NOT NULL DEFAULT 0, + cost_micro_eur_user BIGINT NOT NULL DEFAULT 0 +); + +-- Cost limit policies (org / group / user) +CREATE TABLE IF NOT EXISTS ai_cost_policies ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + scope_type TEXT NOT NULL CHECK (scope_type IN ('org', 'group', 'user')), + scope_id UUID, + daily_limit_micro_eur BIGINT, + monthly_limit_micro_eur BIGINT, + warn_threshold_pct INT NOT NULL DEFAULT 80, + priority INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_org ON ai_cost_policies(scope_type) WHERE scope_type = 'org'; +CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_group ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'group'; +CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_user ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'user'; + +-- Default org policy: ~10 EUR/day, ~100 EUR/month +INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority) +SELECT 'org', NULL, 10000000, 100000000, 80, 0 +WHERE NOT EXISTS (SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org'); + +-- Seed common model pricing (approximate public rates in EUR) +INSERT INTO ai_model_pricing (model_id, provider_type, input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok, output_micro_eur_per_mtok, source) VALUES + ('gpt-4o-mini', 'openai', 140000, 70000, 560000, 'seed'), + ('gpt-4o', 'openai', 2300000, 1150000, 9200000, 'seed'), + ('gpt-4.1-mini', 'openai', 360000, 90000, 1440000, 'seed'), + ('gpt-4.1', 'openai', 1800000, 450000, 7200000, 'seed'), + ('o3-mini', 'openai', 990000, 250000, 3960000, 'seed'), + ('claude-sonnet-4-6', 'anthropic', 2700000, 270000, 13500000, 'seed'), + ('claude-haiku-4-5', 'anthropic', 900000, 90000, 4500000, 'seed'), + ('mistral-small-latest', 'mistral', 180000, 90000, 540000, 'seed'), + ('mistral-large-latest', 'mistral', 1800000, 450000, 5400000, 'seed'), + ('gemini-2.0-flash', 'google_gemini', 90000, 23000, 360000, 'seed'), + ('gemini-2.5-pro', 'google_gemini', 1100000, 280000, 9000000, 'seed') +ON CONFLICT DO NOTHING; diff --git a/migrations/000053_ai_quota_defaults_pme.down.sql b/migrations/000053_ai_quota_defaults_pme.down.sql new file mode 100644 index 0000000..6c3de20 --- /dev/null +++ b/migrations/000053_ai_quota_defaults_pme.down.sql @@ -0,0 +1,42 @@ +UPDATE ai_cost_policies +SET daily_limit_micro_eur = 10000000, + monthly_limit_micro_eur = 100000000, + updated_at = NOW() +WHERE scope_type = 'org' + AND daily_limit_micro_eur = 2000000 + AND monthly_limit_micro_eur = 35000000; + +UPDATE org_settings +SET settings = jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + COALESCE(settings, '{}'::jsonb), + '{usage_quotas,llm_daily_cost_limit_eur}', + '10'::jsonb + ), + '{usage_quotas,llm_monthly_cost_limit_eur}', + '100'::jsonb + ), + '{usage_quotas,llm_requests_per_day}', + '100'::jsonb + ), + '{usage_quotas,llm_tokens_per_month}', + '500000'::jsonb + ), + '{usage_quotas,search_requests_per_day}', + '50'::jsonb + ), + '{usage_quotas,max_api_tokens_per_user}', + '10'::jsonb + ), + '{usage_quotas,max_webhooks_per_user}', + '20'::jsonb +), +updated_at = NOW() +WHERE id = 1 + AND COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 0) = 2 + AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 0) = 35; diff --git a/migrations/000053_ai_quota_defaults_pme.up.sql b/migrations/000053_ai_quota_defaults_pme.up.sql new file mode 100644 index 0000000..03898dd --- /dev/null +++ b/migrations/000053_ai_quota_defaults_pme.up.sql @@ -0,0 +1,56 @@ +-- Align factory AI quotas with reasonable SME defaults (per user, org keys). +-- Only touch rows still at the previous factory values. + +UPDATE ai_cost_policies +SET daily_limit_micro_eur = 2000000, + monthly_limit_micro_eur = 35000000, + updated_at = NOW() +WHERE scope_type = 'org' + AND daily_limit_micro_eur = 10000000 + AND monthly_limit_micro_eur = 100000000; + +UPDATE org_settings +SET settings = jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + jsonb_set( + COALESCE(settings, '{}'::jsonb), + '{usage_quotas,llm_daily_cost_limit_eur}', + '2'::jsonb + ), + '{usage_quotas,llm_monthly_cost_limit_eur}', + '35'::jsonb + ), + '{usage_quotas,llm_requests_per_day}', + '75'::jsonb + ), + '{usage_quotas,llm_tokens_per_month}', + '2000000'::jsonb + ), + '{usage_quotas,search_requests_per_day}', + '20'::jsonb + ), + '{usage_quotas,max_api_tokens_per_user}', + '5'::jsonb + ), + '{usage_quotas,max_webhooks_per_user}', + '5'::jsonb +), +updated_at = NOW() +WHERE id = 1 + AND ( + settings->'usage_quotas' IS NULL + OR settings->'usage_quotas' = 'null'::jsonb + OR ( + COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 10) = 10 + AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 100) = 100 + AND COALESCE((settings->'usage_quotas'->>'llm_requests_per_day')::numeric, 100) = 100 + AND COALESCE((settings->'usage_quotas'->>'llm_tokens_per_month')::numeric, 500000) = 500000 + AND COALESCE((settings->'usage_quotas'->>'search_requests_per_day')::numeric, 50) = 50 + AND COALESCE((settings->'usage_quotas'->>'max_api_tokens_per_user')::numeric, 10) = 10 + AND COALESCE((settings->'usage_quotas'->>'max_webhooks_per_user')::numeric, 20) = 20 + ) + );