- Refactored AI gateway to utilize new cost management structures for usage tracking. - Replaced deprecated token extraction methods with a unified cost parsing approach. - Enhanced usage fallback mechanisms and introduced detailed usage metrics in responses. - Added new metering functionality to record AI usage and costs effectively. - Updated tests to reflect changes in usage parsing and cost calculations. - Introduced new API endpoints for retrieving AI usage summaries and pricing information.
150 lines
4.3 KiB
Go
150 lines
4.3 KiB
Go
package cost
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
const (
|
|
BillingScopeOrg = "org"
|
|
BillingScopeUser = "user"
|
|
)
|
|
|
|
type RecordInput struct {
|
|
ExternalUserID string
|
|
Feature string
|
|
ModelID string
|
|
ProviderID string
|
|
BillingScope string
|
|
ProviderKeyFingerprint string
|
|
Usage UsageDetail
|
|
RequestID string
|
|
}
|
|
|
|
type Meter struct {
|
|
db *pgxpool.Pool
|
|
pricing *PricingStore
|
|
}
|
|
|
|
func NewMeter(db *pgxpool.Pool) *Meter {
|
|
return &Meter{db: db, pricing: NewPricingStore(db)}
|
|
}
|
|
|
|
func (m *Meter) RecordUsage(ctx context.Context, in RecordInput) error {
|
|
if m.db == nil {
|
|
return nil
|
|
}
|
|
userID, err := m.resolveUserID(ctx, in.ExternalUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
price, found := m.pricing.LookupPrice(ctx, in.ModelID)
|
|
cost, estimated := ComputeCostMicroEUR(in.Usage, price, found)
|
|
|
|
tx, err := m.db.Begin(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO ai_usage_events (
|
|
user_id, feature, model_id, provider_id, billing_scope,
|
|
provider_key_fingerprint, prompt_tokens, completion_tokens,
|
|
cached_input_tokens, reasoning_tokens, cost_micro_eur, estimated, request_id
|
|
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
|
`, userID, in.Feature, in.ModelID, in.ProviderID, in.BillingScope,
|
|
in.ProviderKeyFingerprint, in.Usage.PromptTokens, in.Usage.CompletionTokens,
|
|
in.Usage.CachedInputTokens, in.Usage.ReasoningTokens, cost, estimated, nullStr(in.RequestID))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
tokens := int64(in.Usage.TotalTokens)
|
|
orgCost := int64(0)
|
|
userCost := int64(0)
|
|
if in.BillingScope == BillingScopeOrg {
|
|
orgCost = cost
|
|
} else {
|
|
userCost = cost
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens, cost_micro_eur_org, cost_micro_eur_user)
|
|
VALUES ($1, $2, 1, $3, $4, $5)
|
|
ON CONFLICT (user_id, usage_date) DO UPDATE SET
|
|
requests = ai_usage_daily.requests + 1,
|
|
tokens = ai_usage_daily.tokens + EXCLUDED.tokens,
|
|
cost_micro_eur_org = ai_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
|
cost_micro_eur_user = ai_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
|
`, userID, today, tokens, orgCost, userCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens, cost_micro_eur_org, cost_micro_eur_user)
|
|
VALUES ($1, $2, $3, $4, $5)
|
|
ON CONFLICT (user_id, usage_month) DO UPDATE SET
|
|
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens,
|
|
cost_micro_eur_org = ai_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
|
cost_micro_eur_user = ai_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
|
`, userID, month, tokens, orgCost, userCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO ai_org_usage_daily (usage_date, cost_micro_eur_org, cost_micro_eur_user, requests)
|
|
VALUES ($1, $2, $3, 1)
|
|
ON CONFLICT (usage_date) DO UPDATE SET
|
|
cost_micro_eur_org = ai_org_usage_daily.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
|
cost_micro_eur_user = ai_org_usage_daily.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user,
|
|
requests = ai_org_usage_daily.requests + 1
|
|
`, today, orgCost, userCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO ai_org_usage_monthly (usage_month, cost_micro_eur_org, cost_micro_eur_user)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (usage_month) DO UPDATE SET
|
|
cost_micro_eur_org = ai_org_usage_monthly.cost_micro_eur_org + EXCLUDED.cost_micro_eur_org,
|
|
cost_micro_eur_user = ai_org_usage_monthly.cost_micro_eur_user + EXCLUDED.cost_micro_eur_user
|
|
`, month, orgCost, userCost)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
func (m *Meter) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
|
var userID string
|
|
err := m.db.QueryRow(ctx, `
|
|
SELECT id::text FROM users WHERE external_id = $1
|
|
`, externalUserID).Scan(&userID)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return "", fmt.Errorf("user not found")
|
|
}
|
|
return "", err
|
|
}
|
|
return userID, nil
|
|
}
|
|
|
|
func nullStr(s string) *string {
|
|
if s == "" {
|
|
return nil
|
|
}
|
|
return &s
|
|
}
|