125 lines
3.4 KiB
Go
125 lines
3.4 KiB
Go
package ai
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
var ErrQuotaExceeded = errors.New("llm quota exceeded")
|
|
|
|
type QuotaService struct {
|
|
db *pgxpool.Pool
|
|
}
|
|
|
|
func NewQuotaService(db *pgxpool.Pool) *QuotaService {
|
|
return &QuotaService{db: db}
|
|
}
|
|
|
|
func (s *QuotaService) Check(ctx context.Context, externalUserID string) (QuotaStatus, error) {
|
|
limits, err := LoadQuotaLimits(ctx, s.db)
|
|
if err != nil {
|
|
return QuotaStatus{}, err
|
|
}
|
|
userID, err := s.resolveUserID(ctx, externalUserID)
|
|
if err != nil {
|
|
return QuotaStatus{}, err
|
|
}
|
|
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
|
|
var requestsToday int
|
|
var tokensMonth int64
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(requests, 0) FROM ai_usage_daily
|
|
WHERE user_id = $1 AND usage_date = $2
|
|
`, userID, today).Scan(&requestsToday)
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT COALESCE(tokens, 0) FROM ai_usage_monthly
|
|
WHERE user_id = $1 AND usage_month = $2
|
|
`, userID, month).Scan(&tokensMonth)
|
|
|
|
status := QuotaStatus{
|
|
RequestsUsedToday: requestsToday,
|
|
RequestsLimit: limits.RequestsPerDay,
|
|
TokensUsedMonth: tokensMonth,
|
|
TokensLimit: limits.TokensPerMonth,
|
|
}
|
|
if limits.RequestsPerDay > 0 {
|
|
status.RequestsRemaining = limits.RequestsPerDay - requestsToday
|
|
if status.RequestsRemaining < 0 {
|
|
status.RequestsRemaining = 0
|
|
}
|
|
}
|
|
if limits.TokensPerMonth > 0 {
|
|
status.TokensRemaining = limits.TokensPerMonth - tokensMonth
|
|
if status.TokensRemaining < 0 {
|
|
status.TokensRemaining = 0
|
|
}
|
|
}
|
|
return status, nil
|
|
}
|
|
|
|
func (s *QuotaService) AssertAvailable(ctx context.Context, externalUserID string) error {
|
|
status, err := s.Check(ctx, externalUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if status.RequestsLimit > 0 && status.RequestsUsedToday >= status.RequestsLimit {
|
|
return fmt.Errorf("%w: daily request limit reached", ErrQuotaExceeded)
|
|
}
|
|
if status.TokensLimit > 0 && status.TokensUsedMonth >= status.TokensLimit {
|
|
return fmt.Errorf("%w: monthly token limit reached", ErrQuotaExceeded)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *QuotaService) Record(ctx context.Context, externalUserID string, tokens int64) error {
|
|
if tokens < 0 {
|
|
tokens = 0
|
|
}
|
|
userID, err := s.resolveUserID(ctx, externalUserID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
today := time.Now().UTC().Truncate(24 * time.Hour)
|
|
month := time.Date(today.Year(), today.Month(), 1, 0, 0, 0, 0, time.UTC)
|
|
|
|
_, err = s.db.Exec(ctx, `
|
|
INSERT INTO ai_usage_daily (user_id, usage_date, requests, tokens)
|
|
VALUES ($1, $2, 1, $3)
|
|
ON CONFLICT (user_id, usage_date) DO UPDATE SET
|
|
requests = ai_usage_daily.requests + 1,
|
|
tokens = ai_usage_daily.tokens + EXCLUDED.tokens
|
|
`, userID, today, tokens)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.Exec(ctx, `
|
|
INSERT INTO ai_usage_monthly (user_id, usage_month, tokens)
|
|
VALUES ($1, $2, $3)
|
|
ON CONFLICT (user_id, usage_month) DO UPDATE SET
|
|
tokens = ai_usage_monthly.tokens + EXCLUDED.tokens
|
|
`, userID, month, tokens)
|
|
return err
|
|
}
|
|
|
|
func (s *QuotaService) resolveUserID(ctx context.Context, externalUserID string) (string, error) {
|
|
var userID string
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT id::text FROM users WHERE external_id = $1
|
|
`, externalUserID).Scan(&userID)
|
|
if err != nil {
|
|
if err == pgx.ErrNoRows {
|
|
return "", fmt.Errorf("user not found")
|
|
}
|
|
return "", err
|
|
}
|
|
return userID, nil
|
|
}
|