package ai import ( "context" "errors" "fmt" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) var ErrQuotaExceeded = errors.New("llm quota exceeded") type QuotaService struct { db *pgxpool.Pool } func NewQuotaService(db *pgxpool.Pool) *QuotaService { return &QuotaService{db: db} } func (s *QuotaService) Check(ctx context.Context, externalUserID string) (QuotaStatus, error) { limits, err := LoadQuotaLimits(ctx, s.db) if err != nil { return QuotaStatus{}, err } userID, err := s.resolveUserID(ctx, externalUserID) if err != nil { return QuotaStatus{}, err } today := time.Now().UTC().Truncate(24 * time.Hour) month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) var requestsToday int var tokensMonth int64 _ = s.db.QueryRow(ctx, ` SELECT COALESCE(requests, 0) FROM ai_usage_daily WHERE user_id = $1 AND usage_date = $2 `, userID, today).Scan(&requestsToday) _ = s.db.QueryRow(ctx, ` SELECT COALESCE(tokens, 0) FROM ai_usage_monthly WHERE user_id = $1 AND usage_month = $2 `, userID, month).Scan(&tokensMonth) status := QuotaStatus{ RequestsUsedToday: requestsToday, RequestsLimit: limits.RequestsPerDay, TokensUsedMonth: tokensMonth, TokensLimit: limits.TokensPerMonth, } if limits.RequestsPerDay > 0 { status.RequestsRemaining = limits.RequestsPerDay - requestsToday if status.RequestsRemaining < 0 { status.RequestsRemaining = 0 } } if limits.TokensPerMonth > 0 { status.TokensRemaining = limits.TokensPerMonth - tokensMonth if status.TokensRemaining < 0 { status.TokensRemaining = 0 } } return status, nil } func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string) error { status, err := s.Check(ctx, externalUserID) if err != nil { return err } if status.RequestsLimit > 0 && status.RequestsUsedToday >= status.RequestsLimit { return fmt.Errorf("%w: daily request limit reached", ErrQuotaExceeded) } if status.TokensLimit > 0 && status.TokensUsedMonth >= status.TokensLimit { return fmt.Errorf("%w: monthly token limit reached", ErrQuotaExceeded) } return nil } func (s *QuotaService) Record(ctx context.Context, externalUserID string, tokens int64) error { if tokens < 0 { tokens = 0 } userID, err := s.resolveUserID(ctx, externalUserID) if err != nil { return err } today := time.Now().UTC().Truncate(24 * time.Hour) month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC) _, err = s.db.Exec(ctx, ` INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens) VALUES ($1, $2, 1, $3) ON CONFLICT (user_id, usage_date) DO UPDATE SET requests = ai_usage_daily.requests + 1, tokens = ai_usage_daily.tokens + EXCLUDED.tokens `, userID, today, tokens) if err != nil { return err } _, err = s.db.Exec(ctx, ` INSERT INTO ai_usage_monthly (user_id, usage_month, tokens) VALUES ($1, $2, $3) ON CONFLICT (user_id, usage_month) DO UPDATE SET tokens = ai_usage_monthly.tokens + EXCLUDED.tokens `, userID, month, tokens) return err } func (s *QuotaService) resolveUserID(ctx context.Context, externalUserID string) (string, error) { var userID string err := s.db.QueryRow(ctx, ` SELECT id::text FROM users WHERE external_id = $1 `, externalUserID).Scan(&userID) if err != nil { if err == pgx.ErrNoRows { return "", fmt.Errorf("user not found") } return "", err } return userID, nil }