package cost import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) var ErrCostLimitExceeded = errors.New("llm cost limit exceeded") type EffectiveLimits struct { DailyLimitMicroEUR *int64 MonthlyLimitMicroEUR *int64 WarnThresholdPct int } type SpendStatus struct { CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"` CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"` CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"` CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"` CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"` CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"` WarnThresholdPct int `json:"warn_threshold_pct"` Currency string `json:"currency"` BillingScopeOrg bool `json:"billing_scope_org"` ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"` // Legacy fields (deprecated) RequestsUsedToday int `json:"requests_used_today"` RequestsLimit int `json:"requests_limit"` TokensUsedMonth int64 `json:"tokens_used_month"` TokensLimit int64 `json:"tokens_limit"` RequestsRemaining int `json:"requests_remaining"` TokensRemaining int64 `json:"tokens_remaining"` } type ProviderKeySpend struct { Fingerprint string `json:"fingerprint"` Label string `json:"label"` CostMonthMicroEUR int64 `json:"cost_month_micro_eur"` CostMonthEUR float64 `json:"cost_month_eur"` BillingScope string `json:"billing_scope"` } type PolicyService struct { db *pgxpool.Pool } func NewPolicyService(db *pgxpool.Pool) *PolicyService { return &PolicyService{db: db} } type policyCandidate struct { daily *int64 monthly *int64 warn int priority int } func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) { limits := EffectiveLimits{WarnThresholdPct: 80} if s.db == nil { return limits, nil } var candidates []policyCandidate // Org policy (priority 0) org, err := s.loadPolicy(ctx, "org", "") if err == nil && org != nil { candidates = append(candidates, policyCandidate{ daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0, }) } // Group policies (priority 10) — most restrictive wins among groups rows, err := s.db.Query(ctx, ` SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct FROM ai_cost_policies p JOIN user_group_members ugm ON ugm.group_id = p.scope_id WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid `, userID) if err == nil { defer rows.Close() for rows.Next() { var c policyCandidate c.priority = 10 if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil { continue } candidates = append(candidates, c) } } // User policy (priority 20) userPol, err := s.loadPolicy(ctx, "user", userID) if err == nil && userPol != nil { candidates = append(candidates, policyCandidate{ daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20, }) } if len(candidates) == 0 { return limits, nil } // User override wins if present; else merge group+org with most restrictive var userCandidate *policyCandidate var others []policyCandidate for i := range candidates { if candidates[i].priority >= 20 { userCandidate = &candidates[i] } else { others = append(others, candidates[i]) } } var merged policyCandidate if userCandidate != nil { merged = *userCandidate } else { merged = mergeMostRestrictive(others) } limits.DailyLimitMicroEUR = merged.daily limits.MonthlyLimitMicroEUR = merged.monthly if merged.warn > 0 { limits.WarnThresholdPct = merged.warn } return limits, nil } type policyRow struct { daily *int64 monthly *int64 warn int } func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) { var daily, monthly *int64 var warn int var err error if scopeType == "org" { err = s.db.QueryRow(ctx, ` SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1 `).Scan(&daily, &monthly, &warn) } else { err = s.db.QueryRow(ctx, ` SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid `, scopeType, scopeID).Scan(&daily, &monthly, &warn) } if err != nil { if err == pgx.ErrNoRows { return nil, nil } return nil, err } return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil } func mergeMostRestrictive(candidates []policyCandidate) policyCandidate { if len(candidates) == 0 { return policyCandidate{warn: 80} } out := candidates[0] for _, c := range candidates[1:] { if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) { out.daily = c.daily } if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) { out.monthly = c.monthly } if c.warn > 0 && c.warn < out.warn { out.warn = c.warn } } if out.warn == 0 { out.warn = 80 } return out } func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) { userID, err := s.resolveUserID(ctx, externalUserID) if err != nil { return SpendStatus{}, err } today := time.Now().UTC().Truncate(24 * time.Hour) month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) var dailyOrg, dailyUser int64 var requestsToday int _ = s.db.QueryRow(ctx, ` SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0) FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2 `, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday) var monthlyOrg, monthlyUser, tokensMonth int64 _ = s.db.QueryRow(ctx, ` SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0) FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2 `, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth) limits, _ := s.ResolveEffectiveLimits(ctx, userID) status := SpendStatus{ Currency: "EUR", BillingScopeOrg: orgBilling, WarnThresholdPct: limits.WarnThresholdPct, RequestsUsedToday: requestsToday, TokensUsedMonth: tokensMonth, } if orgBilling { status.CostUsedTodayMicroEUR = dailyOrg status.CostUsedMonthMicroEUR = monthlyOrg status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR } else { status.CostUsedTodayMicroEUR = dailyUser status.CostUsedMonthMicroEUR = monthlyUser } status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR) status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR) keys, _ := s.providerKeyBreakdown(ctx, userID, month) if keys == nil { keys = []ProviderKeySpend{} } status.ByProviderKeys = keys return status, nil } func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error { if billingScope == BillingScopeUser { return nil } userID, err := s.resolveUserID(ctx, externalUserID) if err != nil { return err } limits, err := s.ResolveEffectiveLimits(ctx, userID) if err != nil { return err } today := time.Now().UTC().Truncate(24 * time.Hour) month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) var dailyOrg, monthlyOrg int64 _ = s.db.QueryRow(ctx, ` SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2 `, userID, today).Scan(&dailyOrg) _ = s.db.QueryRow(ctx, ` SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2 `, userID, month).Scan(&monthlyOrg) if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR { return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded) } if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR { return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded) } return nil } func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) { rows, err := s.db.Query(ctx, ` SELECT provider_key_fingerprint, billing_scope, COALESCE(SUM(cost_micro_eur), 0), MAX(provider_id) FROM ai_usage_events WHERE user_id = $1::uuid AND created_at >= $2 GROUP BY provider_key_fingerprint, billing_scope ORDER BY SUM(cost_micro_eur) DESC LIMIT 20 `, userID, month) if err != nil { return []ProviderKeySpend{}, err } defer rows.Close() out := make([]ProviderKeySpend, 0) for rows.Next() { var item ProviderKeySpend var providerID string if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil { continue } item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR) suffix := item.Fingerprint if len(suffix) > 8 { suffix = suffix[len(suffix)-8:] } item.Label = providerID + " ···" + suffix out = append(out, item) } return out, rows.Err() } func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) { var userID string err := s.db.QueryRow(ctx, ` SELECT id::text FROM users WHERE external_id = $1 `, externalUserID).Scan(&userID) if err != nil { if err == pgx.ErrNoRows { return "", fmt.Errorf("user not found") } return "", err } return userID, nil } func remaining(limit *int64, used int64) *int64 { if limit == nil || *limit <= 0 { return nil } r := *limit - used if r < 0 { r = 0 } return &r } // UpsertOrgPolicy updates the org-level cost policy. func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error { if warnPct <= 0 { warnPct = 80 } var exists bool if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil { return err } if exists { _, err := s.db.Exec(ctx, ` UPDATE ai_cost_policies SET daily_limit_micro_eur = $1, monthly_limit_micro_eur = $2, warn_threshold_pct = $3, updated_at = NOW() WHERE scope_type = 'org' `, daily, monthly, warnPct) return err } _, err := s.db.Exec(ctx, ` INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority) VALUES ('org', NULL, $1, $2, $3, 0) `, daily, monthly, warnPct) return err } // UpsertScopePolicy updates group or user policy. func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error { if warnPct <= 0 { warnPct = 80 } priority := 10 if scopeType == "user" { priority = 20 } var exists bool err := s.db.QueryRow(ctx, ` SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid) `, scopeType, scopeID).Scan(&exists) if err != nil { return err } if exists { _, err = s.db.Exec(ctx, ` UPDATE ai_cost_policies SET daily_limit_micro_eur = $1, monthly_limit_micro_eur = $2, warn_threshold_pct = $3, updated_at = NOW() WHERE scope_type = $4 AND scope_id = $5::uuid `, daily, monthly, warnPct, scopeType, scopeID) return err } _, err = s.db.Exec(ctx, ` INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority) VALUES ($1, $2::uuid, $3, $4, $5, $6) `, scopeType, scopeID, daily, monthly, warnPct, priority) return err } func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error { if scopeType == "org" { return fmt.Errorf("cannot delete org policy") } _, err := s.db.Exec(ctx, ` DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid `, scopeType, scopeID) return err } func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) { row, err := s.loadPolicy(ctx, "org", "") if err != nil { return EffectiveLimits{WarnThresholdPct: 80}, err } if row == nil { return EffectiveLimits{WarnThresholdPct: 80}, nil } return EffectiveLimits{ DailyLimitMicroEUR: row.daily, MonthlyLimitMicroEUR: row.monthly, WarnThresholdPct: row.warn, }, nil }