refactor(ai): update AI gateway and cost management features
- Refactored AI gateway to utilize new cost management structures for usage tracking. - Replaced deprecated token extraction methods with a unified cost parsing approach. - Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses. - Added new metering functionality to record AI usage and costs effectively. - Updated tests to reflect changes in usage parsing and cost calculations. - Introduced new API endpoints for retrieving AI usage summaries and pricing information.
This commit is contained in:
parent
71b716edba
commit
3978622050
18
internal/ai/cost/fingerprint.go
Normal file
18
internal/ai/cost/fingerprint.go
Normal file
@ -0,0 +1,18 @@
|
||||
package cost
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KeyFingerprint returns a stable short hash for grouping usage by API key.
|
||||
func KeyFingerprint(providerID, apiKey string) string {
|
||||
providerID = strings.TrimSpace(providerID)
|
||||
apiKey = strings.TrimSpace(apiKey)
|
||||
if apiKey == "" {
|
||||
return providerID + ":none"
|
||||
}
|
||||
sum := sha256.Sum256([]byte(providerID + ":" + apiKey))
|
||||
return hex.EncodeToString(sum[:8])
|
||||
}
|
||||
149
internal/ai/cost/meter.go
Normal file
149
internal/ai/cost/meter.go
Normal file
@ -0,0 +1,149 @@
|
||||
package cost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
const (
|
||||
BillingScopeOrg = "org"
|
||||
BillingScopeUser = "user"
|
||||
)
|
||||
|
||||
type RecordInput struct {
|
||||
ExternalUserID string
|
||||
Feature string
|
||||
ModelID string
|
||||
ProviderID string
|
||||
BillingScope string
|
||||
ProviderKeyFingerprint string
|
||||
Usage UsageDetail
|
||||
RequestID string
|
||||
}
|
||||
|
||||
type Meter struct {
|
||||
db *pgxpool.Pool
|
||||
pricing *PricingStore
|
||||
}
|
||||
|
||||
func NewMeter(db *pgxpool.Pool) *Meter {
|
||||
return &Meter{db: db, pricing: NewPricingStore(db)}
|
||||
}
|
||||
|
||||
func (m *Meter) RecordUsage(ctx context.Context, in RecordInput) error {
|
||||
if m.db == nil {
|
||||
return nil
|
||||
}
|
||||
userID, err := m.resolveUserID(ctx, in.ExternalUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
price, found := m.pricing.LookupPrice(ctx, in.ModelID)
|
||||
cost, estimated := ComputeCostMicroEUR(in.Usage, price, found)
|
||||
|
||||
tx, err := m.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO ai_usage_events (
|
||||
user_id, feature, model_id, provider_id, billing_scope,
|
||||
provider_key_fingerprint, prompt_tokens, completion_tokens,
|
||||
cached_input_tokens, reasoning_tokens, cost_micro_eur, estimated, request_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
`, userID, in.Feature, in.ModelID, in.ProviderID, in.BillingScope,
|
||||
in.ProviderKeyFingerprint, in.Usage.PromptTokens, in.Usage.CompletionTokens,
|
||||
in.Usage.CachedInputTokens, in.Usage.ReasoningTokens, cost, estimated, nullStr(in.RequestID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
tokens := int64(in.Usage.TotalTokens)
|
||||
orgCost := int64(0)
|
||||
userCost := int64(0)
|
||||
if in.BillingScope == BillingScopeOrg {
|
||||
orgCost = cost
|
||||
} else {
|
||||
userCost = cost
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens, cost_micro_eur_org, cost_micro_eur_user)
|
||||
VALUES ($1, $2, 1, $3, $4, $5)
|
||||
ON CONFLICT (user_id, usage_date) DO UPDATE SET
|
||||
requests = ai_usage_daily.requests + 1,
|
||||
tokens = ai_usage_daily.tokens + EXCLUDED.tokens,
|
||||
cost_micro_eur_org = ai_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
||||
cost_micro_eur_user = ai_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
||||
`, userID, today, tokens, orgCost, userCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens, cost_micro_eur_org, cost_micro_eur_user)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (user_id, usage_month) DO UPDATE SET
|
||||
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens,
|
||||
cost_micro_eur_org = ai_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
||||
cost_micro_eur_user = ai_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
||||
`, userID, month, tokens, orgCost, userCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO ai_org_usage_daily (usage_date, cost_micro_eur_org, cost_micro_eur_user, requests)
|
||||
VALUES ($1, $2, $3, 1)
|
||||
ON CONFLICT (usage_date) DO UPDATE SET
|
||||
cost_micro_eur_org = ai_org_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
||||
cost_micro_eur_user = ai_org_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user,
|
||||
requests = ai_org_usage_daily.requests + 1
|
||||
`, today, orgCost, userCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO ai_org_usage_monthly (usage_month, cost_micro_eur_org, cost_micro_eur_user)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (usage_month) DO UPDATE SET
|
||||
cost_micro_eur_org = ai_org_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
||||
cost_micro_eur_user = ai_org_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
||||
`, month, orgCost, userCost)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (m *Meter) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
||||
var userID string
|
||||
err := m.db.QueryRow(ctx, `
|
||||
SELECT id::text FROM users WHERE external_id = $1
|
||||
`, externalUserID).Scan(&userID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("user not found")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func nullStr(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
65
internal/ai/cost/parse.go
Normal file
65
internal/ai/cost/parse.go
Normal file
@ -0,0 +1,65 @@
|
||||
package cost
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type usagePayload struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details,omitempty"`
|
||||
}
|
||||
|
||||
type chatCompletionResponse struct {
|
||||
Usage *usagePayload `json:"usage,omitempty"`
|
||||
}
|
||||
|
||||
// ParseUsage extracts token details from an OpenAI-compatible chat completion payload.
|
||||
func ParseUsage(payload []byte) UsageDetail {
|
||||
var parsed chatCompletionResponse
|
||||
if err := json.Unmarshal(payload, &parsed); err != nil || parsed.Usage == nil {
|
||||
return UsageDetail{TotalTokens: 1}
|
||||
}
|
||||
return usageFromPayload(parsed.Usage)
|
||||
}
|
||||
|
||||
func usageFromPayload(u *usagePayload) UsageDetail {
|
||||
if u == nil {
|
||||
return UsageDetail{TotalTokens: 1}
|
||||
}
|
||||
detail := UsageDetail{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptTokensDetails != nil {
|
||||
detail.CachedInputTokens = u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil {
|
||||
detail.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.PromptTokens + detail.CompletionTokens
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = 1
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
// MergeStreamUsage keeps the latest non-zero usage from streaming chunks.
|
||||
func MergeStreamUsage(acc UsageDetail, chunk []byte) UsageDetail {
|
||||
var parsed chatCompletionResponse
|
||||
if err := json.Unmarshal(chunk, &parsed); err != nil || parsed.Usage == nil {
|
||||
return acc
|
||||
}
|
||||
next := usageFromPayload(parsed.Usage)
|
||||
if next.TotalTokens == 0 && next.PromptTokens == 0 && next.CompletionTokens == 0 {
|
||||
return acc
|
||||
}
|
||||
return next
|
||||
}
|
||||
72
internal/ai/cost/parse_test.go
Normal file
72
internal/ai/cost/parse_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package cost
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseUsageOpenAI(t *testing.T) {
|
||||
payload := []byte(`{
|
||||
"usage": {
|
||||
"prompt_tokens": 1200,
|
||||
"completion_tokens": 340,
|
||||
"total_tokens": 1540,
|
||||
"prompt_tokens_details": {"cached_tokens": 800},
|
||||
"completion_tokens_details": {"reasoning_tokens": 50}
|
||||
}
|
||||
}`)
|
||||
u := ParseUsage(payload)
|
||||
if u.PromptTokens != 1200 || u.CompletionTokens != 340 {
|
||||
t.Fatalf("unexpected tokens: %+v", u)
|
||||
}
|
||||
if u.CachedInputTokens != 800 || u.ReasoningTokens != 50 {
|
||||
t.Fatalf("unexpected details: %+v", u)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseUsageFallback(t *testing.T) {
|
||||
u := ParseUsage([]byte(`{"choices":[]}`))
|
||||
if u.TotalTokens != 1 {
|
||||
t.Fatalf("expected fallback 1, got %d", u.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeCostMicroEUR(t *testing.T) {
|
||||
price := ModelPrice{
|
||||
InputMicroEURPerMTok: 1_000_000,
|
||||
OutputMicroEURPerMTok: 2_000_000,
|
||||
}
|
||||
usage := UsageDetail{
|
||||
PromptTokens: 1000,
|
||||
CompletionTokens: 500,
|
||||
CachedInputTokens: 200,
|
||||
TotalTokens: 1500,
|
||||
}
|
||||
cost, estimated := ComputeCostMicroEUR(usage, price, true)
|
||||
if estimated {
|
||||
t.Fatal("should not be estimated when price found")
|
||||
}
|
||||
// uncached 800 * 1 + cached 200 * 0.5 + output 500 * 2 = 800+100+1000 = 1900 micro (cached uses half input)
|
||||
if cost < 1800 || cost > 2000 {
|
||||
t.Fatalf("unexpected cost %d", cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeCostUnknownModel(t *testing.T) {
|
||||
usage := UsageDetail{TotalTokens: 1_000_000}
|
||||
cost, estimated := ComputeCostMicroEUR(usage, ModelPrice{}, false)
|
||||
if !estimated {
|
||||
t.Fatal("expected estimated")
|
||||
}
|
||||
if cost != fallbackInputMicroEURPerMTok {
|
||||
t.Fatalf("expected fallback cost %d, got %d", fallbackInputMicroEURPerMTok, cost)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMergeStreamUsage(t *testing.T) {
|
||||
acc := UsageDetail{}
|
||||
chunk1 := []byte(`{"usage":{"prompt_tokens":10,"completion_tokens":0,"total_tokens":10}}`)
|
||||
chunk2 := []byte(`{"usage":{"prompt_tokens":100,"completion_tokens":40,"total_tokens":140}}`)
|
||||
acc = MergeStreamUsage(acc, chunk1)
|
||||
acc = MergeStreamUsage(acc, chunk2)
|
||||
if acc.TotalTokens != 140 {
|
||||
t.Fatalf("expected 140, got %d", acc.TotalTokens)
|
||||
}
|
||||
}
|
||||
421
internal/ai/cost/policy.go
Normal file
421
internal/ai/cost/policy.go
Normal file
@ -0,0 +1,421 @@
|
||||
package cost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
var ErrCostLimitExceeded = errors.New("llm cost limit exceeded")
|
||||
|
||||
type EffectiveLimits struct {
|
||||
DailyLimitMicroEUR *int64
|
||||
MonthlyLimitMicroEUR *int64
|
||||
WarnThresholdPct int
|
||||
}
|
||||
|
||||
type SpendStatus struct {
|
||||
CostUsedTodayMicroEUR int64 `json:"cost_used_today_micro_eur"`
|
||||
CostLimitTodayMicroEUR *int64 `json:"cost_limit_today_micro_eur,omitempty"`
|
||||
CostUsedMonthMicroEUR int64 `json:"cost_used_month_micro_eur"`
|
||||
CostLimitMonthMicroEUR *int64 `json:"cost_limit_month_micro_eur,omitempty"`
|
||||
CostRemainingTodayMicroEUR *int64 `json:"cost_remaining_today_micro_eur,omitempty"`
|
||||
CostRemainingMonthMicroEUR *int64 `json:"cost_remaining_month_micro_eur,omitempty"`
|
||||
WarnThresholdPct int `json:"warn_threshold_pct"`
|
||||
Currency string `json:"currency"`
|
||||
BillingScopeOrg bool `json:"billing_scope_org"`
|
||||
ByProviderKeys []ProviderKeySpend `json:"by_provider_keys,omitempty"`
|
||||
|
||||
// Legacy fields (deprecated)
|
||||
RequestsUsedToday int `json:"requests_used_today"`
|
||||
RequestsLimit int `json:"requests_limit"`
|
||||
TokensUsedMonth int64 `json:"tokens_used_month"`
|
||||
TokensLimit int64 `json:"tokens_limit"`
|
||||
RequestsRemaining int `json:"requests_remaining"`
|
||||
TokensRemaining int64 `json:"tokens_remaining"`
|
||||
}
|
||||
|
||||
type ProviderKeySpend struct {
|
||||
Fingerprint string `json:"fingerprint"`
|
||||
Label string `json:"label"`
|
||||
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
|
||||
CostMonthEUR float64 `json:"cost_month_eur"`
|
||||
BillingScope string `json:"billing_scope"`
|
||||
}
|
||||
|
||||
type PolicyService struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPolicyService(db *pgxpool.Pool) *PolicyService {
|
||||
return &PolicyService{db: db}
|
||||
}
|
||||
|
||||
type policyCandidate struct {
|
||||
daily *int64
|
||||
monthly *int64
|
||||
warn int
|
||||
priority int
|
||||
}
|
||||
|
||||
func (s *PolicyService) ResolveEffectiveLimits(ctx context.Context, userID string) (EffectiveLimits, error) {
|
||||
limits := EffectiveLimits{WarnThresholdPct: 80}
|
||||
if s.db == nil {
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
var candidates []policyCandidate
|
||||
|
||||
// Org policy (priority 0)
|
||||
org, err := s.loadPolicy(ctx, "org", "")
|
||||
if err == nil && org != nil {
|
||||
candidates = append(candidates, policyCandidate{
|
||||
daily: org.daily, monthly: org.monthly, warn: org.warn, priority: 0,
|
||||
})
|
||||
}
|
||||
|
||||
// Group policies (priority 10) — most restrictive wins among groups
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct
|
||||
FROM ai_cost_policies p
|
||||
JOIN user_group_members ugm ON ugm.group_id = p.scope_id
|
||||
WHERE p.scope_type = 'group' AND ugm.user_id = $1::uuid
|
||||
`, userID)
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var c policyCandidate
|
||||
c.priority = 10
|
||||
if err := rows.Scan(&c.daily, &c.monthly, &c.warn); err != nil {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, c)
|
||||
}
|
||||
}
|
||||
|
||||
// User policy (priority 20)
|
||||
userPol, err := s.loadPolicy(ctx, "user", userID)
|
||||
if err == nil && userPol != nil {
|
||||
candidates = append(candidates, policyCandidate{
|
||||
daily: userPol.daily, monthly: userPol.monthly, warn: userPol.warn, priority: 20,
|
||||
})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
// User override wins if present; else merge group+org with most restrictive
|
||||
var userCandidate *policyCandidate
|
||||
var others []policyCandidate
|
||||
for i := range candidates {
|
||||
if candidates[i].priority >= 20 {
|
||||
userCandidate = &candidates[i]
|
||||
} else {
|
||||
others = append(others, candidates[i])
|
||||
}
|
||||
}
|
||||
|
||||
var merged policyCandidate
|
||||
if userCandidate != nil {
|
||||
merged = *userCandidate
|
||||
} else {
|
||||
merged = mergeMostRestrictive(others)
|
||||
}
|
||||
|
||||
limits.DailyLimitMicroEUR = merged.daily
|
||||
limits.MonthlyLimitMicroEUR = merged.monthly
|
||||
if merged.warn > 0 {
|
||||
limits.WarnThresholdPct = merged.warn
|
||||
}
|
||||
return limits, nil
|
||||
}
|
||||
|
||||
type policyRow struct {
|
||||
daily *int64
|
||||
monthly *int64
|
||||
warn int
|
||||
}
|
||||
|
||||
func (s *PolicyService) loadPolicy(ctx context.Context, scopeType, scopeID string) (*policyRow, error) {
|
||||
var daily, monthly *int64
|
||||
var warn int
|
||||
var err error
|
||||
if scopeType == "org" {
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
|
||||
FROM ai_cost_policies WHERE scope_type = 'org' LIMIT 1
|
||||
`).Scan(&daily, &monthly, &warn)
|
||||
} else {
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct
|
||||
FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
|
||||
`, scopeType, scopeID).Scan(&daily, &monthly, &warn)
|
||||
}
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &policyRow{daily: daily, monthly: monthly, warn: warn}, nil
|
||||
}
|
||||
|
||||
func mergeMostRestrictive(candidates []policyCandidate) policyCandidate {
|
||||
if len(candidates) == 0 {
|
||||
return policyCandidate{warn: 80}
|
||||
}
|
||||
out := candidates[0]
|
||||
for _, c := range candidates[1:] {
|
||||
if c.daily != nil && (out.daily == nil || *c.daily < *out.daily) {
|
||||
out.daily = c.daily
|
||||
}
|
||||
if c.monthly != nil && (out.monthly == nil || *c.monthly < *out.monthly) {
|
||||
out.monthly = c.monthly
|
||||
}
|
||||
if c.warn > 0 && c.warn < out.warn {
|
||||
out.warn = c.warn
|
||||
}
|
||||
}
|
||||
if out.warn == 0 {
|
||||
out.warn = 80
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *PolicyService) GetStatus(ctx context.Context, externalUserID string, orgBilling bool) (SpendStatus, error) {
|
||||
userID, err := s.resolveUserID(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return SpendStatus{}, err
|
||||
}
|
||||
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
var dailyOrg, dailyUser int64
|
||||
var requestsToday int
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(requests, 0)
|
||||
FROM ai_usage_daily WHERE user_id = $1::uuid AND usage_date = $2
|
||||
`, userID, today).Scan(&dailyOrg, &dailyUser, &requestsToday)
|
||||
|
||||
var monthlyOrg, monthlyUser, tokensMonth int64
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0), COALESCE(tokens, 0)
|
||||
FROM ai_usage_monthly WHERE user_id = $1::uuid AND usage_month = $2
|
||||
`, userID, month).Scan(&monthlyOrg, &monthlyUser, &tokensMonth)
|
||||
|
||||
limits, _ := s.ResolveEffectiveLimits(ctx, userID)
|
||||
|
||||
status := SpendStatus{
|
||||
Currency: "EUR",
|
||||
BillingScopeOrg: orgBilling,
|
||||
WarnThresholdPct: limits.WarnThresholdPct,
|
||||
RequestsUsedToday: requestsToday,
|
||||
TokensUsedMonth: tokensMonth,
|
||||
}
|
||||
|
||||
if orgBilling {
|
||||
status.CostUsedTodayMicroEUR = dailyOrg
|
||||
status.CostUsedMonthMicroEUR = monthlyOrg
|
||||
status.CostLimitTodayMicroEUR = limits.DailyLimitMicroEUR
|
||||
status.CostLimitMonthMicroEUR = limits.MonthlyLimitMicroEUR
|
||||
} else {
|
||||
status.CostUsedTodayMicroEUR = dailyUser
|
||||
status.CostUsedMonthMicroEUR = monthlyUser
|
||||
}
|
||||
|
||||
status.CostRemainingTodayMicroEUR = remaining(limits.DailyLimitMicroEUR, status.CostUsedTodayMicroEUR)
|
||||
status.CostRemainingMonthMicroEUR = remaining(limits.MonthlyLimitMicroEUR, status.CostUsedMonthMicroEUR)
|
||||
|
||||
keys, _ := s.providerKeyBreakdown(ctx, userID, month)
|
||||
if keys == nil {
|
||||
keys = []ProviderKeySpend{}
|
||||
}
|
||||
status.ByProviderKeys = keys
|
||||
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *PolicyService) AssertAvailable(ctx context.Context, externalUserID string, billingScope string) error {
|
||||
if billingScope == BillingScopeUser {
|
||||
return nil
|
||||
}
|
||||
userID, err := s.resolveUserID(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
limits, err := s.ResolveEffectiveLimits(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
var dailyOrg, monthlyOrg int64
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_daily
|
||||
WHERE user_id = $1::uuid AND usage_date = $2
|
||||
`, userID, today).Scan(&dailyOrg)
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_usage_monthly
|
||||
WHERE user_id = $1::uuid AND usage_month = $2
|
||||
`, userID, month).Scan(&monthlyOrg)
|
||||
|
||||
if limits.DailyLimitMicroEUR != nil && *limits.DailyLimitMicroEUR > 0 && dailyOrg >= *limits.DailyLimitMicroEUR {
|
||||
return fmt.Errorf("%w: daily cost limit reached", ErrCostLimitExceeded)
|
||||
}
|
||||
if limits.MonthlyLimitMicroEUR != nil && *limits.MonthlyLimitMicroEUR > 0 && monthlyOrg >= *limits.MonthlyLimitMicroEUR {
|
||||
return fmt.Errorf("%w: monthly cost limit reached", ErrCostLimitExceeded)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PolicyService) providerKeyBreakdown(ctx context.Context, userID string, month time.Time) ([]ProviderKeySpend, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT provider_key_fingerprint, billing_scope,
|
||||
COALESCE(SUM(cost_micro_eur), 0),
|
||||
MAX(provider_id)
|
||||
FROM ai_usage_events
|
||||
WHERE user_id = $1::uuid AND created_at >= $2
|
||||
GROUP BY provider_key_fingerprint, billing_scope
|
||||
ORDER BY SUM(cost_micro_eur) DESC
|
||||
LIMIT 20
|
||||
`, userID, month)
|
||||
if err != nil {
|
||||
return []ProviderKeySpend{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
out := make([]ProviderKeySpend, 0)
|
||||
for rows.Next() {
|
||||
var item ProviderKeySpend
|
||||
var providerID string
|
||||
if err := rows.Scan(&item.Fingerprint, &item.BillingScope, &item.CostMonthMicroEUR, &providerID); err != nil {
|
||||
continue
|
||||
}
|
||||
item.CostMonthEUR = MicroEURToEUR(item.CostMonthMicroEUR)
|
||||
suffix := item.Fingerprint
|
||||
if len(suffix) > 8 {
|
||||
suffix = suffix[len(suffix)-8:]
|
||||
}
|
||||
item.Label = providerID + " ···" + suffix
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *PolicyService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
||||
var userID string
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id::text FROM users WHERE external_id = $1
|
||||
`, externalUserID).Scan(&userID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("user not found")
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func remaining(limit *int64, used int64) *int64 {
|
||||
if limit == nil || *limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
r := *limit - used
|
||||
if r < 0 {
|
||||
r = 0
|
||||
}
|
||||
return &r
|
||||
}
|
||||
|
||||
// UpsertOrgPolicy updates the org-level cost policy.
|
||||
func (s *PolicyService) UpsertOrgPolicy(ctx context.Context, daily, monthly *int64, warnPct int) error {
|
||||
if warnPct <= 0 {
|
||||
warnPct = 80
|
||||
}
|
||||
var exists bool
|
||||
if err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org')`).Scan(&exists); err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
_, err := s.db.Exec(ctx, `
|
||||
UPDATE ai_cost_policies SET
|
||||
daily_limit_micro_eur = $1,
|
||||
monthly_limit_micro_eur = $2,
|
||||
warn_threshold_pct = $3,
|
||||
updated_at = NOW()
|
||||
WHERE scope_type = 'org'
|
||||
`, daily, monthly, warnPct)
|
||||
return err
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
|
||||
VALUES ('org', NULL, $1, $2, $3, 0)
|
||||
`, daily, monthly, warnPct)
|
||||
return err
|
||||
}
|
||||
|
||||
// UpsertScopePolicy updates group or user policy.
|
||||
func (s *PolicyService) UpsertScopePolicy(ctx context.Context, scopeType, scopeID string, daily, monthly *int64, warnPct int) error {
|
||||
if warnPct <= 0 {
|
||||
warnPct = 80
|
||||
}
|
||||
priority := 10
|
||||
if scopeType == "user" {
|
||||
priority = 20
|
||||
}
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT EXISTS(SELECT 1 FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid)
|
||||
`, scopeType, scopeID).Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE ai_cost_policies SET
|
||||
daily_limit_micro_eur = $1,
|
||||
monthly_limit_micro_eur = $2,
|
||||
warn_threshold_pct = $3,
|
||||
updated_at = NOW()
|
||||
WHERE scope_type = $4 AND scope_id = $5::uuid
|
||||
`, daily, monthly, warnPct, scopeType, scopeID)
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
|
||||
VALUES ($1, $2::uuid, $3, $4, $5, $6)
|
||||
`, scopeType, scopeID, daily, monthly, warnPct, priority)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *PolicyService) DeleteScopePolicy(ctx context.Context, scopeType, scopeID string) error {
|
||||
if scopeType == "org" {
|
||||
return fmt.Errorf("cannot delete org policy")
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
DELETE FROM ai_cost_policies WHERE scope_type = $1 AND scope_id = $2::uuid
|
||||
`, scopeType, scopeID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *PolicyService) GetOrgPolicy(ctx context.Context) (EffectiveLimits, error) {
|
||||
row, err := s.loadPolicy(ctx, "org", "")
|
||||
if err != nil {
|
||||
return EffectiveLimits{WarnThresholdPct: 80}, err
|
||||
}
|
||||
if row == nil {
|
||||
return EffectiveLimits{WarnThresholdPct: 80}, nil
|
||||
}
|
||||
return EffectiveLimits{
|
||||
DailyLimitMicroEUR: row.daily,
|
||||
MonthlyLimitMicroEUR: row.monthly,
|
||||
WarnThresholdPct: row.warn,
|
||||
}, nil
|
||||
}
|
||||
31
internal/ai/cost/policy_test.go
Normal file
31
internal/ai/cost/policy_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package cost
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMergeMostRestrictive(t *testing.T) {
|
||||
d10 := int64(10_000_000)
|
||||
d5 := int64(5_000_000)
|
||||
m100 := int64(100_000_000)
|
||||
m50 := int64(50_000_000)
|
||||
merged := mergeMostRestrictive([]policyCandidate{
|
||||
{daily: &d10, monthly: &m100, warn: 80},
|
||||
{daily: &d5, monthly: &m50, warn: 70},
|
||||
})
|
||||
if merged.daily == nil || *merged.daily != d5 {
|
||||
t.Fatalf("expected daily 5M, got %v", merged.daily)
|
||||
}
|
||||
if merged.monthly == nil || *merged.monthly != m50 {
|
||||
t.Fatalf("expected monthly 50M, got %v", merged.monthly)
|
||||
}
|
||||
if merged.warn != 70 {
|
||||
t.Fatalf("expected warn 70, got %d", merged.warn)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicroEURToEUR(t *testing.T) {
|
||||
if v := MicroEURToEUR(1_500_000); v != 1.5 {
|
||||
t.Fatalf("expected 1.5, got %f", v)
|
||||
}
|
||||
}
|
||||
206
internal/ai/cost/pricing.go
Normal file
206
internal/ai/cost/pricing.go
Normal file
@ -0,0 +1,206 @@
|
||||
package cost
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// ModelPrice holds per-million-token rates in micro-EUR.
|
||||
type ModelPrice struct {
|
||||
ModelID string
|
||||
ProviderType string
|
||||
InputMicroEURPerMTok int64
|
||||
CachedInputMicroEURPerMTok int64
|
||||
OutputMicroEURPerMTok int64
|
||||
ReasoningMicroEURPerMTok int64
|
||||
}
|
||||
|
||||
// Default fallback when model is unknown (~gpt-4o-mini input rate).
|
||||
const fallbackInputMicroEURPerMTok int64 = 140000
|
||||
const fallbackOutputMicroEURPerMTok int64 = 560000
|
||||
|
||||
type PricingStore struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewPricingStore(db *pgxpool.Pool) *PricingStore {
|
||||
return &PricingStore{db: db}
|
||||
}
|
||||
|
||||
func (s *PricingStore) LookupPrice(ctx context.Context, modelID string) (ModelPrice, bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" || s.db == nil {
|
||||
return ModelPrice{}, false
|
||||
}
|
||||
var p ModelPrice
|
||||
var cached, reasoning *int64
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT model_id, provider_type,
|
||||
input_micro_eur_per_mtok,
|
||||
cached_input_micro_eur_per_mtok,
|
||||
output_micro_eur_per_mtok,
|
||||
reasoning_micro_eur_per_mtok
|
||||
FROM ai_model_pricing
|
||||
WHERE model_id = $1 AND effective_from <= CURRENT_DATE
|
||||
ORDER BY effective_from DESC
|
||||
LIMIT 1
|
||||
`, modelID).Scan(
|
||||
&p.ModelID, &p.ProviderType,
|
||||
&p.InputMicroEURPerMTok, &cached,
|
||||
&p.OutputMicroEURPerMTok, &reasoning,
|
||||
)
|
||||
if err != nil {
|
||||
if err != pgx.ErrNoRows {
|
||||
// prefix match for bedrock/azure model ids
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT model_id, provider_type,
|
||||
input_micro_eur_per_mtok,
|
||||
cached_input_micro_eur_per_mtok,
|
||||
output_micro_eur_per_mtok,
|
||||
reasoning_micro_eur_per_mtok
|
||||
FROM ai_model_pricing
|
||||
WHERE $1 LIKE model_id || '%' AND effective_from <= CURRENT_DATE
|
||||
ORDER BY LENGTH(model_id) DESC, effective_from DESC
|
||||
LIMIT 1
|
||||
`, modelID).Scan(
|
||||
&p.ModelID, &p.ProviderType,
|
||||
&p.InputMicroEURPerMTok, &cached,
|
||||
&p.OutputMicroEURPerMTok, &reasoning,
|
||||
)
|
||||
}
|
||||
if err != nil {
|
||||
return ModelPrice{}, false
|
||||
}
|
||||
}
|
||||
if cached != nil {
|
||||
p.CachedInputMicroEURPerMTok = *cached
|
||||
}
|
||||
if reasoning != nil {
|
||||
p.ReasoningMicroEURPerMTok = *reasoning
|
||||
}
|
||||
return p, true
|
||||
}
|
||||
|
||||
// ComputeCostMicroEUR calculates estimated cost from usage and model pricing.
|
||||
func ComputeCostMicroEUR(usage UsageDetail, price ModelPrice, found bool) (microEUR int64, estimated bool) {
|
||||
if !found {
|
||||
price = ModelPrice{
|
||||
InputMicroEURPerMTok: fallbackInputMicroEURPerMTok,
|
||||
OutputMicroEURPerMTok: fallbackOutputMicroEURPerMTok,
|
||||
}
|
||||
estimated = true
|
||||
if usage.PromptTokens == 0 && usage.CompletionTokens == 0 {
|
||||
microEUR = int64(usage.TotalTokens) * fallbackInputMicroEURPerMTok / 1_000_000
|
||||
if microEUR == 0 && usage.TotalTokens > 0 {
|
||||
microEUR = 1
|
||||
}
|
||||
return microEUR, true
|
||||
}
|
||||
}
|
||||
|
||||
cachedRate := price.CachedInputMicroEURPerMTok
|
||||
if cachedRate == 0 {
|
||||
cachedRate = price.InputMicroEURPerMTok / 2
|
||||
}
|
||||
reasoningRate := price.ReasoningMicroEURPerMTok
|
||||
if reasoningRate == 0 {
|
||||
reasoningRate = price.OutputMicroEURPerMTok
|
||||
}
|
||||
|
||||
uncached := usage.UncachedInputTokens()
|
||||
microEUR += int64(uncached) * price.InputMicroEURPerMTok / 1_000_000
|
||||
microEUR += int64(usage.CachedInputTokens) * cachedRate / 1_000_000
|
||||
completion := usage.CompletionTokens - usage.ReasoningTokens
|
||||
if completion < 0 {
|
||||
completion = usage.CompletionTokens
|
||||
}
|
||||
microEUR += int64(completion) * price.OutputMicroEURPerMTok / 1_000_000
|
||||
microEUR += int64(usage.ReasoningTokens) * reasoningRate / 1_000_000
|
||||
|
||||
if microEUR == 0 && usage.TotalTokens > 0 {
|
||||
microEUR = 1
|
||||
}
|
||||
return microEUR, estimated
|
||||
}
|
||||
|
||||
// UpsertModelPrice stores or updates pricing for a model (effective today).
|
||||
func (s *PricingStore) UpsertModelPrice(ctx context.Context, p ModelPrice) error {
|
||||
if s.db == nil {
|
||||
return nil
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
var cached, reasoning *int64
|
||||
if p.CachedInputMicroEURPerMTok > 0 {
|
||||
v := p.CachedInputMicroEURPerMTok
|
||||
cached = &v
|
||||
}
|
||||
if p.ReasoningMicroEURPerMTok > 0 {
|
||||
v := p.ReasoningMicroEURPerMTok
|
||||
reasoning = &v
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO ai_model_pricing (
|
||||
model_id, provider_type,
|
||||
input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok,
|
||||
output_micro_eur_per_mtok, reasoning_micro_eur_per_mtok,
|
||||
effective_from, source
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, 'manual')
|
||||
ON CONFLICT (model_id, effective_from) DO UPDATE SET
|
||||
provider_type = EXCLUDED.provider_type,
|
||||
input_micro_eur_per_mtok = EXCLUDED.input_micro_eur_per_mtok,
|
||||
cached_input_micro_eur_per_mtok = EXCLUDED.cached_input_micro_eur_per_mtok,
|
||||
output_micro_eur_per_mtok = EXCLUDED.output_micro_eur_per_mtok,
|
||||
reasoning_micro_eur_per_mtok = EXCLUDED.reasoning_micro_eur_per_mtok,
|
||||
source = EXCLUDED.source
|
||||
`, p.ModelID, p.ProviderType, p.InputMicroEURPerMTok, cached,
|
||||
p.OutputMicroEURPerMTok, reasoning, today)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *PricingStore) ListPrices(ctx context.Context) ([]ModelPrice, error) {
|
||||
if s.db == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT DISTINCT ON (model_id)
|
||||
model_id, provider_type,
|
||||
input_micro_eur_per_mtok,
|
||||
cached_input_micro_eur_per_mtok,
|
||||
output_micro_eur_per_mtok,
|
||||
reasoning_micro_eur_per_mtok
|
||||
FROM ai_model_pricing
|
||||
WHERE effective_from <= CURRENT_DATE
|
||||
ORDER BY model_id, effective_from DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []ModelPrice
|
||||
for rows.Next() {
|
||||
var p ModelPrice
|
||||
var cached, reasoning *int64
|
||||
if err := rows.Scan(&p.ModelID, &p.ProviderType,
|
||||
&p.InputMicroEURPerMTok, &cached,
|
||||
&p.OutputMicroEURPerMTok, &reasoning); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cached != nil {
|
||||
p.CachedInputMicroEURPerMTok = *cached
|
||||
}
|
||||
if reasoning != nil {
|
||||
p.ReasoningMicroEURPerMTok = *reasoning
|
||||
}
|
||||
out = append(out, p)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
// MicroEURToEUR converts micro-EUR to float EUR for API responses.
|
||||
func MicroEURToEUR(micro int64) float64 {
|
||||
return float64(micro) / 1_000_000
|
||||
}
|
||||
18
internal/ai/cost/usage.go
Normal file
18
internal/ai/cost/usage.go
Normal file
@ -0,0 +1,18 @@
|
||||
package cost
|
||||
|
||||
// UsageDetail holds token counts from a provider response.
|
||||
type UsageDetail struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
CachedInputTokens int
|
||||
ReasoningTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
func (u UsageDetail) UncachedInputTokens() int {
|
||||
uncached := u.PromptTokens - u.CachedInputTokens
|
||||
if uncached < 0 {
|
||||
return u.PromptTokens
|
||||
}
|
||||
return uncached
|
||||
}
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
)
|
||||
|
||||
@ -40,12 +41,6 @@ type chatCompletionRequest struct {
|
||||
Tools []any `json:"tools,omitempty"`
|
||||
}
|
||||
|
||||
type usagePayload struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
}
|
||||
|
||||
type chatCompletionResponse struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
@ -57,7 +52,17 @@ type chatCompletionResponse struct {
|
||||
FinishReason string `json:"finish_reason"`
|
||||
Delta *llm.ChatMessage `json:"delta,omitempty"`
|
||||
} `json:"choices"`
|
||||
Usage *usagePayload `json:"usage,omitempty"`
|
||||
Usage *struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
@ -116,12 +121,6 @@ func (g *Gateway) listModelsFromSettings(ctx context.Context, settings llm.Setti
|
||||
}
|
||||
|
||||
func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID string, useOrgSettings bool, body []byte, w http.ResponseWriter) error {
|
||||
if strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var modelProbe struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
@ -151,6 +150,14 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
|
||||
return err
|
||||
}
|
||||
|
||||
billingScope := ResolveBillingScope(ctx, g.db, quotaExternalUserID, provider, useOrgSettings)
|
||||
|
||||
if strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
if err := g.quota.AssertAvailable(ctx, quotaExternalUserID, provider, useOrgSettings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
upstreamBody, err := repairChatCompletionBody(body)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -182,7 +189,7 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
|
||||
defer resp.Body.Close()
|
||||
|
||||
if stream {
|
||||
return g.proxyStream(ctx, quotaExternalUserID, w, resp)
|
||||
return g.proxyStream(ctx, quotaExternalUserID, model, provider, billingScope, w, resp)
|
||||
}
|
||||
payload, err := io.ReadAll(io.LimitReader(resp.Body, 8<<20))
|
||||
if err != nil {
|
||||
@ -195,13 +202,21 @@ func (g *Gateway) ProxyChatCompletions(ctx context.Context, quotaExternalUserID
|
||||
return nil
|
||||
}
|
||||
if strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
tokens := extractUsageTokens(payload)
|
||||
_ = g.quota.Record(ctx, quotaExternalUserID, tokens)
|
||||
usage := cost.ParseUsage(payload)
|
||||
_ = g.quota.RecordUsage(ctx, cost.RecordInput{
|
||||
ExternalUserID: quotaExternalUserID,
|
||||
Feature: "gateway",
|
||||
ModelID: model,
|
||||
ProviderID: provider.ID,
|
||||
BillingScope: billingScope,
|
||||
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
|
||||
Usage: usage,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w http.ResponseWriter, resp *http.Response) error {
|
||||
func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID, model string, provider llm.Provider, billingScope string, w http.ResponseWriter, resp *http.Response) error {
|
||||
rc := http.NewResponseController(w)
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
@ -210,7 +225,7 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
|
||||
w.WriteHeader(resp.StatusCode)
|
||||
|
||||
reader := bufio.NewReader(resp.Body)
|
||||
var totalTokens int64
|
||||
var usage cost.UsageDetail
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
if len(line) > 0 {
|
||||
@ -219,7 +234,8 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
|
||||
return fmt.Errorf("streaming not supported: %w", err)
|
||||
}
|
||||
if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
|
||||
totalTokens += extractStreamUsageTokens([]byte(strings.TrimPrefix(strings.TrimSpace(line), "data: ")))
|
||||
chunk := []byte(strings.TrimPrefix(strings.TrimSpace(line), "data: "))
|
||||
usage = cost.MergeStreamUsage(usage, chunk)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@ -235,10 +251,18 @@ func (g *Gateway) proxyStream(ctx context.Context, quotaExternalUserID string, w
|
||||
}
|
||||
}
|
||||
if resp.StatusCode < 400 && strings.TrimSpace(quotaExternalUserID) != "" {
|
||||
if totalTokens == 0 {
|
||||
totalTokens = 1
|
||||
if usage.TotalTokens == 0 {
|
||||
usage.TotalTokens = 1
|
||||
}
|
||||
_ = g.quota.Record(ctx, quotaExternalUserID, totalTokens)
|
||||
_ = g.quota.RecordUsage(ctx, cost.RecordInput{
|
||||
ExternalUserID: quotaExternalUserID,
|
||||
Feature: "gateway",
|
||||
ModelID: model,
|
||||
ProviderID: provider.ID,
|
||||
BillingScope: billingScope,
|
||||
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
|
||||
Usage: usage,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -262,31 +286,6 @@ func resolveProviderForModel(settings llm.Settings, model string) (llm.Provider,
|
||||
return provider, resolvedModel, nil
|
||||
}
|
||||
|
||||
func extractUsageTokens(payload []byte) int64 {
|
||||
var parsed chatCompletionResponse
|
||||
if err := json.Unmarshal(payload, &parsed); err != nil {
|
||||
return 1
|
||||
}
|
||||
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
|
||||
return int64(parsed.Usage.TotalTokens)
|
||||
}
|
||||
if parsed.Usage != nil && parsed.Usage.CompletionTokens > 0 {
|
||||
return int64(parsed.Usage.CompletionTokens)
|
||||
}
|
||||
return 1
|
||||
}
|
||||
|
||||
func extractStreamUsageTokens(payload []byte) int64 {
|
||||
var parsed chatCompletionResponse
|
||||
if err := json.Unmarshal(payload, &parsed); err != nil {
|
||||
return 0
|
||||
}
|
||||
if parsed.Usage != nil && parsed.Usage.TotalTokens > 0 {
|
||||
return int64(parsed.Usage.TotalTokens)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func NowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
|
||||
@ -1,16 +1,22 @@
|
||||
package ai
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
func TestExtractUsageTokens(t *testing.T) {
|
||||
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
||||
)
|
||||
|
||||
func TestParseUsageViaCost(t *testing.T) {
|
||||
payload := []byte(`{"usage":{"total_tokens":42,"completion_tokens":10}}`)
|
||||
if got := extractUsageTokens(payload); got != 42 {
|
||||
t.Fatalf("extractUsageTokens() = %d, want 42", got)
|
||||
u := cost.ParseUsage(payload)
|
||||
if u.TotalTokens != 42 {
|
||||
t.Fatalf("ParseUsage() = %d, want 42", u.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractUsageTokensFallback(t *testing.T) {
|
||||
if got := extractUsageTokens([]byte(`{"choices":[]}`)); got != 1 {
|
||||
t.Fatalf("expected fallback token count 1, got %d", got)
|
||||
func TestParseUsageFallback(t *testing.T) {
|
||||
u := cost.ParseUsage([]byte(`{"choices":[]}`))
|
||||
if u.TotalTokens != 1 {
|
||||
t.Fatalf("expected fallback token count 1, got %d", u.TotalTokens)
|
||||
}
|
||||
}
|
||||
|
||||
34
internal/ai/metering.go
Normal file
34
internal/ai/metering.go
Normal file
@ -0,0 +1,34 @@
|
||||
package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
)
|
||||
|
||||
// RecordFeatureUsage meters an LLM call from a non-gateway feature.
|
||||
func RecordFeatureUsage(ctx context.Context, db *pgxpool.Pool, externalUserID, feature, modelID string, provider llm.Provider, usage llm.UsageDetail) {
|
||||
if db == nil || externalUserID == "" {
|
||||
return
|
||||
}
|
||||
q := NewQuotaService(db)
|
||||
scope := ResolveBillingScope(ctx, db, externalUserID, provider, false)
|
||||
_ = q.RecordUsage(ctx, cost.RecordInput{
|
||||
ExternalUserID: externalUserID,
|
||||
Feature: feature,
|
||||
ModelID: modelID,
|
||||
ProviderID: provider.ID,
|
||||
BillingScope: scope,
|
||||
ProviderKeyFingerprint: cost.KeyFingerprint(provider.ID, provider.APIKey),
|
||||
Usage: cost.UsageDetail{
|
||||
PromptTokens: usage.PromptTokens,
|
||||
CompletionTokens: usage.CompletionTokens,
|
||||
CachedInputTokens: usage.CachedInputTokens,
|
||||
ReasoningTokens: usage.ReasoningTokens,
|
||||
TotalTokens: usage.TotalTokens,
|
||||
},
|
||||
})
|
||||
}
|
||||
@ -229,7 +229,7 @@ func ResolveDefaultModel(ctx context.Context, db *pgxpool.Pool, policy Assistant
|
||||
}
|
||||
|
||||
func LoadQuotaLimits(ctx context.Context, db *pgxpool.Pool) (QuotaLimits, error) {
|
||||
defaults := QuotaLimits{RequestsPerDay: 100, TokensPerMonth: 500_000}
|
||||
defaults := QuotaLimits{RequestsPerDay: 75, TokensPerMonth: 2_000_000}
|
||||
if db == nil {
|
||||
return defaults, nil
|
||||
}
|
||||
|
||||
@ -2,123 +2,95 @@ package ai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
)
|
||||
|
||||
var ErrQuotaExceeded = errors.New("llm quota exceeded")
|
||||
|
||||
// QuotaService wraps cost policy and metering for backward compatibility.
|
||||
type QuotaService struct {
|
||||
db *pgxpool.Pool
|
||||
db *pgxpool.Pool
|
||||
policy *cost.PolicyService
|
||||
meter *cost.Meter
|
||||
}
|
||||
|
||||
func NewQuotaService(db *pgxpool.Pool) *QuotaService {
|
||||
return &QuotaService{db: db}
|
||||
return &QuotaService{
|
||||
db: db,
|
||||
policy: cost.NewPolicyService(db),
|
||||
meter: cost.NewMeter(db),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *QuotaService) Check(ctx context.Context, externalUserID string) (QuotaStatus, error) {
|
||||
limits, err := LoadQuotaLimits(ctx, s.db)
|
||||
func (s *QuotaService) Check(ctx context.Context, externalUserID string) (SpendStatus, error) {
|
||||
orgBilling, _ := s.usesOrgBilling(ctx, externalUserID)
|
||||
status, err := s.policy.GetStatus(ctx, externalUserID, orgBilling)
|
||||
if err != nil {
|
||||
return QuotaStatus{}, err
|
||||
}
|
||||
userID, err := s.resolveUserID(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return QuotaStatus{}, err
|
||||
}
|
||||
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
var requestsToday int
|
||||
var tokensMonth int64
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(requests, 0) FROM ai_usage_daily
|
||||
WHERE user_id = $1 AND usage_date = $2
|
||||
`, userID, today).Scan(&requestsToday)
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(tokens, 0) FROM ai_usage_monthly
|
||||
WHERE user_id = $1 AND usage_month = $2
|
||||
`, userID, month).Scan(&tokensMonth)
|
||||
|
||||
status := QuotaStatus{
|
||||
RequestsUsedToday: requestsToday,
|
||||
RequestsLimit: limits.RequestsPerDay,
|
||||
TokensUsedMonth: tokensMonth,
|
||||
TokensLimit: limits.TokensPerMonth,
|
||||
}
|
||||
if limits.RequestsPerDay > 0 {
|
||||
status.RequestsRemaining = limits.RequestsPerDay - requestsToday
|
||||
if status.RequestsRemaining < 0 {
|
||||
status.RequestsRemaining = 0
|
||||
}
|
||||
}
|
||||
if limits.TokensPerMonth > 0 {
|
||||
status.TokensRemaining = limits.TokensPerMonth - tokensMonth
|
||||
if status.TokensRemaining < 0 {
|
||||
status.TokensRemaining = 0
|
||||
}
|
||||
return SpendStatus{}, err
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string) error {
|
||||
status, err := s.Check(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if status.RequestsLimit > 0 && status.RequestsUsedToday >= status.RequestsLimit {
|
||||
return fmt.Errorf("%w: daily request limit reached", ErrQuotaExceeded)
|
||||
}
|
||||
if status.TokensLimit > 0 && status.TokensUsedMonth >= status.TokensLimit {
|
||||
return fmt.Errorf("%w: monthly token limit reached", ErrQuotaExceeded)
|
||||
}
|
||||
return nil
|
||||
func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string, provider llm.Provider, useOrgSettings bool) error {
|
||||
scope := ResolveBillingScope(ctx, s.db, externalUserID, provider, useOrgSettings)
|
||||
return s.policy.AssertAvailable(ctx, externalUserID, scope)
|
||||
}
|
||||
|
||||
func (s *QuotaService) Record(ctx context.Context, externalUserID string, tokens int64) error {
|
||||
if tokens < 0 {
|
||||
tokens = 0
|
||||
}
|
||||
userID, err := s.resolveUserID(ctx, externalUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens)
|
||||
VALUES ($1, $2, 1, $3)
|
||||
ON CONFLICT (user_id, usage_date) DO UPDATE SET
|
||||
requests = ai_usage_daily.requests + 1,
|
||||
tokens = ai_usage_daily.tokens + EXCLUDED.tokens
|
||||
`, userID, today, tokens)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, usage_month) DO UPDATE SET
|
||||
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens
|
||||
`, userID, month, tokens)
|
||||
return err
|
||||
func (s *QuotaService) RecordUsage(ctx context.Context, in cost.RecordInput) error {
|
||||
return s.meter.RecordUsage(ctx, in)
|
||||
}
|
||||
|
||||
func (s *QuotaService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
||||
var userID string
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id::text FROM users WHERE external_id = $1
|
||||
`, externalUserID).Scan(&userID)
|
||||
func (s *QuotaService) usesOrgBilling(ctx context.Context, externalUserID string) (bool, error) {
|
||||
settings, err := LoadEffectiveLLMSettings(ctx, s.db, externalUserID)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return "", fmt.Errorf("user not found")
|
||||
return true, err
|
||||
}
|
||||
org, err := loadOrgLLMPolicy(ctx, s.db)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
if org.EnforceOrgProviders || len(settings.Providers) == 0 {
|
||||
return true, nil
|
||||
}
|
||||
user, err := loadUserLLMSettings(ctx, s.db, externalUserID)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
return len(user.Providers) == 0, nil
|
||||
}
|
||||
|
||||
// ResolveBillingScope determines whether usage is billed to org or user's own key.
|
||||
func ResolveBillingScope(ctx context.Context, db *pgxpool.Pool, externalUserID string, provider llm.Provider, useOrgSettings bool) string {
|
||||
if useOrgSettings {
|
||||
return cost.BillingScopeOrg
|
||||
}
|
||||
org, err := loadOrgLLMPolicy(ctx, db)
|
||||
if err != nil || org.EnforceOrgProviders {
|
||||
return cost.BillingScopeOrg
|
||||
}
|
||||
user, err := loadUserLLMSettings(ctx, db, externalUserID)
|
||||
if err != nil || len(user.Providers) == 0 {
|
||||
return cost.BillingScopeOrg
|
||||
}
|
||||
apiKey := strings.TrimSpace(provider.APIKey)
|
||||
for _, op := range org.Providers {
|
||||
if op.ID == provider.ID && strings.TrimSpace(op.APIKey) == apiKey && apiKey != "" {
|
||||
return cost.BillingScopeOrg
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return userID, nil
|
||||
for _, up := range user.Providers {
|
||||
if up.ID == provider.ID && strings.TrimSpace(up.APIKey) != "" {
|
||||
return cost.BillingScopeUser
|
||||
}
|
||||
}
|
||||
return cost.BillingScopeOrg
|
||||
}
|
||||
|
||||
// SpendStatus is the user-facing quota/spend response.
|
||||
type SpendStatus = cost.SpendStatus
|
||||
|
||||
// ErrQuotaExceeded aliases cost limit error for backward compatibility.
|
||||
var ErrQuotaExceeded = cost.ErrCostLimitExceeded
|
||||
|
||||
@ -2,20 +2,16 @@ package ai
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestQuotaStatusRemaining(t *testing.T) {
|
||||
status := QuotaStatus{
|
||||
RequestsUsedToday: 40,
|
||||
RequestsLimit: 100,
|
||||
TokensUsedMonth: 100_000,
|
||||
TokensLimit: 500_000,
|
||||
func TestSpendStatusCostRemaining(t *testing.T) {
|
||||
limit := int64(10_000_000)
|
||||
status := SpendStatus{
|
||||
CostUsedTodayMicroEUR: 4_000_000,
|
||||
CostLimitTodayMicroEUR: &limit,
|
||||
Currency: "EUR",
|
||||
}
|
||||
status.RequestsRemaining = status.RequestsLimit - status.RequestsUsedToday
|
||||
status.TokensRemaining = status.TokensLimit - status.TokensUsedMonth
|
||||
if status.RequestsRemaining != 60 {
|
||||
t.Fatalf("requests remaining = %d", status.RequestsRemaining)
|
||||
}
|
||||
if status.TokensRemaining != 400_000 {
|
||||
t.Fatalf("tokens remaining = %d", status.TokensRemaining)
|
||||
remaining := *status.CostLimitTodayMicroEUR - status.CostUsedTodayMicroEUR
|
||||
if remaining != 6_000_000 {
|
||||
t.Fatalf("cost remaining = %d", remaining)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -21,19 +21,14 @@ type AssistantPolicy struct {
|
||||
}
|
||||
|
||||
type QuotaLimits struct {
|
||||
DailyLimitMicroEUR *int64 `json:"llm_daily_cost_limit_micro_eur,omitempty"`
|
||||
MonthlyLimitMicroEUR *int64 `json:"llm_monthly_cost_limit_micro_eur,omitempty"`
|
||||
WarnThresholdPct int `json:"llm_cost_warn_threshold_pct,omitempty"`
|
||||
// Deprecated legacy fields
|
||||
RequestsPerDay int `json:"llm_requests_per_day"`
|
||||
TokensPerMonth int64 `json:"llm_tokens_per_month"`
|
||||
}
|
||||
|
||||
type QuotaStatus struct {
|
||||
RequestsUsedToday int `json:"requests_used_today"`
|
||||
RequestsLimit int `json:"requests_limit"`
|
||||
TokensUsedMonth int64 `json:"tokens_used_month"`
|
||||
TokensLimit int64 `json:"tokens_limit"`
|
||||
RequestsRemaining int `json:"requests_remaining"`
|
||||
TokensRemaining int64 `json:"tokens_remaining"`
|
||||
}
|
||||
|
||||
type ChatMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content,omitempty"`
|
||||
|
||||
384
internal/api/admin/ai_usage.go
Normal file
384
internal/api/admin/ai_usage.go
Normal file
@ -0,0 +1,384 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai/cost"
|
||||
)
|
||||
|
||||
type AIUsageSummary struct {
|
||||
CostTodayMicroEUR int64 `json:"cost_today_micro_eur"`
|
||||
CostMonthMicroEUR int64 `json:"cost_month_micro_eur"`
|
||||
CostTodayEUR float64 `json:"cost_today_eur"`
|
||||
CostMonthEUR float64 `json:"cost_month_eur"`
|
||||
Currency string `json:"currency"`
|
||||
DailySeries []AIUsageDayPoint `json:"daily_series"`
|
||||
TopUsers []AIUsageTopUser `json:"top_users"`
|
||||
TopModels []AIUsageTopModel `json:"top_models"`
|
||||
OrgPolicy cost.EffectiveLimits `json:"org_policy"`
|
||||
}
|
||||
|
||||
type AIUsageDayPoint struct {
|
||||
Date string `json:"date"`
|
||||
CostOrgMicroEUR int64 `json:"cost_org_micro_eur"`
|
||||
CostUserMicroEUR int64 `json:"cost_user_micro_eur"`
|
||||
CostOrgEUR float64 `json:"cost_org_eur"`
|
||||
Requests int `json:"requests"`
|
||||
}
|
||||
|
||||
type AIUsageTopUser struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display_name"`
|
||||
CostOrgMicroEUR int64 `json:"cost_org_micro_eur"`
|
||||
CostOrgEUR float64 `json:"cost_org_eur"`
|
||||
}
|
||||
|
||||
type AIUsageTopModel struct {
|
||||
ModelID string `json:"model_id"`
|
||||
CostMicroEUR int64 `json:"cost_micro_eur"`
|
||||
CostEUR float64 `json:"cost_eur"`
|
||||
RequestCount int `json:"request_count"`
|
||||
}
|
||||
|
||||
type AIUserUsageDetail struct {
|
||||
UserID string `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Summary AIUsageSummary `json:"summary"`
|
||||
Events []AIUsageEventItem `json:"events"`
|
||||
EventsTotal int `json:"events_total"`
|
||||
}
|
||||
|
||||
type AIUsageEventItem struct {
|
||||
ID string `json:"id"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Feature string `json:"feature"`
|
||||
ModelID string `json:"model_id"`
|
||||
ProviderID string `json:"provider_id"`
|
||||
BillingScope string `json:"billing_scope"`
|
||||
CostMicroEUR int64 `json:"cost_micro_eur"`
|
||||
CostEUR float64 `json:"cost_eur"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
CachedInputTokens int `json:"cached_input_tokens"`
|
||||
Estimated bool `json:"estimated"`
|
||||
}
|
||||
|
||||
type AIPricingEntry struct {
|
||||
ModelID string `json:"model_id"`
|
||||
ProviderType string `json:"provider_type"`
|
||||
InputMicroEURPerMTok int64 `json:"input_micro_eur_per_mtok"`
|
||||
CachedInputMicroEURPerMTok int64 `json:"cached_input_micro_eur_per_mtok,omitempty"`
|
||||
OutputMicroEURPerMTok int64 `json:"output_micro_eur_per_mtok"`
|
||||
ReasoningMicroEURPerMTok int64 `json:"reasoning_micro_eur_per_mtok,omitempty"`
|
||||
InputEURPerMTok float64 `json:"input_eur_per_mtok"`
|
||||
OutputEURPerMTok float64 `json:"output_eur_per_mtok"`
|
||||
}
|
||||
|
||||
type AICostPolicyEntry struct {
|
||||
ScopeType string `json:"scope_type"`
|
||||
ScopeID *string `json:"scope_id,omitempty"`
|
||||
ScopeName string `json:"scope_name,omitempty"`
|
||||
DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"`
|
||||
MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"`
|
||||
DailyLimitEUR *float64 `json:"daily_limit_eur,omitempty"`
|
||||
MonthlyLimitEUR *float64 `json:"monthly_limit_eur,omitempty"`
|
||||
WarnThresholdPct int `json:"warn_threshold_pct"`
|
||||
}
|
||||
|
||||
type putAIPricingRequest struct {
|
||||
Prices []AIPricingEntry `json:"prices"`
|
||||
}
|
||||
|
||||
type putAICostPolicyRequest struct {
|
||||
DailyLimitEUR *float64 `json:"daily_limit_eur"`
|
||||
MonthlyLimitEUR *float64 `json:"monthly_limit_eur"`
|
||||
WarnThresholdPct int `json:"warn_threshold_pct"`
|
||||
DailyLimitMicroEUR *int64 `json:"daily_limit_micro_eur,omitempty"`
|
||||
MonthlyLimitMicroEUR *int64 `json:"monthly_limit_micro_eur,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) GetAIUsageSummary(ctx context.Context, billingScope string) (AIUsageSummary, error) {
|
||||
today := time.Now().UTC().Truncate(24 * time.Hour)
|
||||
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
out := AIUsageSummary{
|
||||
Currency: "EUR",
|
||||
DailySeries: []AIUsageDayPoint{},
|
||||
TopUsers: []AIUsageTopUser{},
|
||||
TopModels: []AIUsageTopModel{},
|
||||
}
|
||||
policySvc := cost.NewPolicyService(s.db)
|
||||
orgPolicy, _ := policySvc.GetOrgPolicy(ctx)
|
||||
out.OrgPolicy = orgPolicy
|
||||
|
||||
var dailyOrg, dailyUser int64
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0), COALESCE(cost_micro_eur_user, 0)
|
||||
FROM ai_org_usage_daily WHERE usage_date = $1
|
||||
`, today).Scan(&dailyOrg, &dailyUser)
|
||||
|
||||
var monthlyOrg int64
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(cost_micro_eur_org, 0) FROM ai_org_usage_monthly WHERE usage_month = $1
|
||||
`, month).Scan(&monthlyOrg)
|
||||
|
||||
if billingScope == "user" {
|
||||
out.CostTodayMicroEUR = dailyUser
|
||||
} else {
|
||||
out.CostTodayMicroEUR = dailyOrg
|
||||
}
|
||||
out.CostMonthMicroEUR = monthlyOrg
|
||||
out.CostTodayEUR = cost.MicroEURToEUR(out.CostTodayMicroEUR)
|
||||
out.CostMonthEUR = cost.MicroEURToEUR(out.CostMonthMicroEUR)
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT usage_date, cost_micro_eur_org, cost_micro_eur_user, requests
|
||||
FROM ai_org_usage_daily
|
||||
WHERE usage_date >= $1
|
||||
ORDER BY usage_date ASC
|
||||
`, today.AddDate(0, 0, -29))
|
||||
if err == nil {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var pt AIUsageDayPoint
|
||||
var d time.Time
|
||||
if err := rows.Scan(&d, &pt.CostOrgMicroEUR, &pt.CostUserMicroEUR, &pt.Requests); err != nil {
|
||||
continue
|
||||
}
|
||||
pt.Date = d.Format("2006-01-02")
|
||||
pt.CostOrgEUR = cost.MicroEURToEUR(pt.CostOrgMicroEUR)
|
||||
out.DailySeries = append(out.DailySeries, pt)
|
||||
}
|
||||
}
|
||||
|
||||
topUserRows, err := s.db.Query(ctx, `
|
||||
SELECT u.id::text, COALESCE(u.email, ''), COALESCE(u.display_name, ''),
|
||||
COALESCE(SUM(e.cost_micro_eur), 0)
|
||||
FROM ai_usage_events e
|
||||
JOIN users u ON u.id = e.user_id
|
||||
WHERE e.created_at >= $1 AND e.billing_scope = 'org'
|
||||
GROUP BY u.id, u.email, u.display_name
|
||||
ORDER BY SUM(e.cost_micro_eur) DESC
|
||||
LIMIT 10
|
||||
`, month)
|
||||
if err == nil {
|
||||
defer topUserRows.Close()
|
||||
for topUserRows.Next() {
|
||||
var u AIUsageTopUser
|
||||
if err := topUserRows.Scan(&u.UserID, &u.Email, &u.DisplayName, &u.CostOrgMicroEUR); err != nil {
|
||||
continue
|
||||
}
|
||||
u.CostOrgEUR = cost.MicroEURToEUR(u.CostOrgMicroEUR)
|
||||
out.TopUsers = append(out.TopUsers, u)
|
||||
}
|
||||
}
|
||||
|
||||
topModelRows, err := s.db.Query(ctx, `
|
||||
SELECT model_id, COALESCE(SUM(cost_micro_eur), 0), COUNT(*)
|
||||
FROM ai_usage_events
|
||||
WHERE created_at >= $1 AND billing_scope = 'org'
|
||||
GROUP BY model_id
|
||||
ORDER BY SUM(cost_micro_eur) DESC
|
||||
LIMIT 10
|
||||
`, month)
|
||||
if err == nil {
|
||||
defer topModelRows.Close()
|
||||
for topModelRows.Next() {
|
||||
var m AIUsageTopModel
|
||||
if err := topModelRows.Scan(&m.ModelID, &m.CostMicroEUR, &m.RequestCount); err != nil {
|
||||
continue
|
||||
}
|
||||
m.CostEUR = cost.MicroEURToEUR(m.CostMicroEUR)
|
||||
out.TopModels = append(out.TopModels, m)
|
||||
}
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetUserAIUsage(ctx context.Context, userID string, limit, offset int) (AIUserUsageDetail, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
var detail AIUserUsageDetail
|
||||
detail.Events = []AIUsageEventItem{}
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id::text, COALESCE(email, ''), COALESCE(display_name, '')
|
||||
FROM users WHERE id = $1::uuid
|
||||
`, userID).Scan(&detail.UserID, &detail.Email, &detail.DisplayName)
|
||||
if err != nil {
|
||||
return detail, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
summary, _ := s.GetAIUsageSummary(ctx, "org")
|
||||
detail.Summary = summary
|
||||
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM ai_usage_events WHERE user_id = $1::uuid
|
||||
`, userID).Scan(&detail.EventsTotal)
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id::text, created_at, feature, model_id, provider_id, billing_scope,
|
||||
cost_micro_eur, prompt_tokens, completion_tokens, cached_input_tokens, estimated
|
||||
FROM ai_usage_events
|
||||
WHERE user_id = $1::uuid
|
||||
ORDER BY created_at DESC
|
||||
LIMIT $2 OFFSET $3
|
||||
`, userID, limit, offset)
|
||||
if err != nil {
|
||||
return detail, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var ev AIUsageEventItem
|
||||
var created time.Time
|
||||
if err := rows.Scan(&ev.ID, &created, &ev.Feature, &ev.ModelID, &ev.ProviderID,
|
||||
&ev.BillingScope, &ev.CostMicroEUR, &ev.PromptTokens, &ev.CompletionTokens,
|
||||
&ev.CachedInputTokens, &ev.Estimated); err != nil {
|
||||
continue
|
||||
}
|
||||
ev.CreatedAt = created.UTC().Format(time.RFC3339)
|
||||
ev.CostEUR = cost.MicroEURToEUR(ev.CostMicroEUR)
|
||||
detail.Events = append(detail.Events, ev)
|
||||
}
|
||||
return detail, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) ListAIPricing(ctx context.Context) ([]AIPricingEntry, error) {
|
||||
store := cost.NewPricingStore(s.db)
|
||||
prices, err := store.ListPrices(ctx)
|
||||
if err != nil {
|
||||
return []AIPricingEntry{}, err
|
||||
}
|
||||
out := make([]AIPricingEntry, 0, len(prices))
|
||||
for _, p := range prices {
|
||||
out = append(out, AIPricingEntry{
|
||||
ModelID: p.ModelID,
|
||||
ProviderType: p.ProviderType,
|
||||
InputMicroEURPerMTok: p.InputMicroEURPerMTok,
|
||||
CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok,
|
||||
OutputMicroEURPerMTok: p.OutputMicroEURPerMTok,
|
||||
ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok,
|
||||
InputEURPerMTok: cost.MicroEURToEUR(p.InputMicroEURPerMTok),
|
||||
OutputEURPerMTok: cost.MicroEURToEUR(p.OutputMicroEURPerMTok),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *Service) PutAIPricing(ctx context.Context, actorSub string, req putAIPricingRequest) ([]AIPricingEntry, error) {
|
||||
store := cost.NewPricingStore(s.db)
|
||||
for _, p := range req.Prices {
|
||||
if err := store.UpsertModelPrice(ctx, cost.ModelPrice{
|
||||
ModelID: p.ModelID,
|
||||
ProviderType: p.ProviderType,
|
||||
InputMicroEURPerMTok: p.InputMicroEURPerMTok,
|
||||
CachedInputMicroEURPerMTok: p.CachedInputMicroEURPerMTok,
|
||||
OutputMicroEURPerMTok: p.OutputMicroEURPerMTok,
|
||||
ReasoningMicroEURPerMTok: p.ReasoningMicroEURPerMTok,
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "update_ai_pricing", map[string]any{"count": len(req.Prices)})
|
||||
return s.ListAIPricing(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) ListAICostPolicies(ctx context.Context) ([]AICostPolicyEntry, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT p.scope_type, p.scope_id::text, p.daily_limit_micro_eur, p.monthly_limit_micro_eur, p.warn_threshold_pct,
|
||||
COALESCE(g.name, u.email, 'Organisation')
|
||||
FROM ai_cost_policies p
|
||||
LEFT JOIN user_groups g ON p.scope_type = 'group' AND g.id = p.scope_id
|
||||
LEFT JOIN users u ON p.scope_type = 'user' AND u.id = p.scope_id
|
||||
ORDER BY p.priority ASC, p.scope_type
|
||||
`)
|
||||
if err != nil {
|
||||
return []AICostPolicyEntry{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out = make([]AICostPolicyEntry, 0)
|
||||
for rows.Next() {
|
||||
var e AICostPolicyEntry
|
||||
var scopeID *string
|
||||
if err := rows.Scan(&e.ScopeType, &scopeID, &e.DailyLimitMicroEUR, &e.MonthlyLimitMicroEUR,
|
||||
&e.WarnThresholdPct, &e.ScopeName); err != nil {
|
||||
continue
|
||||
}
|
||||
e.ScopeID = scopeID
|
||||
if e.DailyLimitMicroEUR != nil {
|
||||
v := cost.MicroEURToEUR(*e.DailyLimitMicroEUR)
|
||||
e.DailyLimitEUR = &v
|
||||
}
|
||||
if e.MonthlyLimitMicroEUR != nil {
|
||||
v := cost.MicroEURToEUR(*e.MonthlyLimitMicroEUR)
|
||||
e.MonthlyLimitEUR = &v
|
||||
}
|
||||
out = append(out, e)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) PutUserAICostPolicy(ctx context.Context, actorSub, userID string, req putAICostPolicyRequest) error {
|
||||
daily, monthly := resolveCostLimits(req)
|
||||
policy := cost.NewPolicyService(s.db)
|
||||
if err := policy.UpsertScopePolicy(ctx, "user", userID, daily, monthly, req.WarnThresholdPct); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "set_user_ai_cost_policy", map[string]any{"user_id": userID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) PutGroupAICostPolicy(ctx context.Context, actorSub, groupID string, req putAICostPolicyRequest) error {
|
||||
daily, monthly := resolveCostLimits(req)
|
||||
policy := cost.NewPolicyService(s.db)
|
||||
if err := policy.UpsertScopePolicy(ctx, "group", groupID, daily, monthly, req.WarnThresholdPct); err != nil {
|
||||
return err
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "set_group_ai_cost_policy", map[string]any{"group_id": groupID})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) syncUsageQuotasToCostPolicy(ctx context.Context, usageQuotas map[string]any) error {
|
||||
if usageQuotas == nil {
|
||||
return nil
|
||||
}
|
||||
var daily, monthly *int64
|
||||
warnPct := 80
|
||||
|
||||
if v, ok := usageQuotas["llm_daily_cost_limit_eur"].(float64); ok && v > 0 {
|
||||
m := int64(v * 1_000_000)
|
||||
daily = &m
|
||||
}
|
||||
if v, ok := usageQuotas["llm_monthly_cost_limit_eur"].(float64); ok && v > 0 {
|
||||
m := int64(v * 1_000_000)
|
||||
monthly = &m
|
||||
}
|
||||
if v, ok := usageQuotas["llm_cost_warn_threshold_pct"].(float64); ok && v > 0 {
|
||||
warnPct = int(v)
|
||||
}
|
||||
|
||||
policy := cost.NewPolicyService(s.db)
|
||||
return policy.UpsertOrgPolicy(ctx, daily, monthly, warnPct)
|
||||
}
|
||||
|
||||
func resolveCostLimits(req putAICostPolicyRequest) (*int64, *int64) {
|
||||
var daily, monthly *int64
|
||||
if req.DailyLimitMicroEUR != nil {
|
||||
daily = req.DailyLimitMicroEUR
|
||||
} else if req.DailyLimitEUR != nil && *req.DailyLimitEUR > 0 {
|
||||
v := int64(*req.DailyLimitEUR * 1_000_000)
|
||||
daily = &v
|
||||
}
|
||||
if req.MonthlyLimitMicroEUR != nil {
|
||||
monthly = req.MonthlyLimitMicroEUR
|
||||
} else if req.MonthlyLimitEUR != nil && *req.MonthlyLimitEUR > 0 {
|
||||
v := int64(*req.MonthlyLimitEUR * 1_000_000)
|
||||
monthly = &v
|
||||
}
|
||||
return daily, monthly
|
||||
}
|
||||
116
internal/api/admin/ai_usage_handlers.go
Normal file
116
internal/api/admin/ai_usage_handlers.go
Normal file
@ -0,0 +1,116 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
)
|
||||
|
||||
func (h *Handler) GetAIUsage(w http.ResponseWriter, r *http.Request) {
|
||||
scope := r.URL.Query().Get("scope")
|
||||
if scope == "" {
|
||||
scope = "org"
|
||||
}
|
||||
summary, err := h.svc.GetAIUsageSummary(r.Context(), scope)
|
||||
if err != nil {
|
||||
h.logger.Error("get ai usage", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, summary)
|
||||
}
|
||||
|
||||
func (h *Handler) GetUserAIUsage(w http.ResponseWriter, r *http.Request) {
|
||||
userID := chi.URLParam(r, "userID")
|
||||
limit, _ := strconv.Atoi(r.URL.Query().Get("limit"))
|
||||
offset, _ := strconv.Atoi(r.URL.Query().Get("offset"))
|
||||
detail, err := h.svc.GetUserAIUsage(r.Context(), userID, limit, offset)
|
||||
if err != nil {
|
||||
if err.Error() == "user not found" {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, apiresponse.CodeNotFound, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
h.logger.Error("get user ai usage", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, detail)
|
||||
}
|
||||
|
||||
func (h *Handler) GetAIPricing(w http.ResponseWriter, r *http.Request) {
|
||||
prices, err := h.svc.ListAIPricing(r.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("list ai pricing", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices})
|
||||
}
|
||||
|
||||
func (h *Handler) PutAIPricing(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
var req putAIPricingRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
prices, err := h.svc.PutAIPricing(r.Context(), claims.Sub, req)
|
||||
if err != nil {
|
||||
h.logger.Error("put ai pricing", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"prices": prices})
|
||||
}
|
||||
|
||||
func (h *Handler) GetAICostPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
policies, err := h.svc.ListAICostPolicies(r.Context())
|
||||
if err != nil {
|
||||
h.logger.Error("list ai cost policies", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"policies": policies})
|
||||
}
|
||||
|
||||
func (h *Handler) PutUserAICostPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
userID := chi.URLParam(r, "userID")
|
||||
if verr := validateUserID(userID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
var req putAICostPolicyRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if err := h.svc.PutUserAICostPolicy(r.Context(), claims.Sub, userID, req); err != nil {
|
||||
h.logger.Error("put user ai cost policy", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (h *Handler) PutGroupAICostPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
groupID := chi.URLParam(r, "groupID")
|
||||
if verr := validateGroupID(groupID); verr != nil {
|
||||
apivalidate.WriteValidationError(w, r, verr)
|
||||
return
|
||||
}
|
||||
var req putAICostPolicyRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxQuotaRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
if err := h.svc.PutGroupAICostPolicy(r.Context(), claims.Sub, groupID, req); err != nil {
|
||||
h.logger.Error("put group ai cost policy", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
@ -79,6 +79,14 @@ func (h *Handler) Routes() chi.Router {
|
||||
r.With(write).Put("/org/settings", h.PutOrgSettings)
|
||||
r.With(read).Post("/org/llm/discover-models", h.DiscoverOrgLLMModels)
|
||||
|
||||
r.With(read).Get("/ai/usage", h.GetAIUsage)
|
||||
r.With(read).Get("/ai/usage/users/{userID}", h.GetUserAIUsage)
|
||||
r.With(read).Get("/ai/pricing", h.GetAIPricing)
|
||||
r.With(write).Put("/ai/pricing", h.PutAIPricing)
|
||||
r.With(read).Get("/ai/policies", h.GetAICostPolicies)
|
||||
r.With(write).Put("/users/{userID}/ai-policy", h.PutUserAICostPolicy)
|
||||
r.With(write).Put("/user-groups/{groupID}/ai-policy", h.PutGroupAICostPolicy)
|
||||
|
||||
r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI)
|
||||
r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider)
|
||||
r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider)
|
||||
|
||||
@ -45,11 +45,14 @@ func defaultOrgPolicy() map[string]any {
|
||||
"warn_threshold_pct": 90,
|
||||
},
|
||||
"usage_quotas": map[string]any{
|
||||
"llm_requests_per_day": 100,
|
||||
"llm_tokens_per_month": 500000,
|
||||
"search_requests_per_day": 50,
|
||||
"max_api_tokens_per_user": 10,
|
||||
"max_webhooks_per_user": 20,
|
||||
"llm_daily_cost_limit_eur": 2,
|
||||
"llm_monthly_cost_limit_eur": 35,
|
||||
"llm_cost_warn_threshold_pct": 80,
|
||||
"llm_requests_per_day": 75,
|
||||
"llm_tokens_per_month": 2_000_000,
|
||||
"search_requests_per_day": 20,
|
||||
"max_api_tokens_per_user": 5,
|
||||
"max_webhooks_per_user": 5,
|
||||
},
|
||||
"file_policies": map[string]any{
|
||||
"max_upload_mib": 512,
|
||||
@ -909,6 +912,12 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map
|
||||
}
|
||||
}
|
||||
|
||||
if usageQuotas, ok := merged["usage_quotas"].(map[string]any); ok {
|
||||
if err := s.syncUsageQuotasToCostPolicy(ctx, usageQuotas); err != nil {
|
||||
s.logger.Warn("sync ai cost policy failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetOrgSettings(ctx)
|
||||
}
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
@ -104,11 +105,14 @@ func (p *TranscriptProcessor) runPostActions(
|
||||
actions := policy.PostActions
|
||||
|
||||
if actions.LLMEnabled {
|
||||
summary, err := p.summarize(ctx, policy, rawTranscript)
|
||||
summary, provider, model, usage, err := p.summarize(ctx, policy, rawTranscript)
|
||||
if err != nil {
|
||||
p.logger.Warn("llm summary failed", "error", err, "job_id", jobID)
|
||||
} else if strings.TrimSpace(summary) != "" {
|
||||
finalText = summary
|
||||
if extID, err := ai.ResolveExternalIDByEmail(ctx, p.db, in.OrganizerEmail); err == nil && extID != "" {
|
||||
ai.RecordFeatureUsage(ctx, p.db, extID, "ultimeet", model, provider, usage)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,16 +149,20 @@ func (p *TranscriptProcessor) runPostActions(
|
||||
return err
|
||||
}
|
||||
|
||||
func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, error) {
|
||||
func (p *TranscriptProcessor) summarize(ctx context.Context, policy orgpolicy.MeetPolicy, transcript string) (string, llm.Provider, string, llm.UsageDetail, error) {
|
||||
provider, model, err := p.resolveLLM(ctx, policy.PostActions.LLMProviderID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", llm.Provider{}, "", llm.UsageDetail{}, err
|
||||
}
|
||||
prompt := strings.TrimSpace(policy.PostActions.LLMPrompt)
|
||||
if prompt == "" {
|
||||
prompt = "Résume cette réunion en français."
|
||||
}
|
||||
return p.llm.Complete(ctx, provider, model, prompt, transcript)
|
||||
result, err := p.llm.CompleteWithUsage(ctx, provider, model, prompt, transcript)
|
||||
if err != nil {
|
||||
return "", provider, model, llm.UsageDetail{}, err
|
||||
}
|
||||
return result.Content, provider, result.Model, result.Usage, nil
|
||||
}
|
||||
|
||||
func (p *TranscriptProcessor) resolveLLM(ctx context.Context, providerID string) (llm.Provider, string, error) {
|
||||
|
||||
@ -7,6 +7,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
)
|
||||
|
||||
@ -60,7 +63,7 @@ func parseEnrichedData(raw string) (*EnrichedContactData, error) {
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) {
|
||||
func enrichWithLLMTimeout(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry, timeout time.Duration) (*EnrichedContactData, error) {
|
||||
enrichCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
@ -69,7 +72,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.
|
||||
err error
|
||||
}, 1)
|
||||
go func() {
|
||||
data, err := enrichWithLLM(enrichCtx, client, settings, email, displayName, signatures)
|
||||
data, err := enrichWithLLM(enrichCtx, db, externalUserID, client, settings, email, displayName, signatures)
|
||||
resultCh <- struct {
|
||||
data *EnrichedContactData
|
||||
err error
|
||||
@ -89,7 +92,7 @@ func enrichWithLLMTimeout(ctx context.Context, client *llm.Client, settings llm.
|
||||
}
|
||||
}
|
||||
|
||||
func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) {
|
||||
func enrichWithLLM(ctx context.Context, db *pgxpool.Pool, externalUserID string, client *llm.Client, settings llm.Settings, email, displayName string, signatures []SignatureEntry) (*EnrichedContactData, error) {
|
||||
if client == nil || len(signatures) == 0 {
|
||||
return nil, fmt.Errorf("no signatures to enrich")
|
||||
}
|
||||
@ -98,11 +101,12 @@ func enrichWithLLM(ctx context.Context, client *llm.Client, settings llm.Setting
|
||||
return nil, err
|
||||
}
|
||||
prompt := buildEnrichPrompt(email, displayName, signatures)
|
||||
raw, err := client.Complete(ctx, provider, model, enrichSystemPrompt, prompt)
|
||||
result, err := client.CompleteWithUsage(ctx, provider, model, enrichSystemPrompt, prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseEnrichedData(raw)
|
||||
ai.RecordFeatureUsage(ctx, db, externalUserID, "contact_discovery", result.Model, provider, result.Usage)
|
||||
return parseEnrichedData(result.Content)
|
||||
}
|
||||
|
||||
func enrichedDataToSuggestions(userID, profileID string, data *EnrichedContactData) []Suggestion {
|
||||
|
||||
@ -121,7 +121,7 @@ func (s *Service) runProfileEnrichment(externalUserID, profileID, ncUserID, book
|
||||
}
|
||||
|
||||
enriched, enrichErr := enrichWithLLMTimeout(
|
||||
ctx, s.llm, llmSettings,
|
||||
ctx, s.db, externalUserID, s.llm, llmSettings,
|
||||
profile.PrimaryEmail, profile.DisplayName, sigs, llmEnrichTimeout,
|
||||
)
|
||||
if enrichErr != nil {
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/ai"
|
||||
"github.com/ultisuite/ulti-backend/internal/llm"
|
||||
"github.com/ultisuite/ulti-backend/internal/websearch"
|
||||
)
|
||||
@ -120,11 +121,12 @@ func (s *Service) ImproveContact(ctx context.Context, externalUserID string, inp
|
||||
|
||||
searchSection := s.fetchContactSearchResults(improveCtx, externalUserID, input)
|
||||
prompt := buildImproveContactPrompt(input, searchSection)
|
||||
raw, err := s.llm.Complete(improveCtx, provider, model, improveContactSystemPrompt, prompt)
|
||||
raw, err := s.llm.CompleteWithUsage(improveCtx, provider, model, improveContactSystemPrompt, prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data, err := parseEnrichedData(raw)
|
||||
ai.RecordFeatureUsage(ctx, s.db, externalUserID, "contact_discovery", raw.Model, provider, raw.Usage)
|
||||
data, err := parseEnrichedData(raw.Content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse improved contact: %w", err)
|
||||
}
|
||||
|
||||
@ -334,7 +334,7 @@ func (s *Service) executeScan(ctx context.Context, externalUserID, ncUserID, boo
|
||||
|
||||
heartbeatCtx, heartbeatCancel := context.WithCancel(context.Background())
|
||||
go s.enrichHeartbeat(heartbeatCtx, scanID, externalUserID, messagesScanned, enrichDone, totalMessages, enrichTotal)
|
||||
enriched, enrichErr := enrichWithLLMTimeout(ctx, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout)
|
||||
enriched, enrichErr := enrichWithLLMTimeout(ctx, s.db, externalUserID, s.llm, llmSettings, email, agg.DisplayName, sigEntries, llmEnrichTimeout)
|
||||
heartbeatCancel()
|
||||
if enrichErr != nil {
|
||||
s.logger.Warn("llm enrichment failed", "email", email, "error", enrichErr)
|
||||
|
||||
@ -44,6 +44,17 @@ type chatResponse struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"choices"`
|
||||
Usage *struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details,omitempty"`
|
||||
} `json:"usage,omitempty"`
|
||||
Error *struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error,omitempty"`
|
||||
@ -67,16 +78,40 @@ func NewClient() *Client {
|
||||
}
|
||||
|
||||
func (c *Client) Complete(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (string, error) {
|
||||
result, err := c.CompleteWithUsage(ctx, provider, model, systemPrompt, userPrompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return result.Content, nil
|
||||
}
|
||||
|
||||
// CompletionResult holds LLM output and usage metadata.
|
||||
type CompletionResult struct {
|
||||
Content string
|
||||
Model string
|
||||
Usage UsageDetail
|
||||
}
|
||||
|
||||
// UsageDetail mirrors ai/cost.UsageDetail for llm package consumers.
|
||||
type UsageDetail struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
CachedInputTokens int
|
||||
ReasoningTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
func (c *Client) CompleteWithUsage(ctx context.Context, provider Provider, model, systemPrompt, userPrompt string) (CompletionResult, error) {
|
||||
baseURL := strings.TrimRight(strings.TrimSpace(provider.BaseURL), "/")
|
||||
if baseURL == "" {
|
||||
return "", fmt.Errorf("llm provider base_url is required")
|
||||
return CompletionResult{}, fmt.Errorf("llm provider base_url is required")
|
||||
}
|
||||
model = strings.TrimSpace(model)
|
||||
if model == "" {
|
||||
model = strings.TrimSpace(provider.DefaultModel)
|
||||
}
|
||||
if model == "" {
|
||||
return "", fmt.Errorf("llm model is required")
|
||||
return CompletionResult{}, fmt.Errorf("llm model is required")
|
||||
}
|
||||
|
||||
reqBody := chatRequest{
|
||||
@ -89,13 +124,13 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP
|
||||
}
|
||||
payload, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return CompletionResult{}, err
|
||||
}
|
||||
|
||||
url := baseURL + "/chat/completions"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return CompletionResult{}, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if strings.TrimSpace(provider.APIKey) != "" {
|
||||
@ -104,29 +139,68 @@ func (c *Client) Complete(ctx context.Context, provider Provider, model, systemP
|
||||
|
||||
resp, err := c.http.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return CompletionResult{}, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
if err != nil {
|
||||
return "", err
|
||||
return CompletionResult{}, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return "", fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
|
||||
return CompletionResult{}, fmt.Errorf("llm request failed (%d): %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
var parsed chatResponse
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", err
|
||||
return CompletionResult{}, err
|
||||
}
|
||||
if parsed.Error != nil && parsed.Error.Message != "" {
|
||||
return "", fmt.Errorf("llm error: %s", parsed.Error.Message)
|
||||
return CompletionResult{}, fmt.Errorf("llm error: %s", parsed.Error.Message)
|
||||
}
|
||||
if len(parsed.Choices) == 0 {
|
||||
return "", fmt.Errorf("llm returned no choices")
|
||||
return CompletionResult{}, fmt.Errorf("llm returned no choices")
|
||||
}
|
||||
return strings.TrimSpace(parsed.Choices[0].Message.Content), nil
|
||||
usage := parseUsageFromResponse(parsed.Usage)
|
||||
return CompletionResult{
|
||||
Content: strings.TrimSpace(parsed.Choices[0].Message.Content),
|
||||
Model: model,
|
||||
Usage: usage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func parseUsageFromResponse(u *struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
PromptTokensDetails *struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"prompt_tokens_details,omitempty"`
|
||||
CompletionTokensDetails *struct {
|
||||
ReasoningTokens int `json:"reasoning_tokens"`
|
||||
} `json:"completion_tokens_details,omitempty"`
|
||||
}) UsageDetail {
|
||||
if u == nil {
|
||||
return UsageDetail{TotalTokens: 1}
|
||||
}
|
||||
d := UsageDetail{
|
||||
PromptTokens: u.PromptTokens,
|
||||
CompletionTokens: u.CompletionTokens,
|
||||
TotalTokens: u.TotalTokens,
|
||||
}
|
||||
if u.PromptTokensDetails != nil {
|
||||
d.CachedInputTokens = u.PromptTokensDetails.CachedTokens
|
||||
}
|
||||
if u.CompletionTokensDetails != nil {
|
||||
d.ReasoningTokens = u.CompletionTokensDetails.ReasoningTokens
|
||||
}
|
||||
if d.TotalTokens == 0 {
|
||||
d.TotalTokens = d.PromptTokens + d.CompletionTokens
|
||||
}
|
||||
if d.TotalTokens == 0 {
|
||||
d.TotalTokens = 1
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func (c *Client) ListModels(ctx context.Context, provider Provider) ([]string, error) {
|
||||
|
||||
9
migrations/000052_ai_cost_metering.down.sql
Normal file
9
migrations/000052_ai_cost_metering.down.sql
Normal file
@ -0,0 +1,9 @@
|
||||
DROP TABLE IF EXISTS ai_cost_policies;
|
||||
DROP TABLE IF EXISTS ai_org_usage_monthly;
|
||||
DROP TABLE IF EXISTS ai_org_usage_daily;
|
||||
ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_org;
|
||||
ALTER TABLE ai_usage_monthly DROP COLUMN IF EXISTS cost_micro_eur_user;
|
||||
ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_org;
|
||||
ALTER TABLE ai_usage_daily DROP COLUMN IF EXISTS cost_micro_eur_user;
|
||||
DROP TABLE IF EXISTS ai_usage_events;
|
||||
DROP TABLE IF EXISTS ai_model_pricing;
|
||||
96
migrations/000052_ai_cost_metering.up.sql
Normal file
96
migrations/000052_ai_cost_metering.up.sql
Normal file
@ -0,0 +1,96 @@
|
||||
-- Model pricing (micro-EUR per 1M tokens)
|
||||
CREATE TABLE IF NOT EXISTS ai_model_pricing (
|
||||
model_id TEXT NOT NULL,
|
||||
provider_type TEXT NOT NULL DEFAULT 'generic',
|
||||
input_micro_eur_per_mtok BIGINT NOT NULL,
|
||||
cached_input_micro_eur_per_mtok BIGINT,
|
||||
output_micro_eur_per_mtok BIGINT NOT NULL,
|
||||
reasoning_micro_eur_per_mtok BIGINT,
|
||||
effective_from DATE NOT NULL DEFAULT CURRENT_DATE,
|
||||
source TEXT NOT NULL DEFAULT 'manual',
|
||||
PRIMARY KEY (model_id, effective_from)
|
||||
);
|
||||
|
||||
-- Detailed usage ledger
|
||||
CREATE TABLE IF NOT EXISTS ai_usage_events (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
feature TEXT NOT NULL,
|
||||
model_id TEXT NOT NULL,
|
||||
provider_id TEXT NOT NULL,
|
||||
billing_scope TEXT NOT NULL CHECK (billing_scope IN ('org', 'user')),
|
||||
provider_key_fingerprint TEXT NOT NULL DEFAULT '',
|
||||
prompt_tokens INT NOT NULL DEFAULT 0,
|
||||
completion_tokens INT NOT NULL DEFAULT 0,
|
||||
cached_input_tokens INT NOT NULL DEFAULT 0,
|
||||
reasoning_tokens INT NOT NULL DEFAULT 0,
|
||||
cost_micro_eur BIGINT NOT NULL DEFAULT 0,
|
||||
estimated BOOLEAN NOT NULL DEFAULT false,
|
||||
request_id TEXT
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_user_time ON ai_usage_events(user_id, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_scope_key ON ai_usage_events(billing_scope, provider_key_fingerprint, created_at DESC);
|
||||
CREATE INDEX IF NOT EXISTS idx_ai_usage_events_created ON ai_usage_events(created_at DESC);
|
||||
|
||||
-- Extend daily rollups
|
||||
ALTER TABLE ai_usage_daily
|
||||
ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0;
|
||||
|
||||
-- Extend monthly rollups
|
||||
ALTER TABLE ai_usage_monthly
|
||||
ADD COLUMN IF NOT EXISTS cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
|
||||
ADD COLUMN IF NOT EXISTS cost_micro_eur_user BIGINT NOT NULL DEFAULT 0;
|
||||
|
||||
-- Org-wide aggregates
|
||||
CREATE TABLE IF NOT EXISTS ai_org_usage_daily (
|
||||
usage_date DATE NOT NULL PRIMARY KEY,
|
||||
cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
|
||||
cost_micro_eur_user BIGINT NOT NULL DEFAULT 0,
|
||||
requests INT NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS ai_org_usage_monthly (
|
||||
usage_month DATE NOT NULL PRIMARY KEY,
|
||||
cost_micro_eur_org BIGINT NOT NULL DEFAULT 0,
|
||||
cost_micro_eur_user BIGINT NOT NULL DEFAULT 0
|
||||
);
|
||||
|
||||
-- Cost limit policies (org / group / user)
|
||||
CREATE TABLE IF NOT EXISTS ai_cost_policies (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
scope_type TEXT NOT NULL CHECK (scope_type IN ('org', 'group', 'user')),
|
||||
scope_id UUID,
|
||||
daily_limit_micro_eur BIGINT,
|
||||
monthly_limit_micro_eur BIGINT,
|
||||
warn_threshold_pct INT NOT NULL DEFAULT 80,
|
||||
priority INT NOT NULL DEFAULT 0,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_org ON ai_cost_policies(scope_type) WHERE scope_type = 'org';
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_group ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'group';
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_ai_cost_policies_user ON ai_cost_policies(scope_type, scope_id) WHERE scope_type = 'user';
|
||||
|
||||
-- Default org policy: ~10 EUR/day, ~100 EUR/month
|
||||
INSERT INTO ai_cost_policies (scope_type, scope_id, daily_limit_micro_eur, monthly_limit_micro_eur, warn_threshold_pct, priority)
|
||||
SELECT 'org', NULL, 10000000, 100000000, 80, 0
|
||||
WHERE NOT EXISTS (SELECT 1 FROM ai_cost_policies WHERE scope_type = 'org');
|
||||
|
||||
-- Seed common model pricing (approximate public rates in EUR)
|
||||
INSERT INTO ai_model_pricing (model_id, provider_type, input_micro_eur_per_mtok, cached_input_micro_eur_per_mtok, output_micro_eur_per_mtok, source) VALUES
|
||||
('gpt-4o-mini', 'openai', 140000, 70000, 560000, 'seed'),
|
||||
('gpt-4o', 'openai', 2300000, 1150000, 9200000, 'seed'),
|
||||
('gpt-4.1-mini', 'openai', 360000, 90000, 1440000, 'seed'),
|
||||
('gpt-4.1', 'openai', 1800000, 450000, 7200000, 'seed'),
|
||||
('o3-mini', 'openai', 990000, 250000, 3960000, 'seed'),
|
||||
('claude-sonnet-4-6', 'anthropic', 2700000, 270000, 13500000, 'seed'),
|
||||
('claude-haiku-4-5', 'anthropic', 900000, 90000, 4500000, 'seed'),
|
||||
('mistral-small-latest', 'mistral', 180000, 90000, 540000, 'seed'),
|
||||
('mistral-large-latest', 'mistral', 1800000, 450000, 5400000, 'seed'),
|
||||
('gemini-2.0-flash', 'google_gemini', 90000, 23000, 360000, 'seed'),
|
||||
('gemini-2.5-pro', 'google_gemini', 1100000, 280000, 9000000, 'seed')
|
||||
ON CONFLICT DO NOTHING;
|
||||
42
migrations/000053_ai_quota_defaults_pme.down.sql
Normal file
42
migrations/000053_ai_quota_defaults_pme.down.sql
Normal file
@ -0,0 +1,42 @@
|
||||
UPDATE ai_cost_policies
|
||||
SET daily_limit_micro_eur = 10000000,
|
||||
monthly_limit_micro_eur = 100000000,
|
||||
updated_at = NOW()
|
||||
WHERE scope_type = 'org'
|
||||
AND daily_limit_micro_eur = 2000000
|
||||
AND monthly_limit_micro_eur = 35000000;
|
||||
|
||||
UPDATE org_settings
|
||||
SET settings = jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
COALESCE(settings, '{}'::jsonb),
|
||||
'{usage_quotas,llm_daily_cost_limit_eur}',
|
||||
'10'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_monthly_cost_limit_eur}',
|
||||
'100'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_requests_per_day}',
|
||||
'100'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_tokens_per_month}',
|
||||
'500000'::jsonb
|
||||
),
|
||||
'{usage_quotas,search_requests_per_day}',
|
||||
'50'::jsonb
|
||||
),
|
||||
'{usage_quotas,max_api_tokens_per_user}',
|
||||
'10'::jsonb
|
||||
),
|
||||
'{usage_quotas,max_webhooks_per_user}',
|
||||
'20'::jsonb
|
||||
),
|
||||
updated_at = NOW()
|
||||
WHERE id = 1
|
||||
AND COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 0) = 2
|
||||
AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 0) = 35;
|
||||
56
migrations/000053_ai_quota_defaults_pme.up.sql
Normal file
56
migrations/000053_ai_quota_defaults_pme.up.sql
Normal file
@ -0,0 +1,56 @@
|
||||
-- Align factory AI quotas with reasonable SME defaults (per user, org keys).
|
||||
-- Only touch rows still at the previous factory values.
|
||||
|
||||
UPDATE ai_cost_policies
|
||||
SET daily_limit_micro_eur = 2000000,
|
||||
monthly_limit_micro_eur = 35000000,
|
||||
updated_at = NOW()
|
||||
WHERE scope_type = 'org'
|
||||
AND daily_limit_micro_eur = 10000000
|
||||
AND monthly_limit_micro_eur = 100000000;
|
||||
|
||||
UPDATE org_settings
|
||||
SET settings = jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
jsonb_set(
|
||||
COALESCE(settings, '{}'::jsonb),
|
||||
'{usage_quotas,llm_daily_cost_limit_eur}',
|
||||
'2'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_monthly_cost_limit_eur}',
|
||||
'35'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_requests_per_day}',
|
||||
'75'::jsonb
|
||||
),
|
||||
'{usage_quotas,llm_tokens_per_month}',
|
||||
'2000000'::jsonb
|
||||
),
|
||||
'{usage_quotas,search_requests_per_day}',
|
||||
'20'::jsonb
|
||||
),
|
||||
'{usage_quotas,max_api_tokens_per_user}',
|
||||
'5'::jsonb
|
||||
),
|
||||
'{usage_quotas,max_webhooks_per_user}',
|
||||
'5'::jsonb
|
||||
),
|
||||
updated_at = NOW()
|
||||
WHERE id = 1
|
||||
AND (
|
||||
settings->'usage_quotas' IS NULL
|
||||
OR settings->'usage_quotas' = 'null'::jsonb
|
||||
OR (
|
||||
COALESCE((settings->'usage_quotas'->>'llm_daily_cost_limit_eur')::numeric, 10) = 10
|
||||
AND COALESCE((settings->'usage_quotas'->>'llm_monthly_cost_limit_eur')::numeric, 100) = 100
|
||||
AND COALESCE((settings->'usage_quotas'->>'llm_requests_per_day')::numeric, 100) = 100
|
||||
AND COALESCE((settings->'usage_quotas'->>'llm_tokens_per_month')::numeric, 500000) = 500000
|
||||
AND COALESCE((settings->'usage_quotas'->>'search_requests_per_day')::numeric, 50) = 50
|
||||
AND COALESCE((settings->'usage_quotas'->>'max_api_tokens_per_user')::numeric, 10) = 10
|
||||
AND COALESCE((settings->'usage_quotas'->>'max_webhooks_per_user')::numeric, 20) = 20
|
||||
)
|
||||
);
|
||||
Loading…
Reference in New Issue
Block a user