ultisuite-backend/internal/ai/cost/policy.go
R3D347HR4Y 3978622050
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
refactor(ai): update AI gateway and cost management features
- Refactored AI gateway to utilize new cost management structures for usage tracking.
- Replaced deprecated token extraction methods with a unified cost parsing approach.
- Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses.
- Added new metering functionality to record AI usage and costs effectively.
- Updated tests to reflect changes in usage parsing and cost calculations.
- Introduced new API endpoints for retrieving AI usage summaries and pricing information.
2026-06-16 10:46:33 +02:00

422 lines
12 KiB
Go

package cost
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
var ErrCostLimitExceeded = errors.New("llm cost limit exceeded")
type EffectiveLimits struct {
DailyLimitMicroEUR *int64
MonthlyLimitMicroEUR *int64
WarnThresholdPct int
}
type SpendStatus struct {
CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"`
CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"`
CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"`
CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"`
CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"`
CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"`
WarnThresholdPct int `json:"warn_threshold_pct"`
Currency string `json:"currency"`
BillingScopeOrg bool `json:"billing_scope_org"`
ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"`
// Legacy fields (deprecated)
RequestsUsedToday int `json:"requests_used_today"`
RequestsLimit int `json:"requests_limit"`
TokensUsedMonth int64 `json:"tokens_used_month"`
TokensLimit int64 `json:"tokens_limit"`
RequestsRemaining int `json:"requests_remaining"`
TokensRemaining int64 `json:"tokens_remaining"`
}
type ProviderKeySpend struct {
Fingerprint string `json:"fingerprint"`
Label string `json:"label"`
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
CostMonthEUR float64 `json:"cost_month_eur"`
BillingScope string `json:"billing_scope"`
}
type PolicyService struct {
db *pgxpool.Pool
}
func NewPolicyService(db *pgxpool.Pool) *PolicyService {
return &PolicyService{db: db}
}
type policyCandidate struct {
daily *int64
monthly *int64
warn int
priority int
}
func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) {
limits := EffectiveLimits{WarnThresholdPct: 80}
if s.db == nil {
return limits, nil
}
var candidates []policyCandidate
// Org policy (priority 0)
org, err := s.loadPolicy(ctx, "org", "")
if err == nil && org != nil {
candidates = append(candidates, policyCandidate{
daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0,
})
}
// Group policies (priority 10) — most restrictive wins among groups
rows, err := s.db.Query(ctx, `
SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct
FROM ai_cost_policies p
JOIN user_group_members ugm ON ugm.group_id = p.scope_id
WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid
`, userID)
if err == nil {
defer rows.Close()
for rows.Next() {
var c policyCandidate
c.priority = 10
if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil {
continue
}
candidates = append(candidates, c)
}
}
// User policy (priority 20)
userPol, err := s.loadPolicy(ctx, "user", userID)
if err == nil && userPol != nil {
candidates = append(candidates, policyCandidate{
daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20,
})
}
if len(candidates) == 0 {
return limits, nil
}
// User override wins if present; else merge group+org with most restrictive
var userCandidate *policyCandidate
var others []policyCandidate
for i := range candidates {
if candidates[i].priority >= 20 {
userCandidate = &candidates[i]
} else {
others = append(others, candidates[i])
}
}
var merged policyCandidate
if userCandidate != nil {
merged = *userCandidate
} else {
merged = mergeMostRestrictive(others)
}
limits.DailyLimitMicroEUR = merged.daily
limits.MonthlyLimitMicroEUR = merged.monthly
if merged.warn > 0 {
limits.WarnThresholdPct = merged.warn
}
return limits, nil
}
type policyRow struct {
daily *int64
monthly *int64
warn int
}
func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) {
var daily, monthly *int64
var warn int
var err error
if scopeType == "org" {
err = s.db.QueryRow(ctx, `
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1
`).Scan(&daily, &monthly, &warn)
} else {
err = s.db.QueryRow(ctx, `
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
`, scopeType, scopeID).Scan(&daily, &monthly, &warn)
}
if err != nil {
if err == pgx.ErrNoRows {
return nil, nil
}
return nil, err
}
return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil
}
func mergeMostRestrictive(candidates []policyCandidate) policyCandidate {
if len(candidates) == 0 {
return policyCandidate{warn: 80}
}
out := candidates[0]
for _, c := range candidates[1:] {
if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) {
out.daily = c.daily
}
if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) {
out.monthly = c.monthly
}
if c.warn > 0 && c.warn < out.warn {
out.warn = c.warn
}
}
if out.warn == 0 {
out.warn = 80
}
return out
}
func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) {
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return SpendStatus{}, err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
var dailyOrg, dailyUser int64
var requestsToday int
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0)
FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2
`, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday)
var monthlyOrg, monthlyUser, tokensMonth int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0)
FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2
`, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth)
limits, _ := s.ResolveEffectiveLimits(ctx, userID)
status := SpendStatus{
Currency: "EUR",
BillingScopeOrg: orgBilling,
WarnThresholdPct: limits.WarnThresholdPct,
RequestsUsedToday: requestsToday,
TokensUsedMonth: tokensMonth,
}
if orgBilling {
status.CostUsedTodayMicroEUR = dailyOrg
status.CostUsedMonthMicroEUR = monthlyOrg
status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR
status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR
} else {
status.CostUsedTodayMicroEUR = dailyUser
status.CostUsedMonthMicroEUR = monthlyUser
}
status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR)
status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR)
keys, _ := s.providerKeyBreakdown(ctx, userID, month)
if keys == nil {
keys = []ProviderKeySpend{}
}
status.ByProviderKeys = keys
return status, nil
}
func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error {
if billingScope == BillingScopeUser {
return nil
}
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return err
}
limits, err := s.ResolveEffectiveLimits(ctx, userID)
if err != nil {
return err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
var dailyOrg, monthlyOrg int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily
WHERE user_id = $1::uuid AND usage_date = $2
`, userID, today).Scan(&dailyOrg)
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly
WHERE user_id = $1::uuid AND usage_month = $2
`, userID, month).Scan(&monthlyOrg)
if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR {
return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded)
}
if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR {
return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded)
}
return nil
}
func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) {
rows, err := s.db.Query(ctx, `
SELECT provider_key_fingerprint, billing_scope,
COALESCE(SUM(cost_micro_eur), 0),
MAX(provider_id)
FROM ai_usage_events
WHERE user_id = $1::uuid AND created_at >= $2
GROUP BY provider_key_fingerprint, billing_scope
ORDER BY SUM(cost_micro_eur) DESC
LIMIT 20
`, userID, month)
if err != nil {
return []ProviderKeySpend{}, err
}
defer rows.Close()
out := make([]ProviderKeySpend, 0)
for rows.Next() {
var item ProviderKeySpend
var providerID string
if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil {
continue
}
item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR)
suffix := item.Fingerprint
if len(suffix) > 8 {
suffix = suffix[len(suffix)-8:]
}
item.Label = providerID + " ···" + suffix
out = append(out, item)
}
return out, rows.Err()
}
func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
var userID string
err := s.db.QueryRow(ctx, `
SELECT id::text FROM users WHERE external_id = $1
`, externalUserID).Scan(&userID)
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("user not found")
}
return "", err
}
return userID, nil
}
func remaining(limit *int64, used int64) *int64 {
if limit == nil || *limit <= 0 {
return nil
}
r := *limit - used
if r < 0 {
r = 0
}
return &r
}
// UpsertOrgPolicy updates the org-level cost policy.
func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error {
if warnPct <= 0 {
warnPct = 80
}
var exists bool
if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil {
return err
}
if exists {
_, err := s.db.Exec(ctx, `
UPDATE ai_cost_policies SET
daily_limit_micro_eur = $1,
monthly_limit_micro_eur = $2,
warn_threshold_pct = $3,
updated_at = NOW()
WHERE scope_type = 'org'
`, daily, monthly, warnPct)
return err
}
_, err := s.db.Exec(ctx, `
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
VALUES ('org', NULL, $1, $2, $3, 0)
`, daily, monthly, warnPct)
return err
}
// UpsertScopePolicy updates group or user policy.
func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error {
if warnPct <= 0 {
warnPct = 80
}
priority := 10
if scopeType == "user" {
priority = 20
}
var exists bool
err := s.db.QueryRow(ctx, `
SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid)
`, scopeType, scopeID).Scan(&exists)
if err != nil {
return err
}
if exists {
_, err = s.db.Exec(ctx, `
UPDATE ai_cost_policies SET
daily_limit_micro_eur = $1,
monthly_limit_micro_eur = $2,
warn_threshold_pct = $3,
updated_at = NOW()
WHERE scope_type = $4 AND scope_id = $5::uuid
`, daily, monthly, warnPct, scopeType, scopeID)
return err
}
_, err = s.db.Exec(ctx, `
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
VALUES ($1, $2::uuid, $3, $4, $5, $6)
`, scopeType, scopeID, daily, monthly, warnPct, priority)
return err
}
func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error {
if scopeType == "org" {
return fmt.Errorf("cannot delete org policy")
}
_, err := s.db.Exec(ctx, `
DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
`, scopeType, scopeID)
return err
}
func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) {
row, err := s.loadPolicy(ctx, "org", "")
if err != nil {
return EffectiveLimits{WarnThresholdPct: 80}, err
}
if row == nil {
return EffectiveLimits{WarnThresholdPct: 80}, nil
}
return EffectiveLimits{
DailyLimitMicroEUR: row.daily,
MonthlyLimitMicroEUR: row.monthly,
WarnThresholdPct: row.warn,
}, nil
}