- Refactored AI gateway to utilize new cost management structures for usage tracking. - Replaced deprecated token extraction methods with a unified cost parsing approach. - Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses. - Added new metering functionality to record AI usage and costs effectively. - Updated tests to reflect changes in usage parsing and cost calculations. - Introduced new API endpoints for retrieving AI usage summaries and pricing information.
422 lines
12 KiB
Go
422 lines
12 KiB
Go
package cost
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
var ErrCostLimitExceeded = errors.New("llm cost limit exceeded")
|
|
|
|
type EffectiveLimits struct {
|
|
DailyLimitMicroEUR *int64
|
|
MonthlyLimitMicroEUR *int64
|
|
WarnThresholdPct int
|
|
}
|
|
|
|
type SpendStatus struct {
|
|
CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"`
|
|
CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"`
|
|
CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"`
|
|
CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"`
|
|
CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"`
|
|
CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"`
|
|
WarnThresholdPct int `json:"warn_threshold_pct"`
|
|
Currency string `json:"currency"`
|
|
BillingScopeOrg bool `json:"billing_scope_org"`
|
|
ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"`
|
|
|
|
// Legacy fields (deprecated)
|
|
RequestsUsedToday int `json:"requests_used_today"`
|
|
RequestsLimit int `json:"requests_limit"`
|
|
TokensUsedMonth int64 `json:"tokens_used_month"`
|
|
TokensLimit int64 `json:"tokens_limit"`
|
|
RequestsRemaining int `json:"requests_remaining"`
|
|
TokensRemaining int64 `json:"tokens_remaining"`
|
|
}
|
|
|
|
type ProviderKeySpend struct {
|
|
Fingerprint string `json:"fingerprint"`
|
|
Label string `json:"label"`
|
|
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
|
|
CostMonthEUR float64 `json:"cost_month_eur"`
|
|
BillingScope string `json:"billing_scope"`
|
|
}
|
|
|
|
type PolicyService struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewPolicyService(db *pgxpool.Pool) *PolicyService {
|
|
return &PolicyService{db: db}
|
|
}
|
|
|
|
type policyCandidate struct {
|
|
daily *int64
|
|
monthly *int64
|
|
warn int
|
|
priority int
|
|
}
|
|
|
|
func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) {
|
|
limits := EffectiveLimits{WarnThresholdPct: 80}
|
|
if s.db == nil {
|
|
return limits, nil
|
|
}
|
|
|
|
var candidates []policyCandidate
|
|
|
|
// Org policy (priority 0)
|
|
org, err := s.loadPolicy(ctx, "org", "")
|
|
if err == nil && org != nil {
|
|
candidates = append(candidates, policyCandidate{
|
|
daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0,
|
|
})
|
|
}
|
|
|
|
// Group policies (priority 10) — most restrictive wins among groups
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct
|
|
FROM ai_cost_policies p
|
|
JOIN user_group_members ugm ON ugm.group_id = p.scope_id
|
|
WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid
|
|
`, userID)
|
|
if err == nil {
|
|
defer rows.Close()
|
|
for rows.Next() {
|
|
var c policyCandidate
|
|
c.priority = 10
|
|
if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil {
|
|
continue
|
|
}
|
|
candidates = append(candidates, c)
|
|
}
|
|
}
|
|
|
|
// User policy (priority 20)
|
|
userPol, err := s.loadPolicy(ctx, "user", userID)
|
|
if err == nil && userPol != nil {
|
|
candidates = append(candidates, policyCandidate{
|
|
daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20,
|
|
})
|
|
}
|
|
|
|
if len(candidates) == 0 {
|
|
return limits, nil
|
|
}
|
|
|
|
// User override wins if present; else merge group+org with most restrictive
|
|
var userCandidate *policyCandidate
|
|
var others []policyCandidate
|
|
for i := range candidates {
|
|
if candidates[i].priority >= 20 {
|
|
userCandidate = &candidates[i]
|
|
} else {
|
|
others = append(others, candidates[i])
|
|
}
|
|
}
|
|
|
|
var merged policyCandidate
|
|
if userCandidate != nil {
|
|
merged = *userCandidate
|
|
} else {
|
|
merged = mergeMostRestrictive(others)
|
|
}
|
|
|
|
limits.DailyLimitMicroEUR = merged.daily
|
|
limits.MonthlyLimitMicroEUR = merged.monthly
|
|
if merged.warn > 0 {
|
|
limits.WarnThresholdPct = merged.warn
|
|
}
|
|
return limits, nil
|
|
}
|
|
|
|
type policyRow struct {
|
|
daily *int64
|
|
monthly *int64
|
|
warn int
|
|
}
|
|
|
|
func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) {
|
|
var daily, monthly *int64
|
|
var warn int
|
|
var err error
|
|
if scopeType == "org" {
|
|
err = s.db.QueryRow(ctx, `
|
|
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
|
|
FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1
|
|
`).Scan(&daily, &monthly, &warn)
|
|
} else {
|
|
err = s.db.QueryRow(ctx, `
|
|
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
|
|
FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
|
|
`, scopeType, scopeID).Scan(&daily, &monthly, &warn)
|
|
}
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil
|
|
}
|
|
|
|
func mergeMostRestrictive(candidates []policyCandidate) policyCandidate {
|
|
if len(candidates) == 0 {
|
|
return policyCandidate{warn: 80}
|
|
}
|
|
out := candidates[0]
|
|
for _, c := range candidates[1:] {
|
|
if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) {
|
|
out.daily = c.daily
|
|
}
|
|
if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) {
|
|
out.monthly = c.monthly
|
|
}
|
|
if c.warn > 0 && c.warn < out.warn {
|
|
out.warn = c.warn
|
|
}
|
|
}
|
|
if out.warn == 0 {
|
|
out.warn = 80
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) {
|
|
userID, err := s.resolveUserID(ctx, externalUserID)
|
|
if err != nil {
|
|
return SpendStatus{}, err
|
|
}
|
|
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
|
|
var dailyOrg, dailyUser int64
|
|
var requestsToday int
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0)
|
|
FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2
|
|
`, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday)
|
|
|
|
var monthlyOrg, monthlyUser, tokensMonth int64
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0)
|
|
FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2
|
|
`, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth)
|
|
|
|
limits, _ := s.ResolveEffectiveLimits(ctx, userID)
|
|
|
|
status := SpendStatus{
|
|
Currency: "EUR",
|
|
BillingScopeOrg: orgBilling,
|
|
WarnThresholdPct: limits.WarnThresholdPct,
|
|
RequestsUsedToday: requestsToday,
|
|
TokensUsedMonth: tokensMonth,
|
|
}
|
|
|
|
if orgBilling {
|
|
status.CostUsedTodayMicroEUR = dailyOrg
|
|
status.CostUsedMonthMicroEUR = monthlyOrg
|
|
status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR
|
|
status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR
|
|
} else {
|
|
status.CostUsedTodayMicroEUR = dailyUser
|
|
status.CostUsedMonthMicroEUR = monthlyUser
|
|
}
|
|
|
|
status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR)
|
|
status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR)
|
|
|
|
keys, _ := s.providerKeyBreakdown(ctx, userID, month)
|
|
if keys == nil {
|
|
keys = []ProviderKeySpend{}
|
|
}
|
|
status.ByProviderKeys = keys
|
|
|
|
return status, nil
|
|
}
|
|
|
|
func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error {
|
|
if billingScope == BillingScopeUser {
|
|
return nil
|
|
}
|
|
userID, err := s.resolveUserID(ctx, externalUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
limits, err := s.ResolveEffectiveLimits(ctx, userID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
|
|
var dailyOrg, monthlyOrg int64
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily
|
|
WHERE user_id = $1::uuid AND usage_date = $2
|
|
`, userID, today).Scan(&dailyOrg)
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly
|
|
WHERE user_id = $1::uuid AND usage_month = $2
|
|
`, userID, month).Scan(&monthlyOrg)
|
|
|
|
if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR {
|
|
return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded)
|
|
}
|
|
if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR {
|
|
return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) {
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT provider_key_fingerprint, billing_scope,
|
|
COALESCE(SUM(cost_micro_eur), 0),
|
|
MAX(provider_id)
|
|
FROM ai_usage_events
|
|
WHERE user_id = $1::uuid AND created_at >= $2
|
|
GROUP BY provider_key_fingerprint, billing_scope
|
|
ORDER BY SUM(cost_micro_eur) DESC
|
|
LIMIT 20
|
|
`, userID, month)
|
|
if err != nil {
|
|
return []ProviderKeySpend{}, err
|
|
}
|
|
defer rows.Close()
|
|
out := make([]ProviderKeySpend, 0)
|
|
for rows.Next() {
|
|
var item ProviderKeySpend
|
|
var providerID string
|
|
if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil {
|
|
continue
|
|
}
|
|
item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR)
|
|
suffix := item.Fingerprint
|
|
if len(suffix) > 8 {
|
|
suffix = suffix[len(suffix)-8:]
|
|
}
|
|
item.Label = providerID + " ···" + suffix
|
|
out = append(out, item)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
|
var userID string
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT id::text FROM users WHERE external_id = $1
|
|
`, externalUserID).Scan(&userID)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return "", fmt.Errorf("user not found")
|
|
}
|
|
return "", err
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
func remaining(limit *int64, used int64) *int64 {
|
|
if limit == nil || *limit <= 0 {
|
|
return nil
|
|
}
|
|
r := *limit - used
|
|
if r < 0 {
|
|
r = 0
|
|
}
|
|
return &r
|
|
}
|
|
|
|
// UpsertOrgPolicy updates the org-level cost policy.
|
|
func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error {
|
|
if warnPct <= 0 {
|
|
warnPct = 80
|
|
}
|
|
var exists bool
|
|
if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
_, err := s.db.Exec(ctx, `
|
|
UPDATE ai_cost_policies SET
|
|
daily_limit_micro_eur = $1,
|
|
monthly_limit_micro_eur = $2,
|
|
warn_threshold_pct = $3,
|
|
updated_at = NOW()
|
|
WHERE scope_type = 'org'
|
|
`, daily, monthly, warnPct)
|
|
return err
|
|
}
|
|
_, err := s.db.Exec(ctx, `
|
|
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
|
|
VALUES ('org', NULL, $1, $2, $3, 0)
|
|
`, daily, monthly, warnPct)
|
|
return err
|
|
}
|
|
|
|
// UpsertScopePolicy updates group or user policy.
|
|
func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error {
|
|
if warnPct <= 0 {
|
|
warnPct = 80
|
|
}
|
|
priority := 10
|
|
if scopeType == "user" {
|
|
priority = 20
|
|
}
|
|
var exists bool
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid)
|
|
`, scopeType, scopeID).Scan(&exists)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if exists {
|
|
_, err = s.db.Exec(ctx, `
|
|
UPDATE ai_cost_policies SET
|
|
daily_limit_micro_eur = $1,
|
|
monthly_limit_micro_eur = $2,
|
|
warn_threshold_pct = $3,
|
|
updated_at = NOW()
|
|
WHERE scope_type = $4 AND scope_id = $5::uuid
|
|
`, daily, monthly, warnPct, scopeType, scopeID)
|
|
return err
|
|
}
|
|
_, err = s.db.Exec(ctx, `
|
|
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
|
|
VALUES ($1, $2::uuid, $3, $4, $5, $6)
|
|
`, scopeType, scopeID, daily, monthly, warnPct, priority)
|
|
return err
|
|
}
|
|
|
|
func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error {
|
|
if scopeType == "org" {
|
|
return fmt.Errorf("cannot delete org policy")
|
|
}
|
|
_, err := s.db.Exec(ctx, `
|
|
DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
|
|
`, scopeType, scopeID)
|
|
return err
|
|
}
|
|
|
|
func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) {
|
|
row, err := s.loadPolicy(ctx, "org", "")
|
|
if err != nil {
|
|
return EffectiveLimits{WarnThresholdPct: 80}, err
|
|
}
|
|
if row == nil {
|
|
return EffectiveLimits{WarnThresholdPct: 80}, nil
|
|
}
|
|
return EffectiveLimits{
|
|
DailyLimitMicroEUR: row.daily,
|
|
MonthlyLimitMicroEUR: row.monthly,
|
|
WarnThresholdPct: row.warn,
|
|
}, nil
|
|
}
|