ultisuite-backend/internal/api/auth/rate_limit.go
R3D347HR4Y 8bbc539d77 feat(auth): implement flow completion and rate limiting for authentication flows
- Added a new handler for completing authentication flows, including session validation and cookie management.
- Implemented flow rate limiting to restrict the number of flow start requests per client IP.
- Enhanced flow session management with Redis support for persistent session storage.
- Updated existing handlers to integrate the new flow completion logic and error handling for various session states.
- Introduced unit tests for the new flow completion and rate limiting functionalities to ensure reliability.
2026-06-20 01:09:42 +02:00

95 lines
1.9 KiB
Go

package authapi
import (
"net"
"net/http"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
const (
flowRateLimitMax = 10
flowRateLimitWindow = time.Minute
flowRateKeyPrefix = "auth_flow_rate:"
)
// FlowRateLimiter limits flow start requests per client IP.
type FlowRateLimiter struct {
rdb *redis.Client
mu sync.Mutex
// fallback when KeyDB unavailable
hits map[string][]time.Time
}
func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter {
return &FlowRateLimiter{
rdb: rdb,
hits: make(map[string][]time.Time),
}
}
func (l *FlowRateLimiter) Allow(r *http.Request) bool {
ip := clientIP(r)
if ip == "" {
return true
}
if l.rdb != nil {
return l.allowRedis(r, ip)
}
return l.allowMemory(ip)
}
func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool {
ctx := r.Context()
key := flowRateKeyPrefix + ip
count, err := l.rdb.Incr(ctx, key).Result()
if err != nil {
return l.allowMemory(ip)
}
if count == 1 {
_ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err()
}
return count <= flowRateLimitMax
}
func (l *FlowRateLimiter) allowMemory(ip string) bool {
now := time.Now()
cutoff := now.Add(-flowRateLimitWindow)
l.mu.Lock()
defer l.mu.Unlock()
prev := l.hits[ip]
next := prev[:0]
for _, t := range prev {
if t.After(cutoff) {
next = append(next, t)
}
}
if len(next) >= flowRateLimitMax {
l.hits[ip] = next
return false
}
next = append(next, now)
l.hits[ip] = next
return true
}
func clientIP(r *http.Request) string {
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
parts := strings.Split(xff, ",")
if len(parts) > 0 {
return strings.TrimSpace(parts[0])
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
return xrip
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return strings.TrimSpace(r.RemoteAddr)
}
return host
}