refactor(ai): update AI gateway and cost management features
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run

- Refactored AI gateway to utilize new cost management structures for usage tracking.
- Replaced deprecated token extraction methods with a unified cost parsing approach.
- Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses.
- Added new metering functionality to record AI usage and costs effectively.
- Updated tests to reflect changes in usage parsing and cost calculations.
- Introduced new API endpoints for retrieving AI usage summaries and pricing information.
This commit is contained in:
R3D347HR4Y 2026-06-16 10:46:33 +02:00
parent 71b716edba
commit 3978622050
29 changed files with 1993 additions and 203 deletions

View File

@ -0,0 +1,18 @@
package cost
import (
"crypto/sha256"
"encoding/hex"
"strings"
)
// KeyFingerprint returns a stable short hash for grouping usage by API key.
func KeyFingerprint(providerID, apiKey string) string {
providerID = strings.TrimSpace(providerID)
apiKey = strings.TrimSpace(apiKey)
if apiKey == "" {
return providerID + ":none"
}
sum := sha256.Sum256([]byte(providerID + ":" + apiKey))
return hex.EncodeToString(sum[:8])
}

149
internal/ai/cost/meter.go Normal file
View File

@ -0,0 +1,149 @@
package cost
import (
"context"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
BillingScopeOrg = "org"
BillingScopeUser = "user"
)
type RecordInput struct {
ExternalUserID string
Feature string
ModelID string
ProviderID string
BillingScope string
ProviderKeyFingerprint string
Usage UsageDetail
RequestID string
}
type Meter struct {
db *pgxpool.Pool
pricing *PricingStore
}
func NewMeter(db *pgxpool.Pool) *Meter {
return &Meter{db: db, pricing: NewPricingStore(db)}
}
func (m *Meter) RecordUsage(ctx context.Context, in RecordInput) error {
if m.db == nil {
return nil
}
userID, err := m.resolveUserID(ctx, in.ExternalUserID)
if err != nil {
return err
}
price, found := m.pricing.LookupPrice(ctx, in.ModelID)
cost, estimated := ComputeCostMicroEUR(in.Usage, price, found)
tx, err := m.db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, `
INSERT INTO ai_usage_events (
user_id, feature, model_id, provider_id, billing_scope,
provider_key_fingerprint, prompt_tokens, completion_tokens,
cached_input_tokens, reasoning_tokens, cost_micro_eur, estimated, request_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
`, userID, in.Feature, in.ModelID, in.ProviderID, in.BillingScope,
in.ProviderKeyFingerprint, in.Usage.PromptTokens, in.Usage.CompletionTokens,
in.Usage.CachedInputTokens, in.Usage.ReasoningTokens, cost, estimated, nullStr(in.RequestID))
if err != nil {
return err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
tokens := int64(in.Usage.TotalTokens)
orgCost := int64(0)
userCost := int64(0)
if in.BillingScope == BillingScopeOrg {
orgCost = cost
} else {
userCost = cost
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens, cost_micro_eur_org, cost_micro_eur_user)
VALUES ($1, $2, 1, $3, $4, $5)
ON CONFLICT (user_id, usage_date) DO UPDATE SET
requests = ai_usage_daily.requests + 1,
tokens = ai_usage_daily.tokens + EXCLUDED.tokens,
cost_micro_eur_org = ai_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
cost_micro_eur_user = ai_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
`, userID, today, tokens, orgCost, userCost)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens, cost_micro_eur_org, cost_micro_eur_user)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (user_id, usage_month) DO UPDATE SET
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens,
cost_micro_eur_org = ai_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
cost_micro_eur_user = ai_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
`, userID, month, tokens, orgCost, userCost)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_org_usage_daily (usage_date, cost_micro_eur_org, cost_micro_eur_user, requests)
VALUES ($1, $2, $3, 1)
ON CONFLICT (usage_date) DO UPDATE SET
cost_micro_eur_org = ai_org_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
cost_micro_eur_user = ai_org_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user,
requests = ai_org_usage_daily.requests + 1
`, today, orgCost, userCost)
if err != nil {
return err
}
_, err = tx.Exec(ctx, `
INSERT INTO ai_org_usage_monthly (usage_month, cost_micro_eur_org, cost_micro_eur_user)
VALUES ($1, $2, $3)
ON CONFLICT (usage_month) DO UPDATE SET
cost_micro_eur_org = ai_org_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
cost_micro_eur_user = ai_org_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
`, month, orgCost, userCost)
if err != nil {
return err
}
return tx.Commit(ctx)
}
func (m *Meter) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
var userID string
err := m.db.QueryRow(ctx, `
SELECT id::text FROM users WHERE external_id = $1
`, externalUserID).Scan(&userID)
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("user not found")
}
return "", err
}
return userID, nil
}
func nullStr(s string) *string {
if s == "" {
return nil
}
return &s
}

65
internal/ai/cost/parse.go Normal file
View File

@ -0,0 +1,65 @@
package cost
import "encoding/json"
type usagePayload struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
}
type chatCompletionResponse struct {
Usage *usagePayload `json:"usage,omitempty"`
}
// ParseUsage extracts token details from an OpenAI-compatible chat completion payload.
func ParseUsage(payload []byte) UsageDetail {
var parsed chatCompletionResponse
if err := json.Unmarshal(payload, &parsed); err != nil || parsed.Usage == nil {
return UsageDetail{TotalTokens: 1}
}
return usageFromPayload(parsed.Usage)
}
func usageFromPayload(u *usagePayload) UsageDetail {
if u == nil {
return UsageDetail{TotalTokens: 1}
}
detail := UsageDetail{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptTokensDetails != nil {
detail.CachedInputTokens = u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil {
detail.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.PromptTokens + detail.CompletionTokens
}
if detail.TotalTokens == 0 {
detail.TotalTokens = 1
}
return detail
}
// MergeStreamUsage keeps the latest non-zero usage from streaming chunks.
func MergeStreamUsage(acc UsageDetail, chunk []byte) UsageDetail {
var parsed chatCompletionResponse
if err := json.Unmarshal(chunk, &parsed); err != nil || parsed.Usage == nil {
return acc
}
next := usageFromPayload(parsed.Usage)
if next.TotalTokens == 0 && next.PromptTokens == 0 && next.CompletionTokens == 0 {
return acc
}
return next
}

View File

@ -0,0 +1,72 @@
package cost
import "testing"
func TestParseUsageOpenAI(t *testing.T) {
payload := []byte(`{
"usage": {
"prompt_tokens": 1200,
"completion_tokens": 340,
"total_tokens": 1540,
"prompt_tokens_details": {"cached_tokens": 800},
"completion_tokens_details": {"reasoning_tokens": 50}
}
}`)
u := ParseUsage(payload)
if u.PromptTokens != 1200 || u.CompletionTokens != 340 {
t.Fatalf("unexpected tokens: %+v", u)
}
if u.CachedInputTokens != 800 || u.ReasoningTokens != 50 {
t.Fatalf("unexpected details: %+v", u)
}
}
func TestParseUsageFallback(t *testing.T) {
u := ParseUsage([]byte(`{"choices":[]}`))
if u.TotalTokens != 1 {
t.Fatalf("expected fallback 1, got %d", u.TotalTokens)
}
}
func TestComputeCostMicroEUR(t *testing.T) {
price := ModelPrice{
InputMicroEURPerMTok: 1_000_000,
OutputMicroEURPerMTok: 2_000_000,
}
usage := UsageDetail{
PromptTokens: 1000,
CompletionTokens: 500,
CachedInputTokens: 200,
TotalTokens: 1500,
}
cost, estimated := ComputeCostMicroEUR(usage, price, true)
if estimated {
t.Fatal("should not be estimated when price found")
}
// uncached 800 * 1 + cached 200 * 0.5 + output 500 * 2 = 800+100+1000 = 1900 micro (cached uses half input)
if cost < 1800 || cost > 2000 {
t.Fatalf("unexpected cost %d", cost)
}
}
func TestComputeCostUnknownModel(t *testing.T) {
usage := UsageDetail{TotalTokens: 1_000_000}
cost, estimated := ComputeCostMicroEUR(usage, ModelPrice{}, false)
if !estimated {
t.Fatal("expected estimated")
}
if cost != fallbackInputMicroEURPerMTok {
t.Fatalf("expected fallback cost %d, got %d", fallbackInputMicroEURPerMTok, cost)
}
}
func TestMergeStreamUsage(t *testing.T) {
acc := UsageDetail{}
chunk1 := []byte(`{"usage":{"prompt_tokens":10,"completion_tokens":0,"total_tokens":10}}`)
chunk2 := []byte(`{"usage":{"prompt_tokens":100,"completion_tokens":40,"total_tokens":140}}`)
acc = MergeStreamUsage(acc, chunk1)
acc = MergeStreamUsage(acc, chunk2)
if acc.TotalTokens != 140 {
t.Fatalf("expected 140, got %d", acc.TotalTokens)
}
}

421
internal/ai/cost/policy.go Normal file
View File

@ -0,0 +1,421 @@
package cost
import (
"context"
"errors"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
var ErrCostLimitExceeded = errors.New("llm cost limit exceeded")
type EffectiveLimits struct {
DailyLimitMicroEUR *int64
MonthlyLimitMicroEUR *int64
WarnThresholdPct int
}
type SpendStatus struct {
CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"`
CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"`
CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"`
CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"`
CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"`
CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"`
WarnThresholdPct int `json:"warn_threshold_pct"`
Currency string `json:"currency"`
BillingScopeOrg bool `json:"billing_scope_org"`
ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"`
// Legacy fields (deprecated)
RequestsUsedToday int `json:"requests_used_today"`
RequestsLimit int `json:"requests_limit"`
TokensUsedMonth int64 `json:"tokens_used_month"`
TokensLimit int64 `json:"tokens_limit"`
RequestsRemaining int `json:"requests_remaining"`
TokensRemaining int64 `json:"tokens_remaining"`
}
type ProviderKeySpend struct {
Fingerprint string `json:"fingerprint"`
Label string `json:"label"`
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
CostMonthEUR float64 `json:"cost_month_eur"`
BillingScope string `json:"billing_scope"`
}
type PolicyService struct {
db *pgxpool.Pool
}
func NewPolicyService(db *pgxpool.Pool) *PolicyService {
return &PolicyService{db: db}
}
type policyCandidate struct {
daily *int64
monthly *int64
warn int
priority int
}
func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) {
limits := EffectiveLimits{WarnThresholdPct: 80}
if s.db == nil {
return limits, nil
}
var candidates []policyCandidate
// Org policy (priority 0)
org, err := s.loadPolicy(ctx, "org", "")
if err == nil && org != nil {
candidates = append(candidates, policyCandidate{
daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0,
})
}
// Group policies (priority 10) — most restrictive wins among groups
rows, err := s.db.Query(ctx, `
SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct
FROM ai_cost_policies p
JOIN user_group_members ugm ON ugm.group_id = p.scope_id
WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid
`, userID)
if err == nil {
defer rows.Close()
for rows.Next() {
var c policyCandidate
c.priority = 10
if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil {
continue
}
candidates = append(candidates, c)
}
}
// User policy (priority 20)
userPol, err := s.loadPolicy(ctx, "user", userID)
if err == nil && userPol != nil {
candidates = append(candidates, policyCandidate{
daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20,
})
}
if len(candidates) == 0 {
return limits, nil
}
// User override wins if present; else merge group+org with most restrictive
var userCandidate *policyCandidate
var others []policyCandidate
for i := range candidates {
if candidates[i].priority >= 20 {
userCandidate = &candidates[i]
} else {
others = append(others, candidates[i])
}
}
var merged policyCandidate
if userCandidate != nil {
merged = *userCandidate
} else {
merged = mergeMostRestrictive(others)
}
limits.DailyLimitMicroEUR = merged.daily
limits.MonthlyLimitMicroEUR = merged.monthly
if merged.warn > 0 {
limits.WarnThresholdPct = merged.warn
}
return limits, nil
}
type policyRow struct {
daily *int64
monthly *int64
warn int
}
func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) {
var daily, monthly *int64
var warn int
var err error
if scopeType == "org" {
err = s.db.QueryRow(ctx, `
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1
`).Scan(&daily, &monthly, &warn)
} else {
err = s.db.QueryRow(ctx, `
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
`, scopeType, scopeID).Scan(&daily, &monthly, &warn)
}
if err != nil {
if err == pgx.ErrNoRows {
return nil, nil
}
return nil, err
}
return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil
}
func mergeMostRestrictive(candidates []policyCandidate) policyCandidate {
if len(candidates) == 0 {
return policyCandidate{warn: 80}
}
out := candidates[0]
for _, c := range candidates[1:] {
if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) {
out.daily = c.daily
}
if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) {
out.monthly = c.monthly
}
if c.warn > 0 && c.warn < out.warn {
out.warn = c.warn
}
}
if out.warn == 0 {
out.warn = 80
}
return out
}
func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) {
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return SpendStatus{}, err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
var dailyOrg, dailyUser int64
var requestsToday int
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0)
FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2
`, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday)
var monthlyOrg, monthlyUser, tokensMonth int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0)
FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2
`, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth)
limits, _ := s.ResolveEffectiveLimits(ctx, userID)
status := SpendStatus{
Currency: "EUR",
BillingScopeOrg: orgBilling,
WarnThresholdPct: limits.WarnThresholdPct,
RequestsUsedToday: requestsToday,
TokensUsedMonth: tokensMonth,
}
if orgBilling {
status.CostUsedTodayMicroEUR = dailyOrg
status.CostUsedMonthMicroEUR = monthlyOrg
status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR
status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR
} else {
status.CostUsedTodayMicroEUR = dailyUser
status.CostUsedMonthMicroEUR = monthlyUser
}
status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR)
status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR)
keys, _ := s.providerKeyBreakdown(ctx, userID, month)
if keys == nil {
keys = []ProviderKeySpend{}
}
status.ByProviderKeys = keys
return status, nil
}
func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error {
if billingScope == BillingScopeUser {
return nil
}
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return err
}
limits, err := s.ResolveEffectiveLimits(ctx, userID)
if err != nil {
return err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
var dailyOrg, monthlyOrg int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily
WHERE user_id = $1::uuid AND usage_date = $2
`, userID, today).Scan(&dailyOrg)
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly
WHERE user_id = $1::uuid AND usage_month = $2
`, userID, month).Scan(&monthlyOrg)
if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR {
return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded)
}
if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR {
return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded)
}
return nil
}
func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) {
rows, err := s.db.Query(ctx, `
SELECT provider_key_fingerprint, billing_scope,
COALESCE(SUM(cost_micro_eur), 0),
MAX(provider_id)
FROM ai_usage_events
WHERE user_id = $1::uuid AND created_at >= $2
GROUP BY provider_key_fingerprint, billing_scope
ORDER BY SUM(cost_micro_eur) DESC
LIMIT 20
`, userID, month)
if err != nil {
return []ProviderKeySpend{}, err
}
defer rows.Close()
out := make([]ProviderKeySpend, 0)
for rows.Next() {
var item ProviderKeySpend
var providerID string
if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil {
continue
}
item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR)
suffix := item.Fingerprint
if len(suffix) > 8 {
suffix = suffix[len(suffix)-8:]
}
item.Label = providerID + " ···" + suffix
out = append(out, item)
}
return out, rows.Err()
}
func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
var userID string
err := s.db.QueryRow(ctx, `
SELECT id::text FROM users WHERE external_id = $1
`, externalUserID).Scan(&userID)
if err != nil {
if err == pgx.ErrNoRows {
return "", fmt.Errorf("user not found")
}
return "", err
}
return userID, nil
}
func remaining(limit *int64, used int64) *int64 {
if limit == nil || *limit <= 0 {
return nil
}
r := *limit - used
if r < 0 {
r = 0
}
return &r
}
// UpsertOrgPolicy updates the org-level cost policy.
func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error {
if warnPct <= 0 {
warnPct = 80
}
var exists bool
if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil {
return err
}
if exists {
_, err := s.db.Exec(ctx, `
UPDATE ai_cost_policies SET
daily_limit_micro_eur = $1,
monthly_limit_micro_eur = $2,
warn_threshold_pct = $3,
updated_at = NOW()
WHERE scope_type = 'org'
`, daily, monthly, warnPct)
return err
}
_, err := s.db.Exec(ctx, `
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
VALUES ('org', NULL, $1, $2, $3, 0)
`, daily, monthly, warnPct)
return err
}
// UpsertScopePolicy updates group or user policy.
func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error {
if warnPct <= 0 {
warnPct = 80
}
priority := 10
if scopeType == "user" {
priority = 20
}
var exists bool
err := s.db.QueryRow(ctx, `
SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid)
`, scopeType, scopeID).Scan(&exists)
if err != nil {
return err
}
if exists {
_, err = s.db.Exec(ctx, `
UPDATE ai_cost_policies SET
daily_limit_micro_eur = $1,
monthly_limit_micro_eur = $2,
warn_threshold_pct = $3,
updated_at = NOW()
WHERE scope_type = $4 AND scope_id = $5::uuid
`, daily, monthly, warnPct, scopeType, scopeID)
return err
}
_, err = s.db.Exec(ctx, `
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
VALUES ($1, $2::uuid, $3, $4, $5, $6)
`, scopeType, scopeID, daily, monthly, warnPct, priority)
return err
}
func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error {
if scopeType == "org" {
return fmt.Errorf("cannot delete org policy")
}
_, err := s.db.Exec(ctx, `
DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
`, scopeType, scopeID)
return err
}
func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) {
row, err := s.loadPolicy(ctx, "org", "")
if err != nil {
return EffectiveLimits{WarnThresholdPct: 80}, err
}
if row == nil {
return EffectiveLimits{WarnThresholdPct: 80}, nil
}
return EffectiveLimits{
DailyLimitMicroEUR: row.daily,
MonthlyLimitMicroEUR: row.monthly,
WarnThresholdPct: row.warn,
}, nil
}

View File

@ -0,0 +1,31 @@
package cost
import (
"testing"
)
func TestMergeMostRestrictive(t *testing.T) {
d10 := int64(10_000_000)
d5 := int64(5_000_000)
m100 := int64(100_000_000)
m50 := int64(50_000_000)
merged := mergeMostRestrictive([]policyCandidate{
{daily: &d10, monthly: &m100, warn: 80},
{daily: &d5, monthly: &m50, warn: 70},
})
if merged.daily == nil || *merged.daily != d5 {
t.Fatalf("expected daily 5M, got %v", merged.daily)
}
if merged.monthly == nil || *merged.monthly != m50 {
t.Fatalf("expected monthly 50M, got %v", merged.monthly)
}
if merged.warn != 70 {
t.Fatalf("expected warn 70, got %d", merged.warn)
}
}
func TestMicroEURToEUR(t *testing.T) {
if v := MicroEURToEUR(1_500_000); v != 1.5 {
t.Fatalf("expected 1.5, got %f", v)
}
}

206
internal/ai/cost/pricing.go Normal file
View File

@ -0,0 +1,206 @@
package cost
import (
"context"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// ModelPrice holds per-million-token rates in micro-EUR.
type ModelPrice struct {
ModelID string
ProviderType string
InputMicroEURPerMTok int64
CachedInputMicroEURPerMTok int64
OutputMicroEURPerMTok int64
ReasoningMicroEURPerMTok int64
}
// Default fallback when model is unknown (~gpt-4o-mini input rate).
const fallbackInputMicroEURPerMTok int64 = 140000
const fallbackOutputMicroEURPerMTok int64 = 560000
type PricingStore struct {
db *pgxpool.Pool
}
func NewPricingStore(db *pgxpool.Pool) *PricingStore {
return &PricingStore{db: db}
}
func (s *PricingStore) LookupPrice(ctx context.Context, modelID string) (ModelPrice, bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" || s.db == nil {
return ModelPrice{}, false
}
var p ModelPrice
var cached, reasoning *int64
err := s.db.QueryRow(ctx, `
SELECT model_id, provider_type,
input_micro_eur_per_mtok,
cached_input_micro_eur_per_mtok,
output_micro_eur_per_mtok,
reasoning_micro_eur_per_mtok
FROM ai_model_pricing
WHERE model_id = $1 AND effective_from <= CURRENT_DATE
ORDER BY effective_from DESC
LIMIT 1
`, modelID).Scan(
&p.ModelID, &p.ProviderType,
&p.InputMicroEURPerMTok, &cached,
&p.OutputMicroEURPerMTok, &reasoning,
)
if err != nil {
if err != pgx.ErrNoRows {
// prefix match for bedrock/azure model ids
err = s.db.QueryRow(ctx, `
SELECT model_id, provider_type,
input_micro_eur_per_mtok,
cached_input_micro_eur_per_mtok,
output_micro_eur_per_mtok,
reasoning_micro_eur_per_mtok
FROM ai_model_pricing
WHERE $1 LIKE model_id || '%' AND effective_from <= CURRENT_DATE
ORDER BY LENGTH(model_id) DESC, effective_from DESC
LIMIT 1
`, modelID).Scan(
&p.ModelID, &p.ProviderType,
&p.InputMicroEURPerMTok, &cached,
&p.OutputMicroEURPerMTok, &reasoning,
)
}
if err != nil {
return ModelPrice{}, false
}
}
if cached != nil {
p.CachedInputMicroEURPerMTok = *cached
}
if reasoning != nil {
p.ReasoningMicroEURPerMTok = *reasoning
}
return p, true
}
// ComputeCostMicroEUR calculates estimated cost from usage and model pricing.
func ComputeCostMicroEUR(usage UsageDetail, price ModelPrice, found bool) (microEUR int64, estimated bool) {
if !found {
price = ModelPrice{
InputMicroEURPerMTok: fallbackInputMicroEURPerMTok,
OutputMicroEURPerMTok: fallbackOutputMicroEURPerMTok,
}
estimated = true
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
microEUR = int64(usage.TotalTokens) * fallbackInputMicroEURPerMTok / 1_000_000
if microEUR == 0 && usage.TotalTokens > 0 {
microEUR = 1
}
return microEUR, true
}
}
cachedRate := price.CachedInputMicroEURPerMTok
if cachedRate == 0 {
cachedRate = price.InputMicroEURPerMTok / 2
}
reasoningRate := price.ReasoningMicroEURPerMTok
if reasoningRate == 0 {
reasoningRate = price.OutputMicroEURPerMTok
}
uncached := usage.UncachedInputTokens()
microEUR += int64(uncached) * price.InputMicroEURPerMTok / 1_000_000
microEUR += int64(usage.CachedInputTokens) * cachedRate / 1_000_000
completion := usage.CompletionTokens - usage.ReasoningTokens
if completion < 0 {
completion = usage.CompletionTokens
}
microEUR += int64(completion) * price.OutputMicroEURPerMTok / 1_000_000
microEUR += int64(usage.ReasoningTokens) * reasoningRate / 1_000_000
if microEUR == 0 && usage.TotalTokens > 0 {
microEUR = 1
}
return microEUR, estimated
}
// UpsertModelPrice stores or updates pricing for a model (effective today).
func (s *PricingStore) UpsertModelPrice(ctx context.Context, p ModelPrice) error {
if s.db == nil {
return nil
}
today := time.Now().UTC().Truncate(24 * time.Hour)
var cached, reasoning *int64
if p.CachedInputMicroEURPerMTok > 0 {
v := p.CachedInputMicroEURPerMTok
cached = &v
}
if p.ReasoningMicroEURPerMTok > 0 {
v := p.ReasoningMicroEURPerMTok
reasoning = &v
}
_, err := s.db.Exec(ctx, `
INSERT INTO ai_model_pricing (
model_id, provider_type,
input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok,
output_micro_eur_per_mtok, reasoning_micro_eur_per_mtok,
effective_from, source
) VALUES ($1, $2, $3, $4, $5, $6, $7, 'manual')
ON CONFLICT (model_id, effective_from) DO UPDATE SET
provider_type = EXCLUDED.provider_type,
input_micro_eur_per_mtok = EXCLUDED.input_micro_eur_per_mtok,
cached_input_micro_eur_per_mtok = EXCLUDED.cached_input_micro_eur_per_mtok,
output_micro_eur_per_mtok = EXCLUDED.output_micro_eur_per_mtok,
reasoning_micro_eur_per_mtok = EXCLUDED.reasoning_micro_eur_per_mtok,
source = EXCLUDED.source
`, p.ModelID, p.ProviderType, p.InputMicroEURPerMTok, cached,
p.OutputMicroEURPerMTok, reasoning, today)
return err
}
func (s *PricingStore) ListPrices(ctx context.Context) ([]ModelPrice, error) {
if s.db == nil {
return nil, nil
}
rows, err := s.db.Query(ctx, `
SELECT DISTINCT ON (model_id)
model_id, provider_type,
input_micro_eur_per_mtok,
cached_input_micro_eur_per_mtok,
output_micro_eur_per_mtok,
reasoning_micro_eur_per_mtok
FROM ai_model_pricing
WHERE effective_from <= CURRENT_DATE
ORDER BY model_id, effective_from DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var out []ModelPrice
for rows.Next() {
var p ModelPrice
var cached, reasoning *int64
if err := rows.Scan(&p.ModelID, &p.ProviderType,
&p.InputMicroEURPerMTok, &cached,
&p.OutputMicroEURPerMTok, &reasoning); err != nil {
return nil, err
}
if cached != nil {
p.CachedInputMicroEURPerMTok = *cached
}
if reasoning != nil {
p.ReasoningMicroEURPerMTok = *reasoning
}
out = append(out, p)
}
return out, rows.Err()
}
// MicroEURToEUR converts micro-EUR to float EUR for API responses.
func MicroEURToEUR(micro int64) float64 {
return float64(micro) / 1_000_000
}

18
internal/ai/cost/usage.go Normal file
View File

@ -0,0 +1,18 @@
package cost
// UsageDetail holds token counts from a provider response.
type UsageDetail struct {
PromptTokens int
CompletionTokens int
CachedInputTokens int
ReasoningTokens int
TotalTokens int
}
func (u UsageDetail) UncachedInputTokens() int {
uncached := u.PromptTokens - u.CachedInputTokens
if uncached < 0 {
return u.PromptTokens
}
return uncached
}

View File

@ -13,6 +13,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai/cost"
"github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/llm"
) )
@ -40,12 +41,6 @@ type chatCompletionRequest struct {
Tools []any `json:"tools,omitempty"` Tools []any `json:"tools,omitempty"`
} }
type usagePayload struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type chatCompletionResponse struct { type chatCompletionResponse struct {
ID string `json:"id"` ID string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@ -57,7 +52,17 @@ type chatCompletionResponse struct {
FinishReason string `json:"finish_reason"` FinishReason string `json:"finish_reason"`
Delta *llm.ChatMessage `json:"delta,omitempty"` Delta *llm.ChatMessage `json:"delta,omitempty"`
} `json:"choices"` } `json:"choices"`
Usage *usagePayload `json:"usage,omitempty"` Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
} `json:"usage,omitempty"`
Error *struct { Error *struct {
Message string `json:"message"` Message string `json:"message"`
} `json:"error,omitempty"` } `json:"error,omitempty"`
@ -116,12 +121,6 @@ func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Setti
} }
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error { func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
if strings.TrimSpace(quotaExternalUserID) != "" {
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil {
return err
}
}
var modelProbe struct { var modelProbe struct {
Model string `json:"model"` Model string `json:"model"`
} }
@ -151,6 +150,14 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
return err return err
} }
billingScope := ResolveBillingScope(ctx, g.db, quotaExternalUserID, provider, useOrgSettings)
if strings.TrimSpace(quotaExternalUserID) != "" {
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID, provider, useOrgSettings); err != nil {
return err
}
}
upstreamBody, err := repairChatCompletionBody(body) upstreamBody, err := repairChatCompletionBody(body)
if err != nil { if err != nil {
return err return err
@ -182,7 +189,7 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
defer resp.Body.Close() defer resp.Body.Close()
if stream { if stream {
return g.proxyStream(ctx, quotaExternalUserID, w, resp) return g.proxyStream(ctx, quotaExternalUserID, model, provider, billingScope, w, resp)
} }
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
if err != nil { if err != nil {
@ -195,13 +202,21 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
return nil return nil
} }
if strings.TrimSpace(quotaExternalUserID) != "" { if strings.TrimSpace(quotaExternalUserID) != "" {
tokens := extractUsageTokens(payload) usage := cost.ParseUsage(payload)
_ = g.quota.Record(ctx, quotaExternalUserID, tokens) _ = g.quota.RecordUsage(ctx, cost.RecordInput{
ExternalUserID: quotaExternalUserID,
Feature: "gateway",
ModelID: model,
ProviderID: provider.ID,
BillingScope: billingScope,
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
Usage: usage,
})
} }
return nil return nil
} }
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error { func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID, model string, provider llm.Provider, billingScope string, w http.ResponseWriter, resp *http.Response) error {
rc := http.NewResponseController(w) rc := http.NewResponseController(w)
w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Cache-Control", "no-cache")
@ -210,7 +225,7 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
w.WriteHeader(resp.StatusCode) w.WriteHeader(resp.StatusCode)
reader := bufio.NewReader(resp.Body) reader := bufio.NewReader(resp.Body)
var totalTokens int64 var usage cost.UsageDetail
for { for {
line, err := reader.ReadString('\n') line, err := reader.ReadString('\n')
if len(line) > 0 { if len(line) > 0 {
@ -219,7 +234,8 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
return fmt.Errorf("streaming not supported: %w", err) return fmt.Errorf("streaming not supported: %w", err)
} }
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))) chunk := []byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))
usage = cost.MergeStreamUsage(usage, chunk)
} }
} }
if err != nil { if err != nil {
@ -235,10 +251,18 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
} }
} }
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" { if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
if totalTokens == 0 { if usage.TotalTokens == 0 {
totalTokens = 1 usage.TotalTokens = 1
} }
_ = g.quota.Record(ctx, quotaExternalUserID, totalTokens) _ = g.quota.RecordUsage(ctx, cost.RecordInput{
ExternalUserID: quotaExternalUserID,
Feature: "gateway",
ModelID: model,
ProviderID: provider.ID,
BillingScope: billingScope,
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
Usage: usage,
})
} }
return nil return nil
} }
@ -262,31 +286,6 @@ func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider,
return provider, resolvedModel, nil return provider, resolvedModel, nil
} }
func extractUsageTokens(payload []byte) int64 {
var parsed chatCompletionResponse
if err := json.Unmarshal(payload, &parsed); err != nil {
return 1
}
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
return int64(parsed.Usage.TotalTokens)
}
if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 {
return int64(parsed.Usage.CompletionTokens)
}
return 1
}
func extractStreamUsageTokens(payload []byte) int64 {
var parsed chatCompletionResponse
if err := json.Unmarshal(payload, &parsed); err != nil {
return 0
}
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
return int64(parsed.Usage.TotalTokens)
}
return 0
}
func NowUnix() int64 { func NowUnix() int64 {
return time.Now().Unix() return time.Now().Unix()
} }

View File

@ -1,16 +1,22 @@
package ai package ai
import "testing" import (
"testing"
func TestExtractUsageTokens(t *testing.T) { "github.com/ultisuite/ulti-backend/internal/ai/cost"
)
func TestParseUsageViaCost(t *testing.T) {
payload := []byte(`{"usage":{"total_tokens":42,"completion_tokens":10}}`) payload := []byte(`{"usage":{"total_tokens":42,"completion_tokens":10}}`)
if got := extractUsageTokens(payload); got != 42 { u := cost.ParseUsage(payload)
t.Fatalf("extractUsageTokens() = %d, want 42", got) if u.TotalTokens != 42 {
t.Fatalf("ParseUsage() = %d, want 42", u.TotalTokens)
} }
} }
func TestExtractUsageTokensFallback(t *testing.T) { func TestParseUsageFallback(t *testing.T) {
if got := extractUsageTokens([]byte(`{"choices":[]}`)); got != 1 { u := cost.ParseUsage([]byte(`{"choices":[]}`))
t.Fatalf("expected fallback token count 1, got %d", got) if u.TotalTokens != 1 {
t.Fatalf("expected fallback token count 1, got %d", u.TotalTokens)
} }
} }

34
internal/ai/metering.go Normal file
View File

@ -0,0 +1,34 @@
package ai
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai/cost"
"github.com/ultisuite/ulti-backend/internal/llm"
)
// RecordFeatureUsage meters an LLM call from a non-gateway feature.
func RecordFeatureUsage(ctx context.Context, db *pgxpool.Pool, externalUserID, feature, modelID string, provider llm.Provider, usage llm.UsageDetail) {
if db == nil || externalUserID == "" {
return
}
q := NewQuotaService(db)
scope := ResolveBillingScope(ctx, db, externalUserID, provider, false)
_ = q.RecordUsage(ctx, cost.RecordInput{
ExternalUserID: externalUserID,
Feature: feature,
ModelID: modelID,
ProviderID: provider.ID,
BillingScope: scope,
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
Usage: cost.UsageDetail{
PromptTokens: usage.PromptTokens,
CompletionTokens: usage.CompletionTokens,
CachedInputTokens: usage.CachedInputTokens,
ReasoningTokens: usage.ReasoningTokens,
TotalTokens: usage.TotalTokens,
},
})
}

View File

@ -229,7 +229,7 @@ func ResolveDefaultModel(ctx context.Context, db *pgxpool.Pool, policy Assistant
} }
func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) { func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) {
defaults := QuotaLimits{RequestsPerDay: 100, TokensPerMonth: 500_000} defaults := QuotaLimits{RequestsPerDay: 75, TokensPerMonth: 2_000_000}
if db == nil { if db == nil {
return defaults, nil return defaults, nil
} }

View File

@ -2,123 +2,95 @@ package ai
import ( import (
"context" "context"
"errors" "strings"
"fmt"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai/cost"
"github.com/ultisuite/ulti-backend/internal/llm"
) )
var ErrQuotaExceeded = errors.New("llm quota exceeded") // QuotaService wraps cost policy and metering for backward compatibility.
type QuotaService struct { type QuotaService struct {
db *pgxpool.Pool db *pgxpool.Pool
policy *cost.PolicyService
meter *cost.Meter
} }
func NewQuotaService(db *pgxpool.Pool) *QuotaService { func NewQuotaService(db *pgxpool.Pool) *QuotaService {
return &QuotaService{db: db} return &QuotaService{
db: db,
policy: cost.NewPolicyService(db),
meter: cost.NewMeter(db),
}
} }
func (s *QuotaService) Check(ctx context.Context, externalUserID string) (QuotaStatus, error) { func (s *QuotaService) Check(ctx context.Context, externalUserID string) (SpendStatus, error) {
limits, err := LoadQuotaLimits(ctx, s.db) orgBilling, _ := s.usesOrgBilling(ctx, externalUserID)
status, err := s.policy.GetStatus(ctx, externalUserID, orgBilling)
if err != nil { if err != nil {
return QuotaStatus{}, err return SpendStatus{}, err
}
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return QuotaStatus{}, err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
var requestsToday int
var tokensMonth int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(requests, 0) FROM ai_usage_daily
WHERE user_id = $1 AND usage_date = $2
`, userID, today).Scan(&requestsToday)
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(tokens, 0) FROM ai_usage_monthly
WHERE user_id = $1 AND usage_month = $2
`, userID, month).Scan(&tokensMonth)
status := QuotaStatus{
RequestsUsedToday: requestsToday,
RequestsLimit: limits.RequestsPerDay,
TokensUsedMonth: tokensMonth,
TokensLimit: limits.TokensPerMonth,
}
if limits.RequestsPerDay > 0 {
status.RequestsRemaining = limits.RequestsPerDay - requestsToday
if status.RequestsRemaining < 0 {
status.RequestsRemaining = 0
}
}
if limits.TokensPerMonth > 0 {
status.TokensRemaining = limits.TokensPerMonth - tokensMonth
if status.TokensRemaining < 0 {
status.TokensRemaining = 0
}
} }
return status, nil return status, nil
} }
func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string) error { func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string, provider llm.Provider, useOrgSettings bool) error {
status, err := s.Check(ctx, externalUserID) scope := ResolveBillingScope(ctx, s.db, externalUserID, provider, useOrgSettings)
if err != nil { return s.policy.AssertAvailable(ctx, externalUserID, scope)
return err
}
if status.RequestsLimit > 0 && status.RequestsUsedToday >= status.RequestsLimit {
return fmt.Errorf("%w: daily request limit reached", ErrQuotaExceeded)
}
if status.TokensLimit > 0 && status.TokensUsedMonth >= status.TokensLimit {
return fmt.Errorf("%w: monthly token limit reached", ErrQuotaExceeded)
}
return nil
} }
func (s *QuotaService) Record(ctx context.Context, externalUserID string, tokens int64) error { func (s *QuotaService) RecordUsage(ctx context.Context, in cost.RecordInput) error {
if tokens < 0 { return s.meter.RecordUsage(ctx, in)
tokens = 0
}
userID, err := s.resolveUserID(ctx, externalUserID)
if err != nil {
return err
}
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
_, err = s.db.Exec(ctx, `
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens)
VALUES ($1, $2, 1, $3)
ON CONFLICT (user_id, usage_date) DO UPDATE SET
requests = ai_usage_daily.requests + 1,
tokens = ai_usage_daily.tokens + EXCLUDED.tokens
`, userID, today, tokens)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens)
VALUES ($1, $2, $3)
ON CONFLICT (user_id, usage_month) DO UPDATE SET
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens
`, userID, month, tokens)
return err
} }
func (s *QuotaService) resolveUserID(ctx context.Context, externalUserID string) (string, error) { func (s *QuotaService) usesOrgBilling(ctx context.Context, externalUserID string) (bool, error) {
var userID string settings, err := LoadEffectiveLLMSettings(ctx, s.db, externalUserID)
err := s.db.QueryRow(ctx, `
SELECT id::text FROM users WHERE external_id = $1
`, externalUserID).Scan(&userID)
if err != nil { if err != nil {
if err == pgx.ErrNoRows { return true, err
return "", fmt.Errorf("user not found")
} }
return "", err org, err := loadOrgLLMPolicy(ctx, s.db)
if err != nil {
return true, err
} }
return userID, nil if org.EnforceOrgProviders || len(settings.Providers) == 0 {
return true, nil
} }
user, err := loadUserLLMSettings(ctx, s.db, externalUserID)
if err != nil {
return true, err
}
return len(user.Providers) == 0, nil
}
// ResolveBillingScope determines whether usage is billed to org or user's own key.
func ResolveBillingScope(ctx context.Context, db *pgxpool.Pool, externalUserID string, provider llm.Provider, useOrgSettings bool) string {
if useOrgSettings {
return cost.BillingScopeOrg
}
org, err := loadOrgLLMPolicy(ctx, db)
if err != nil || org.EnforceOrgProviders {
return cost.BillingScopeOrg
}
user, err := loadUserLLMSettings(ctx, db, externalUserID)
if err != nil || len(user.Providers) == 0 {
return cost.BillingScopeOrg
}
apiKey := strings.TrimSpace(provider.APIKey)
for _, op := range org.Providers {
if op.ID == provider.ID && strings.TrimSpace(op.APIKey) == apiKey && apiKey != "" {
return cost.BillingScopeOrg
}
}
for _, up := range user.Providers {
if up.ID == provider.ID && strings.TrimSpace(up.APIKey) != "" {
return cost.BillingScopeUser
}
}
return cost.BillingScopeOrg
}
// SpendStatus is the user-facing quota/spend response.
type SpendStatus = cost.SpendStatus
// ErrQuotaExceeded aliases cost limit error for backward compatibility.
var ErrQuotaExceeded = cost.ErrCostLimitExceeded

View File

@ -2,20 +2,16 @@ package ai
import "testing" import "testing"
func TestQuotaStatusRemaining(t *testing.T) { func TestSpendStatusCostRemaining(t *testing.T) {
status := QuotaStatus{ limit := int64(10_000_000)
RequestsUsedToday: 40, status := SpendStatus{
RequestsLimit: 100, CostUsedTodayMicroEUR: 4_000_000,
TokensUsedMonth: 100_000, CostLimitTodayMicroEUR: &limit,
TokensLimit: 500_000, Currency: "EUR",
} }
status.RequestsRemaining = status.RequestsLimit - status.RequestsUsedToday remaining := *status.CostLimitTodayMicroEUR - status.CostUsedTodayMicroEUR
status.TokensRemaining = status.TokensLimit - status.TokensUsedMonth if remaining != 6_000_000 {
if status.RequestsRemaining != 60 { t.Fatalf("cost remaining = %d", remaining)
t.Fatalf("requests remaining = %d", status.RequestsRemaining)
}
if status.TokensRemaining != 400_000 {
t.Fatalf("tokens remaining = %d", status.TokensRemaining)
} }
} }

View File

@ -21,19 +21,14 @@ type AssistantPolicy struct {
} }
type QuotaLimits struct { type QuotaLimits struct {
DailyLimitMicroEUR *int64 `json:"llm_daily_cost_limit_micro_eur,omitempty"`
MonthlyLimitMicroEUR *int64 `json:"llm_monthly_cost_limit_micro_eur,omitempty"`
WarnThresholdPct int `json:"llm_cost_warn_threshold_pct,omitempty"`
// Deprecated legacy fields
RequestsPerDay int `json:"llm_requests_per_day"` RequestsPerDay int `json:"llm_requests_per_day"`
TokensPerMonth int64 `json:"llm_tokens_per_month"` TokensPerMonth int64 `json:"llm_tokens_per_month"`
} }
type QuotaStatus struct {
RequestsUsedToday int `json:"requests_used_today"`
RequestsLimit int `json:"requests_limit"`
TokensUsedMonth int64 `json:"tokens_used_month"`
TokensLimit int64 `json:"tokens_limit"`
RequestsRemaining int `json:"requests_remaining"`
TokensRemaining int64 `json:"tokens_remaining"`
}
type ChatMessage struct { type ChatMessage struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content,omitempty"` Content string `json:"content,omitempty"`

View File

@ -0,0 +1,384 @@
package admin
import (
"context"
"fmt"
"time"
"github.com/ultisuite/ulti-backend/internal/ai/cost"
)
type AIUsageSummary struct {
CostTodayMicroEUR int64 `json:"cost_today_micro_eur"`
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
CostTodayEUR float64 `json:"cost_today_eur"`
CostMonthEUR float64 `json:"cost_month_eur"`
Currency string `json:"currency"`
DailySeries []AIUsageDayPoint `json:"daily_series"`
TopUsers []AIUsageTopUser `json:"top_users"`
TopModels []AIUsageTopModel `json:"top_models"`
OrgPolicy cost.EffectiveLimits `json:"org_policy"`
}
type AIUsageDayPoint struct {
Date string `json:"date"`
CostOrgMicroEUR int64 `json:"cost_org_micro_eur"`
CostUserMicroEUR int64 `json:"cost_user_micro_eur"`
CostOrgEUR float64 `json:"cost_org_eur"`
Requests int `json:"requests"`
}
type AIUsageTopUser struct {
UserID string `json:"user_id"`
Email string `json:"email"`
DisplayName string `json:"display_name"`
CostOrgMicroEUR int64 `json:"cost_org_micro_eur"`
CostOrgEUR float64 `json:"cost_org_eur"`
}
type AIUsageTopModel struct {
ModelID string `json:"model_id"`
CostMicroEUR int64 `json:"cost_micro_eur"`
CostEUR float64 `json:"cost_eur"`
RequestCount int `json:"request_count"`
}
type AIUserUsageDetail struct {
UserID string `json:"user_id"`
Email string `json:"email"`
DisplayName string `json:"display_name"`
Summary AIUsageSummary `json:"summary"`
Events []AIUsageEventItem `json:"events"`
EventsTotal int `json:"events_total"`
}
type AIUsageEventItem struct {
ID string `json:"id"`
CreatedAt string `json:"created_at"`
Feature string `json:"feature"`
ModelID string `json:"model_id"`
ProviderID string `json:"provider_id"`
BillingScope string `json:"billing_scope"`
CostMicroEUR int64 `json:"cost_micro_eur"`
CostEUR float64 `json:"cost_eur"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
CachedInputTokens int `json:"cached_input_tokens"`
Estimated bool `json:"estimated"`
}
type AIPricingEntry struct {
ModelID string `json:"model_id"`
ProviderType string `json:"provider_type"`
InputMicroEURPerMTok int64 `json:"input_micro_eur_per_mtok"`
CachedInputMicroEURPerMTok int64 `json:"cached_input_micro_eur_per_mtok,omitempty"`
OutputMicroEURPerMTok int64 `json:"output_micro_eur_per_mtok"`
ReasoningMicroEURPerMTok int64 `json:"reasoning_micro_eur_per_mtok,omitempty"`
InputEURPerMTok float64 `json:"input_eur_per_mtok"`
OutputEURPerMTok float64 `json:"output_eur_per_mtok"`
}
type AICostPolicyEntry struct {
ScopeType string `json:"scope_type"`
ScopeID *string `json:"scope_id,omitempty"`
ScopeName string `json:"scope_name,omitempty"`
DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"`
MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"`
DailyLimitEUR *float64 `json:"daily_limit_eur,omitempty"`
MonthlyLimitEUR *float64 `json:"monthly_limit_eur,omitempty"`
WarnThresholdPct int `json:"warn_threshold_pct"`
}
type putAIPricingRequest struct {
Prices []AIPricingEntry `json:"prices"`
}
type putAICostPolicyRequest struct {
DailyLimitEUR *float64 `json:"daily_limit_eur"`
MonthlyLimitEUR *float64 `json:"monthly_limit_eur"`
WarnThresholdPct int `json:"warn_threshold_pct"`
DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"`
MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"`
}
func (s *Service) GetAIUsageSummary(ctx context.Context, billingScope string) (AIUsageSummary, error) {
today := time.Now().UTC().Truncate(24 * time.Hour)
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
out := AIUsageSummary{
Currency: "EUR",
DailySeries: []AIUsageDayPoint{},
TopUsers: []AIUsageTopUser{},
TopModels: []AIUsageTopModel{},
}
policySvc := cost.NewPolicyService(s.db)
orgPolicy, _ := policySvc.GetOrgPolicy(ctx)
out.OrgPolicy = orgPolicy
var dailyOrg, dailyUser int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0)
FROM ai_org_usage_daily WHERE usage_date = $1
`, today).Scan(&dailyOrg, &dailyUser)
var monthlyOrg int64
_ = s.db.QueryRow(ctx, `
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_org_usage_monthly WHERE usage_month = $1
`, month).Scan(&monthlyOrg)
if billingScope == "user" {
out.CostTodayMicroEUR = dailyUser
} else {
out.CostTodayMicroEUR = dailyOrg
}
out.CostMonthMicroEUR = monthlyOrg
out.CostTodayEUR = cost.MicroEURToEUR(out.CostTodayMicroEUR)
out.CostMonthEUR = cost.MicroEURToEUR(out.CostMonthMicroEUR)
rows, err := s.db.Query(ctx, `
SELECT usage_date, cost_micro_eur_org, cost_micro_eur_user, requests
FROM ai_org_usage_daily
WHERE usage_date >= $1
ORDER BY usage_date ASC
`, today.AddDate(0, 0, -29))
if err == nil {
defer rows.Close()
for rows.Next() {
var pt AIUsageDayPoint
var d time.Time
if err := rows.Scan(&d, &pt.CostOrgMicroEUR, &pt.CostUserMicroEUR, &pt.Requests); err != nil {
continue
}
pt.Date = d.Format("2006-01-02")
pt.CostOrgEUR = cost.MicroEURToEUR(pt.CostOrgMicroEUR)
out.DailySeries = append(out.DailySeries, pt)
}
}
topUserRows, err := s.db.Query(ctx, `
SELECT u.id::text, COALESCE(u.email, ''), COALESCE(u.display_name, ''),
COALESCE(SUM(e.cost_micro_eur), 0)
FROM ai_usage_events e
JOIN users u ON u.id = e.user_id
WHERE e.created_at >= $1 AND e.billing_scope = 'org'
GROUP BY u.id, u.email, u.display_name
ORDER BY SUM(e.cost_micro_eur) DESC
LIMIT 10
`, month)
if err == nil {
defer topUserRows.Close()
for topUserRows.Next() {
var u AIUsageTopUser
if err := topUserRows.Scan(&u.UserID, &u.Email, &u.DisplayName, &u.CostOrgMicroEUR); err != nil {
continue
}
u.CostOrgEUR = cost.MicroEURToEUR(u.CostOrgMicroEUR)
out.TopUsers = append(out.TopUsers, u)
}
}
topModelRows, err := s.db.Query(ctx, `
SELECT model_id, COALESCE(SUM(cost_micro_eur), 0), COUNT(*)
FROM ai_usage_events
WHERE created_at >= $1 AND billing_scope = 'org'
GROUP BY model_id
ORDER BY SUM(cost_micro_eur) DESC
LIMIT 10
`, month)
if err == nil {
defer topModelRows.Close()
for topModelRows.Next() {
var m AIUsageTopModel
if err := topModelRows.Scan(&m.ModelID, &m.CostMicroEUR, &m.RequestCount); err != nil {
continue
}
m.CostEUR = cost.MicroEURToEUR(m.CostMicroEUR)
out.TopModels = append(out.TopModels, m)
}
}
return out, nil
}
func (s *Service) GetUserAIUsage(ctx context.Context, userID string, limit, offset int) (AIUserUsageDetail, error) {
if limit <= 0 {
limit = 50
}
var detail AIUserUsageDetail
detail.Events = []AIUsageEventItem{}
err := s.db.QueryRow(ctx, `
SELECT id::text, COALESCE(email, ''), COALESCE(display_name, '')
FROM users WHERE id = $1::uuid
`, userID).Scan(&detail.UserID, &detail.Email, &detail.DisplayName)
if err != nil {
return detail, fmt.Errorf("user not found")
}
summary, _ := s.GetAIUsageSummary(ctx, "org")
detail.Summary = summary
_ = s.db.QueryRow(ctx, `
SELECT COUNT(*) FROM ai_usage_events WHERE user_id = $1::uuid
`, userID).Scan(&detail.EventsTotal)
rows, err := s.db.Query(ctx, `
SELECT id::text, created_at, feature, model_id, provider_id, billing_scope,
cost_micro_eur, prompt_tokens, completion_tokens, cached_input_tokens, estimated
FROM ai_usage_events
WHERE user_id = $1::uuid
ORDER BY created_at DESC
LIMIT $2 OFFSET $3
`, userID, limit, offset)
if err != nil {
return detail, err
}
defer rows.Close()
for rows.Next() {
var ev AIUsageEventItem
var created time.Time
if err := rows.Scan(&ev.ID, &created, &ev.Feature, &ev.ModelID, &ev.ProviderID,
&ev.BillingScope, &ev.CostMicroEUR, &ev.PromptTokens, &ev.CompletionTokens,
&ev.CachedInputTokens, &ev.Estimated); err != nil {
continue
}
ev.CreatedAt = created.UTC().Format(time.RFC3339)
ev.CostEUR = cost.MicroEURToEUR(ev.CostMicroEUR)
detail.Events = append(detail.Events, ev)
}
return detail, rows.Err()
}
func (s *Service) ListAIPricing(ctx context.Context) ([]AIPricingEntry, error) {
store := cost.NewPricingStore(s.db)
prices, err := store.ListPrices(ctx)
if err != nil {
return []AIPricingEntry{}, err
}
out := make([]AIPricingEntry, 0, len(prices))
for _, p := range prices {
out = append(out, AIPricingEntry{
ModelID: p.ModelID,
ProviderType: p.ProviderType,
InputMicroEURPerMTok: p.InputMicroEURPerMTok,
CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok,
OutputMicroEURPerMTok: p.OutputMicroEURPerMTok,
ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok,
InputEURPerMTok: cost.MicroEURToEUR(p.InputMicroEURPerMTok),
OutputEURPerMTok: cost.MicroEURToEUR(p.OutputMicroEURPerMTok),
})
}
return out, nil
}
func (s *Service) PutAIPricing(ctx context.Context, actorSub string, req putAIPricingRequest) ([]AIPricingEntry, error) {
store := cost.NewPricingStore(s.db)
for _, p := range req.Prices {
if err := store.UpsertModelPrice(ctx, cost.ModelPrice{
ModelID: p.ModelID,
ProviderType: p.ProviderType,
InputMicroEURPerMTok: p.InputMicroEURPerMTok,
CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok,
OutputMicroEURPerMTok: p.OutputMicroEURPerMTok,
ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok,
}); err != nil {
return nil, err
}
}
s.logAudit(ctx, actorSub, "update_ai_pricing", map[string]any{"count": len(req.Prices)})
return s.ListAIPricing(ctx)
}
func (s *Service) ListAICostPolicies(ctx context.Context) ([]AICostPolicyEntry, error) {
rows, err := s.db.Query(ctx, `
SELECT p.scope_type, p.scope_id::text, p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct,
COALESCE(g.name, u.email, 'Organisation')
FROM ai_cost_policies p
LEFT JOIN user_groups g ON p.scope_type = 'group' AND g.id = p.scope_id
LEFT JOIN users u ON p.scope_type = 'user' AND u.id = p.scope_id
ORDER BY p.priority ASC, p.scope_type
`)
if err != nil {
return []AICostPolicyEntry{}, err
}
defer rows.Close()
var out = make([]AICostPolicyEntry, 0)
for rows.Next() {
var e AICostPolicyEntry
var scopeID *string
if err := rows.Scan(&e.ScopeType, &scopeID, &e.DailyLimitMicroEUR, &e.MonthlyLimitMicroEUR,
&e.WarnThresholdPct, &e.ScopeName); err != nil {
continue
}
e.ScopeID = scopeID
if e.DailyLimitMicroEUR != nil {
v := cost.MicroEURToEUR(*e.DailyLimitMicroEUR)
e.DailyLimitEUR = &v
}
if e.MonthlyLimitMicroEUR != nil {
v := cost.MicroEURToEUR(*e.MonthlyLimitMicroEUR)
e.MonthlyLimitEUR = &v
}
out = append(out, e)
}
return out, rows.Err()
}
func (s *Service) PutUserAICostPolicy(ctx context.Context, actorSub, userID string, req putAICostPolicyRequest) error {
daily, monthly := resolveCostLimits(req)
policy := cost.NewPolicyService(s.db)
if err := policy.UpsertScopePolicy(ctx, "user", userID, daily, monthly, req.WarnThresholdPct); err != nil {
return err
}
s.logAudit(ctx, actorSub, "set_user_ai_cost_policy", map[string]any{"user_id": userID})
return nil
}
func (s *Service) PutGroupAICostPolicy(ctx context.Context, actorSub, groupID string, req putAICostPolicyRequest) error {
daily, monthly := resolveCostLimits(req)
policy := cost.NewPolicyService(s.db)
if err := policy.UpsertScopePolicy(ctx, "group", groupID, daily, monthly, req.WarnThresholdPct); err != nil {
return err
}
s.logAudit(ctx, actorSub, "set_group_ai_cost_policy", map[string]any{"group_id": groupID})
return nil
}
func (s *Service) syncUsageQuotasToCostPolicy(ctx context.Context, usageQuotas map[string]any) error {
if usageQuotas == nil {
return nil
}
var daily, monthly *int64
warnPct := 80
if v, ok := usageQuotas["llm_daily_cost_limit_eur"].(float64); ok && v > 0 {
m := int64(v * 1_000_000)
daily = &m
}
if v, ok := usageQuotas["llm_monthly_cost_limit_eur"].(float64); ok && v > 0 {
m := int64(v * 1_000_000)
monthly = &m
}
if v, ok := usageQuotas["llm_cost_warn_threshold_pct"].(float64); ok && v > 0 {
warnPct = int(v)
}
policy := cost.NewPolicyService(s.db)
return policy.UpsertOrgPolicy(ctx, daily, monthly, warnPct)
}
func resolveCostLimits(req putAICostPolicyRequest) (*int64, *int64) {
var daily, monthly *int64
if req.DailyLimitMicroEUR != nil {
daily = req.DailyLimitMicroEUR
} else if req.DailyLimitEUR != nil && *req.DailyLimitEUR > 0 {
v := int64(*req.DailyLimitEUR * 1_000_000)
daily = &v
}
if req.MonthlyLimitMicroEUR != nil {
monthly = req.MonthlyLimitMicroEUR
} else if req.MonthlyLimitEUR != nil && *req.MonthlyLimitEUR > 0 {
v := int64(*req.MonthlyLimitEUR * 1_000_000)
monthly = &v
}
return daily, monthly
}

View File

@ -0,0 +1,116 @@
package admin
import (
"net/http"
"strconv"
"github.com/go-chi/chi/v5"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
"github.com/ultisuite/ulti-backend/internal/api/middleware"
)
func (h *Handler) GetAIUsage(w http.ResponseWriter, r *http.Request) {
scope := r.URL.Query().Get("scope")
if scope == "" {
scope = "org"
}
summary, err := h.svc.GetAIUsageSummary(r.Context(), scope)
if err != nil {
h.logger.Error("get ai usage", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, summary)
}
func (h *Handler) GetUserAIUsage(w http.ResponseWriter, r *http.Request) {
userID := chi.URLParam(r, "userID")
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
detail, err := h.svc.GetUserAIUsage(r.Context(), userID, limit, offset)
if err != nil {
if err.Error() == "user not found" {
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, err.Error(), nil)
return
}
h.logger.Error("get user ai usage", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, detail)
}
func (h *Handler) GetAIPricing(w http.ResponseWriter, r *http.Request) {
prices, err := h.svc.ListAIPricing(r.Context())
if err != nil {
h.logger.Error("list ai pricing", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices})
}
func (h *Handler) PutAIPricing(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
var req putAIPricingRequest
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
return
}
prices, err := h.svc.PutAIPricing(r.Context(), claims.Sub, req)
if err != nil {
h.logger.Error("put ai pricing", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices})
}
func (h *Handler) GetAICostPolicies(w http.ResponseWriter, r *http.Request) {
policies, err := h.svc.ListAICostPolicies(r.Context())
if err != nil {
h.logger.Error("list ai cost policies", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"policies": policies})
}
func (h *Handler) PutUserAICostPolicy(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
userID := chi.URLParam(r, "userID")
if verr := validateUserID(userID); verr != nil {
apivalidate.WriteValidationError(w, r, verr)
return
}
var req putAICostPolicyRequest
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
return
}
if err := h.svc.PutUserAICostPolicy(r.Context(), claims.Sub, userID, req); err != nil {
h.logger.Error("put user ai cost policy", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
}
func (h *Handler) PutGroupAICostPolicy(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
groupID := chi.URLParam(r, "groupID")
if verr := validateGroupID(groupID); verr != nil {
apivalidate.WriteValidationError(w, r, verr)
return
}
var req putAICostPolicyRequest
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
return
}
if err := h.svc.PutGroupAICostPolicy(r.Context(), claims.Sub, groupID, req); err != nil {
h.logger.Error("put group ai cost policy", "error", err)
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
}

View File

@ -79,6 +79,14 @@ func (h *Handler) Routes() chi.Router {
r.With(write).Put("/org/settings", h.PutOrgSettings) r.With(write).Put("/org/settings", h.PutOrgSettings)
r.With(read).Post("/org/llm/discover-models", h.DiscoverOrgLLMModels) r.With(read).Post("/org/llm/discover-models", h.DiscoverOrgLLMModels)
r.With(read).Get("/ai/usage", h.GetAIUsage)
r.With(read).Get("/ai/usage/users/{userID}", h.GetUserAIUsage)
r.With(read).Get("/ai/pricing", h.GetAIPricing)
r.With(write).Put("/ai/pricing", h.PutAIPricing)
r.With(read).Get("/ai/policies", h.GetAICostPolicies)
r.With(write).Put("/users/{userID}/ai-policy", h.PutUserAICostPolicy)
r.With(write).Put("/user-groups/{groupID}/ai-policy", h.PutGroupAICostPolicy)
r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI) r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI)
r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider) r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider)
r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider) r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider)

View File

@ -45,11 +45,14 @@ func defaultOrgPolicy() map[string]any {
"warn_threshold_pct": 90, "warn_threshold_pct": 90,
}, },
"usage_quotas": map[string]any{ "usage_quotas": map[string]any{
"llm_requests_per_day": 100, "llm_daily_cost_limit_eur": 2,
"llm_tokens_per_month": 500000, "llm_monthly_cost_limit_eur": 35,
"search_requests_per_day": 50, "llm_cost_warn_threshold_pct": 80,
"max_api_tokens_per_user": 10, "llm_requests_per_day": 75,
"max_webhooks_per_user": 20, "llm_tokens_per_month": 2_000_000,
"search_requests_per_day": 20,
"max_api_tokens_per_user": 5,
"max_webhooks_per_user": 5,
}, },
"file_policies": map[string]any{ "file_policies": map[string]any{
"max_upload_mib": 512, "max_upload_mib": 512,
@ -909,6 +912,12 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map
} }
} }
if usageQuotas, ok := merged["usage_quotas"].(map[string]any); ok {
if err := s.syncUsageQuotasToCostPolicy(ctx, usageQuotas); err != nil {
s.logger.Warn("sync ai cost policy failed", "error", err)
}
}
return s.GetOrgSettings(ctx) return s.GetOrgSettings(ctx)
} }

View File

@ -13,6 +13,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai"
"github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/llm"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
"github.com/ultisuite/ulti-backend/internal/orgpolicy" "github.com/ultisuite/ulti-backend/internal/orgpolicy"
@ -104,11 +105,14 @@ func (p *TranscriptProcessor) runPostActions(
actions := policy.PostActions actions := policy.PostActions
if actions.LLMEnabled { if actions.LLMEnabled {
summary, err := p.summarize(ctx, policy, rawTranscript) summary, provider, model, usage, err := p.summarize(ctx, policy, rawTranscript)
if err != nil { if err != nil {
p.logger.Warn("llm summary failed", "error", err, "job_id", jobID) p.logger.Warn("llm summary failed", "error", err, "job_id", jobID)
} else if strings.TrimSpace(summary) != "" { } else if strings.TrimSpace(summary) != "" {
finalText = summary finalText = summary
if extID, err := ai.ResolveExternalIDByEmail(ctx, p.db, in.OrganizerEmail); err == nil && extID != "" {
ai.RecordFeatureUsage(ctx, p.db, extID, "ultimeet", model, provider, usage)
}
} }
} }
@ -145,16 +149,20 @@ func (p *TranscriptProcessor) runPostActions(
return err return err
} }
func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, error) { func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, llm.Provider, string, llm.UsageDetail, error) {
provider, model, err := p.resolveLLM(ctx, policy.PostActions.LLMProviderID) provider, model, err := p.resolveLLM(ctx, policy.PostActions.LLMProviderID)
if err != nil { if err != nil {
return "", err return "", llm.Provider{}, "", llm.UsageDetail{}, err
} }
prompt := strings.TrimSpace(policy.PostActions.LLMPrompt) prompt := strings.TrimSpace(policy.PostActions.LLMPrompt)
if prompt == "" { if prompt == "" {
prompt = "Résume cette réunion en français." prompt = "Résume cette réunion en français."
} }
return p.llm.Complete(ctx, provider, model, prompt, transcript) result, err := p.llm.CompleteWithUsage(ctx, provider, model, prompt, transcript)
if err != nil {
return "", provider, model, llm.UsageDetail{}, err
}
return result.Content, provider, result.Model, result.Usage, nil
} }
func (p *TranscriptProcessor) resolveLLM(ctx context.Context, providerID string) (llm.Provider, string, error) { func (p *TranscriptProcessor) resolveLLM(ctx context.Context, providerID string) (llm.Provider, string, error) {

View File

@ -7,6 +7,9 @@ import (
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/ai"
"github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/llm"
) )
@ -60,7 +63,7 @@ func parseEnrichedData(raw string) (*EnrichedContactData, error) {
return &data, nil return &data, nil
} }
func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) { func enrichWithLLMTimeout(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) {
enrichCtx, cancel := context.WithTimeout(ctx, timeout) enrichCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel() defer cancel()
@ -69,7 +72,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.
err error err error
}, 1) }, 1)
go func() { go func() {
data, err := enrichWithLLM(enrichCtx, client, settings, email, displayName, signatures) data, err := enrichWithLLM(enrichCtx, db, externalUserID, client, settings, email, displayName, signatures)
resultCh <- struct { resultCh <- struct {
data *EnrichedContactData data *EnrichedContactData
err error err error
@ -89,7 +92,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.
} }
} }
func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) { func enrichWithLLM(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) {
if client == nil || len(signatures) == 0 { if client == nil || len(signatures) == 0 {
return nil, fmt.Errorf("no signatures to enrich") return nil, fmt.Errorf("no signatures to enrich")
} }
@ -98,11 +101,12 @@ func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Setting
return nil, err return nil, err
} }
prompt := buildEnrichPrompt(email, displayName, signatures) prompt := buildEnrichPrompt(email, displayName, signatures)
raw, err := client.Complete(ctx, provider, model, enrichSystemPrompt, prompt) result, err := client.CompleteWithUsage(ctx, provider, model, enrichSystemPrompt, prompt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return parseEnrichedData(raw) ai.RecordFeatureUsage(ctx, db, externalUserID, "contact_discovery", result.Model, provider, result.Usage)
return parseEnrichedData(result.Content)
} }
func enrichedDataToSuggestions(userID, profileID string, data *EnrichedContactData) []Suggestion { func enrichedDataToSuggestions(userID, profileID string, data *EnrichedContactData) []Suggestion {

View File

@ -121,7 +121,7 @@ func (s *Service) runProfileEnrichment(externalUserID, profileID, ncUserID, book
} }
enriched, enrichErr := enrichWithLLMTimeout( enriched, enrichErr := enrichWithLLMTimeout(
ctx, s.llm, llmSettings, ctx, s.db, externalUserID, s.llm, llmSettings,
profile.PrimaryEmail, profile.DisplayName, sigs, llmEnrichTimeout, profile.PrimaryEmail, profile.DisplayName, sigs, llmEnrichTimeout,
) )
if enrichErr != nil { if enrichErr != nil {

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/ultisuite/ulti-backend/internal/ai"
"github.com/ultisuite/ulti-backend/internal/llm" "github.com/ultisuite/ulti-backend/internal/llm"
"github.com/ultisuite/ulti-backend/internal/websearch" "github.com/ultisuite/ulti-backend/internal/websearch"
) )
@ -120,11 +121,12 @@ func (s *Service) ImproveContact(ctx context.Context, externalUserID string, inp
searchSection := s.fetchContactSearchResults(improveCtx, externalUserID, input) searchSection := s.fetchContactSearchResults(improveCtx, externalUserID, input)
prompt := buildImproveContactPrompt(input, searchSection) prompt := buildImproveContactPrompt(input, searchSection)
raw, err := s.llm.Complete(improveCtx, provider, model, improveContactSystemPrompt, prompt) raw, err := s.llm.CompleteWithUsage(improveCtx, provider, model, improveContactSystemPrompt, prompt)
if err != nil { if err != nil {
return nil, err return nil, err
} }
data, err := parseEnrichedData(raw) ai.RecordFeatureUsage(ctx, s.db, externalUserID, "contact_discovery", raw.Model, provider, raw.Usage)
data, err := parseEnrichedData(raw.Content)
if err != nil { if err != nil {
return nil, fmt.Errorf("parse improved contact: %w", err) return nil, fmt.Errorf("parse improved contact: %w", err)
} }

View File

@ -334,7 +334,7 @@ func (s *Service) executeScan(ctx context.Context, externalUserID, ncUserID, boo
heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background()) heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background())
go s.enrichHeartbeat(heartbeatCtx, scanID, externalUserID, messagesScanned, enrichDone, totalMessages, enrichTotal) go s.enrichHeartbeat(heartbeatCtx, scanID, externalUserID, messagesScanned, enrichDone, totalMessages, enrichTotal)
enriched, enrichErr := enrichWithLLMTimeout(ctx, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout) enriched, enrichErr := enrichWithLLMTimeout(ctx, s.db, externalUserID, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout)
heartbeatCancel() heartbeatCancel()
if enrichErr != nil { if enrichErr != nil {
s.logger.Warn("llm enrichment failed", "email", email, "error", enrichErr) s.logger.Warn("llm enrichment failed", "email", email, "error", enrichErr)

View File

@ -44,6 +44,17 @@ type chatResponse struct {
Content string `json:"content"` Content string `json:"content"`
} `json:"message"` } `json:"message"`
} `json:"choices"` } `json:"choices"`
Usage *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
} `json:"usage,omitempty"`
Error *struct { Error *struct {
Message string `json:"message"` Message string `json:"message"`
} `json:"error,omitempty"` } `json:"error,omitempty"`
@ -67,16 +78,40 @@ func NewClient() *Client {
} }
func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) { func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) {
result, err := c.CompleteWithUsage(ctx, provider, model, systemPrompt, userPrompt)
if err != nil {
return "", err
}
return result.Content, nil
}
// CompletionResult holds LLM output and usage metadata.
type CompletionResult struct {
Content string
Model string
Usage UsageDetail
}
// UsageDetail mirrors ai/cost.UsageDetail for llm package consumers.
type UsageDetail struct {
PromptTokens int
CompletionTokens int
CachedInputTokens int
ReasoningTokens int
TotalTokens int
}
func (c *Client) CompleteWithUsage(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (CompletionResult, error) {
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/") baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
if baseURL == "" { if baseURL == "" {
return "", fmt.Errorf("llm provider base_url is required") return CompletionResult{}, fmt.Errorf("llm provider base_url is required")
} }
model = strings.TrimSpace(model) model = strings.TrimSpace(model)
if model == "" { if model == "" {
model = strings.TrimSpace(provider.DefaultModel) model = strings.TrimSpace(provider.DefaultModel)
} }
if model == "" { if model == "" {
return "", fmt.Errorf("llm model is required") return CompletionResult{}, fmt.Errorf("llm model is required")
} }
reqBody := chatRequest{ reqBody := chatRequest{
@ -89,13 +124,13 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP
} }
payload, err := json.Marshal(reqBody) payload, err := json.Marshal(reqBody)
if err != nil { if err != nil {
return "", err return CompletionResult{}, err
} }
url := baseURL + "/chat/completions" url := baseURL + "/chat/completions"
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if err != nil { if err != nil {
return "", err return CompletionResult{}, err
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
if strings.TrimSpace(provider.APIKey) != "" { if strings.TrimSpace(provider.APIKey) != "" {
@ -104,29 +139,68 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP
resp, err := c.http.Do(req) resp, err := c.http.Do(req)
if err != nil { if err != nil {
return "", err return CompletionResult{}, err
} }
defer resp.Body.Close() defer resp.Body.Close()
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if err != nil { if err != nil {
return "", err return CompletionResult{}, err
} }
if resp.StatusCode >= 400 { if resp.StatusCode >= 400 {
return "", fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body)) return CompletionResult{}, fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
} }
var parsed chatResponse var parsed chatResponse
if err := json.Unmarshal(body, &parsed); err != nil { if err := json.Unmarshal(body, &parsed); err != nil {
return "", err return CompletionResult{}, err
} }
if parsed.Error != nil && parsed.Error.Message != "" { if parsed.Error != nil && parsed.Error.Message != "" {
return "", fmt.Errorf("llm error: %s", parsed.Error.Message) return CompletionResult{}, fmt.Errorf("llm error: %s", parsed.Error.Message)
} }
if len(parsed.Choices) == 0 { if len(parsed.Choices) == 0 {
return "", fmt.Errorf("llm returned no choices") return CompletionResult{}, fmt.Errorf("llm returned no choices")
} }
return strings.TrimSpace(parsed.Choices[0].Message.Content), nil usage := parseUsageFromResponse(parsed.Usage)
return CompletionResult{
Content: strings.TrimSpace(parsed.Choices[0].Message.Content),
Model: model,
Usage: usage,
}, nil
}
func parseUsageFromResponse(u *struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *struct {
CachedTokens int `json:"cached_tokens"`
} `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *struct {
ReasoningTokens int `json:"reasoning_tokens"`
} `json:"completion_tokens_details,omitempty"`
}) UsageDetail {
if u == nil {
return UsageDetail{TotalTokens: 1}
}
d := UsageDetail{
PromptTokens: u.PromptTokens,
CompletionTokens: u.CompletionTokens,
TotalTokens: u.TotalTokens,
}
if u.PromptTokensDetails != nil {
d.CachedInputTokens = u.PromptTokensDetails.CachedTokens
}
if u.CompletionTokensDetails != nil {
d.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
}
if d.TotalTokens == 0 {
d.TotalTokens = d.PromptTokens + d.CompletionTokens
}
if d.TotalTokens == 0 {
d.TotalTokens = 1
}
return d
} }
func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) { func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) {

View File

@ -0,0 +1,9 @@
DROP TABLE IF EXISTS ai_cost_policies;
DROP TABLE IF EXISTS ai_org_usage_monthly;
DROP TABLE IF EXISTS ai_org_usage_daily;
ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_org;
ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_user;
ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_org;
ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_user;
DROP TABLE IF EXISTS ai_usage_events;
DROP TABLE IF EXISTS ai_model_pricing;

View File

@ -0,0 +1,96 @@
-- Model pricing (micro-EUR per 1M tokens)
CREATE TABLE IF NOT EXISTS ai_model_pricing (
model_id TEXT NOT NULL,
provider_type TEXT NOT NULL DEFAULT 'generic',
input_micro_eur_per_mtok BIGINT NOT NULL,
cached_input_micro_eur_per_mtok BIGINT,
output_micro_eur_per_mtok BIGINT NOT NULL,
reasoning_micro_eur_per_mtok BIGINT,
effective_from DATE NOT NULL DEFAULT CURRENT_DATE,
source TEXT NOT NULL DEFAULT 'manual',
PRIMARY KEY (model_id, effective_from)
);
-- Detailed usage ledger
CREATE TABLE IF NOT EXISTS ai_usage_events (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
feature TEXT NOT NULL,
model_id TEXT NOT NULL,
provider_id TEXT NOT NULL,
billing_scope TEXT NOT NULL CHECK (billing_scope IN ('org', 'user')),
provider_key_fingerprint TEXT NOT NULL DEFAULT '',
prompt_tokens INT NOT NULL DEFAULT 0,
completion_tokens INT NOT NULL DEFAULT 0,
cached_input_tokens INT NOT NULL DEFAULT 0,
reasoning_tokens INT NOT NULL DEFAULT 0,
cost_micro_eur BIGINT NOT NULL DEFAULT 0,
estimated BOOLEAN NOT NULL DEFAULT false,
request_id TEXT
);
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_user_time ON ai_usage_events(user_id, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_scope_key ON ai_usage_events(billing_scope, provider_key_fingerprint, created_at DESC);
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_created ON ai_usage_events(created_at DESC);
-- Extend daily rollups
ALTER TABLE ai_usage_daily
ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0;
-- Extend monthly rollups
ALTER TABLE ai_usage_monthly
ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0;
-- Org-wide aggregates
CREATE TABLE IF NOT EXISTS ai_org_usage_daily (
usage_date DATE NOT NULL PRIMARY KEY,
cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
cost_micro_eur_user BIGINT NOT NULL DEFAULT 0,
requests INT NOT NULL DEFAULT 0
);
CREATE TABLE IF NOT EXISTS ai_org_usage_monthly (
usage_month DATE NOT NULL PRIMARY KEY,
cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
cost_micro_eur_user BIGINT NOT NULL DEFAULT 0
);
-- Cost limit policies (org / group / user)
CREATE TABLE IF NOT EXISTS ai_cost_policies (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
scope_type TEXT NOT NULL CHECK (scope_type IN ('org', 'group', 'user')),
scope_id UUID,
daily_limit_micro_eur BIGINT,
monthly_limit_micro_eur BIGINT,
warn_threshold_pct INT NOT NULL DEFAULT 80,
priority INT NOT NULL DEFAULT 0,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_org ON ai_cost_policies(scope_type) WHERE scope_type = 'org';
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_group ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'group';
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_user ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'user';
-- Default org policy: ~10 EUR/day, ~100 EUR/month
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
SELECT 'org', NULL, 10000000, 100000000, 80, 0
WHERE NOT EXISTS (SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org');
-- Seed common model pricing (approximate public rates in EUR)
INSERT INTO ai_model_pricing (model_id, provider_type, input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok, output_micro_eur_per_mtok, source) VALUES
('gpt-4o-mini', 'openai', 140000, 70000, 560000, 'seed'),
('gpt-4o', 'openai', 2300000, 1150000, 9200000, 'seed'),
('gpt-4.1-mini', 'openai', 360000, 90000, 1440000, 'seed'),
('gpt-4.1', 'openai', 1800000, 450000, 7200000, 'seed'),
('o3-mini', 'openai', 990000, 250000, 3960000, 'seed'),
('claude-sonnet-4-6', 'anthropic', 2700000, 270000, 13500000, 'seed'),
('claude-haiku-4-5', 'anthropic', 900000, 90000, 4500000, 'seed'),
('mistral-small-latest', 'mistral', 180000, 90000, 540000, 'seed'),
('mistral-large-latest', 'mistral', 1800000, 450000, 5400000, 'seed'),
('gemini-2.0-flash', 'google_gemini', 90000, 23000, 360000, 'seed'),
('gemini-2.5-pro', 'google_gemini', 1100000, 280000, 9000000, 'seed')
ON CONFLICT DO NOTHING;

View File

@ -0,0 +1,42 @@
UPDATE ai_cost_policies
SET daily_limit_micro_eur = 10000000,
monthly_limit_micro_eur = 100000000,
updated_at = NOW()
WHERE scope_type = 'org'
AND daily_limit_micro_eur = 2000000
AND monthly_limit_micro_eur = 35000000;
UPDATE org_settings
SET settings = jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
COALESCE(settings, '{}'::jsonb),
'{usage_quotas,llm_daily_cost_limit_eur}',
'10'::jsonb
),
'{usage_quotas,llm_monthly_cost_limit_eur}',
'100'::jsonb
),
'{usage_quotas,llm_requests_per_day}',
'100'::jsonb
),
'{usage_quotas,llm_tokens_per_month}',
'500000'::jsonb
),
'{usage_quotas,search_requests_per_day}',
'50'::jsonb
),
'{usage_quotas,max_api_tokens_per_user}',
'10'::jsonb
),
'{usage_quotas,max_webhooks_per_user}',
'20'::jsonb
),
updated_at = NOW()
WHERE id = 1
AND COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 0) = 2
AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 0) = 35;

View File

@ -0,0 +1,56 @@
-- Align factory AI quotas with reasonable SME defaults (per user, org keys).
-- Only touch rows still at the previous factory values.
UPDATE ai_cost_policies
SET daily_limit_micro_eur = 2000000,
monthly_limit_micro_eur = 35000000,
updated_at = NOW()
WHERE scope_type = 'org'
AND daily_limit_micro_eur = 10000000
AND monthly_limit_micro_eur = 100000000;
UPDATE org_settings
SET settings = jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
jsonb_set(
COALESCE(settings, '{}'::jsonb),
'{usage_quotas,llm_daily_cost_limit_eur}',
'2'::jsonb
),
'{usage_quotas,llm_monthly_cost_limit_eur}',
'35'::jsonb
),
'{usage_quotas,llm_requests_per_day}',
'75'::jsonb
),
'{usage_quotas,llm_tokens_per_month}',
'2000000'::jsonb
),
'{usage_quotas,search_requests_per_day}',
'20'::jsonb
),
'{usage_quotas,max_api_tokens_per_user}',
'5'::jsonb
),
'{usage_quotas,max_webhooks_per_user}',
'5'::jsonb
),
updated_at = NOW()
WHERE id = 1
AND (
settings->'usage_quotas' IS NULL
OR settings->'usage_quotas' = 'null'::jsonb
OR (
COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 10) = 10
AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 100) = 100
AND COALESCE((settings->'usage_quotas'->>'llm_requests_per_day')::numeric, 100) = 100
AND COALESCE((settings->'usage_quotas'->>'llm_tokens_per_month')::numeric, 500000) = 500000
AND COALESCE((settings->'usage_quotas'->>'search_requests_per_day')::numeric, 50) = 50
AND COALESCE((settings->'usage_quotas'->>'max_api_tokens_per_user')::numeric, 10) = 10
AND COALESCE((settings->'usage_quotas'->>'max_webhooks_per_user')::numeric, 20) = 20
)
);