package authapi import ( "net" "net/http" "strings" "sync" "time" "github.com/redis/go-redis/v9" ) const ( flowRateLimitMax = 10 flowRateLimitWindow = time.Minute flowRateKeyPrefix = "auth_flow_rate:" ) // FlowRateLimiter limits flow start requests per client IP. type FlowRateLimiter struct { rdb *redis.Client mu sync.Mutex // fallback when KeyDB unavailable hits map[string][]time.Time } func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter { return &FlowRateLimiter{ rdb: rdb, hits: make(map[string][]time.Time), } } func (l *FlowRateLimiter) Allow(r *http.Request) bool { ip := clientIP(r) if ip == "" { return true } if l.rdb != nil { return l.allowRedis(r, ip) } return l.allowMemory(ip) } func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool { ctx := r.Context() key := flowRateKeyPrefix + ip count, err := l.rdb.Incr(ctx, key).Result() if err != nil { return l.allowMemory(ip) } if count == 1 { _ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err() } return count <= flowRateLimitMax } func (l *FlowRateLimiter) allowMemory(ip string) bool { now := time.Now() cutoff := now.Add(-flowRateLimitWindow) l.mu.Lock() defer l.mu.Unlock() prev := l.hits[ip] next := prev[:0] for _, t := range prev { if t.After(cutoff) { next = append(next, t) } } if len(next) >= flowRateLimitMax { l.hits[ip] = next return false } next = append(next, now) l.hits[ip] = next return true } func clientIP(r *http.Request) string { if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { parts := strings.Split(xff, ",") if len(parts) > 0 { return strings.TrimSpace(parts[0]) } } if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" { return xrip } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { return strings.TrimSpace(r.RemoteAddr) } return host }