- Added a new handler for completing authentication flows, including session validation and cookie management. - Implemented flow rate limiting to restrict the number of flow start requests per client IP. - Enhanced flow session management with Redis support for persistent session storage. - Updated existing handlers to integrate the new flow completion logic and error handling for various session states. - Introduced unit tests for the new flow completion and rate limiting functionalities to ensure reliability.
95 lines
1.9 KiB
Go
95 lines
1.9 KiB
Go
package authapi
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
flowRateLimitMax = 10
|
|
flowRateLimitWindow = time.Minute
|
|
flowRateKeyPrefix = "auth_flow_rate:"
|
|
)
|
|
|
|
// FlowRateLimiter limits flow start requests per client IP.
|
|
type FlowRateLimiter struct {
|
|
rdb *redis.Client
|
|
mu sync.Mutex
|
|
// fallback when KeyDB unavailable
|
|
hits map[string][]time.Time
|
|
}
|
|
|
|
func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter {
|
|
return &FlowRateLimiter{
|
|
rdb: rdb,
|
|
hits: make(map[string][]time.Time),
|
|
}
|
|
}
|
|
|
|
func (l *FlowRateLimiter) Allow(r *http.Request) bool {
|
|
ip := clientIP(r)
|
|
if ip == "" {
|
|
return true
|
|
}
|
|
if l.rdb != nil {
|
|
return l.allowRedis(r, ip)
|
|
}
|
|
return l.allowMemory(ip)
|
|
}
|
|
|
|
func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool {
|
|
ctx := r.Context()
|
|
key := flowRateKeyPrefix + ip
|
|
count, err := l.rdb.Incr(ctx, key).Result()
|
|
if err != nil {
|
|
return l.allowMemory(ip)
|
|
}
|
|
if count == 1 {
|
|
_ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err()
|
|
}
|
|
return count <= flowRateLimitMax
|
|
}
|
|
|
|
func (l *FlowRateLimiter) allowMemory(ip string) bool {
|
|
now := time.Now()
|
|
cutoff := now.Add(-flowRateLimitWindow)
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
prev := l.hits[ip]
|
|
next := prev[:0]
|
|
for _, t := range prev {
|
|
if t.After(cutoff) {
|
|
next = append(next, t)
|
|
}
|
|
}
|
|
if len(next) >= flowRateLimitMax {
|
|
l.hits[ip] = next
|
|
return false
|
|
}
|
|
next = append(next, now)
|
|
l.hits[ip] = next
|
|
return true
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
|
|
parts := strings.Split(xff, ",")
|
|
if len(parts) > 0 {
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
}
|
|
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
|
|
return xrip
|
|
}
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err != nil {
|
|
return strings.TrimSpace(r.RemoteAddr)
|
|
}
|
|
return host
|
|
}
|