- Added a new handler for completing authentication flows, including session validation and cookie management. - Implemented flow rate limiting to restrict the number of flow start requests per client IP. - Enhanced flow session management with Redis support for persistent session storage. - Updated existing handlers to integrate the new flow completion logic and error handling for various session states. - Introduced unit tests for the new flow completion and rate limiting functionalities to ensure reliability.
221 lines
5.3 KiB
Go
221 lines
5.3 KiB
Go
package authentik
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"errors"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
const (
|
|
defaultFlowSessionTTL = 20 * time.Minute
|
|
flowSessionKeyPrefix = "auth_flow_session:"
|
|
)
|
|
|
|
type storedFlowSession struct {
|
|
Slug string `json:"slug"`
|
|
Cookies []SerializedCookie `json:"cookies"`
|
|
CreatedAt time.Time `json:"createdAt"`
|
|
Completed bool `json:"completed,omitempty"`
|
|
}
|
|
|
|
type flowSessionEntry struct {
|
|
slug string
|
|
cookies []SerializedCookie
|
|
createdAt time.Time
|
|
completed bool
|
|
}
|
|
|
|
// FlowSessionStore keeps Authentik flow executor sessions (memory + optional KeyDB).
|
|
type FlowSessionStore struct {
|
|
mu sync.Mutex
|
|
baseURL string
|
|
ttl time.Duration
|
|
rdb *redis.Client
|
|
items map[string]*flowSessionEntry
|
|
}
|
|
|
|
func NewFlowSessionStore(baseURL string, rdb *redis.Client) *FlowSessionStore {
|
|
return &FlowSessionStore{
|
|
baseURL: baseURL,
|
|
ttl: defaultFlowSessionTTL,
|
|
rdb: rdb,
|
|
items: make(map[string]*flowSessionEntry),
|
|
}
|
|
}
|
|
|
|
func (s *FlowSessionStore) Start(ctx context.Context, slug, query string) (sessionID string, challenge FlowChallenge, err error) {
|
|
executor, err := NewFlowExecutor(s.baseURL, slug)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
challenge, err = executor.GetChallenge(ctx, query)
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
id, err := randomSessionID()
|
|
if err != nil {
|
|
return "", nil, err
|
|
}
|
|
done, _ := FlowDone(challenge)
|
|
if !done {
|
|
entry := &flowSessionEntry{
|
|
slug: slug,
|
|
cookies: executor.ExportCookies(),
|
|
createdAt: time.Now(),
|
|
}
|
|
if err := s.save(ctx, id, entry); err != nil {
|
|
return "", nil, err
|
|
}
|
|
}
|
|
return id, challenge, nil
|
|
}
|
|
|
|
func (s *FlowSessionStore) Respond(ctx context.Context, sessionID, slug, query string, payload map[string]any) (FlowChallenge, error) {
|
|
entry, err := s.load(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if entry.slug != slug {
|
|
return nil, ErrFlowSessionSlugMismatch
|
|
}
|
|
if entry.completed {
|
|
return nil, ErrFlowSessionAlreadyCompleted
|
|
}
|
|
executor, err := RestoreFlowExecutor(s.baseURL, slug, entry.cookies)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
challenge, err := executor.PostResponse(ctx, query, payload)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
entry.cookies = executor.ExportCookies()
|
|
done, denied := FlowDone(challenge)
|
|
if done && !denied {
|
|
entry.completed = true
|
|
if err := s.save(ctx, sessionID, entry); err != nil {
|
|
return nil, err
|
|
}
|
|
return challenge, nil
|
|
}
|
|
if err := s.save(ctx, sessionID, entry); err != nil {
|
|
return nil, err
|
|
}
|
|
return challenge, nil
|
|
}
|
|
|
|
// CompleteSession returns cookies from a completed flow and removes the session.
|
|
func (s *FlowSessionStore) CompleteSession(ctx context.Context, sessionID, slug string) ([]SerializedCookie, error) {
|
|
entry, err := s.load(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if entry.slug != slug {
|
|
return nil, ErrFlowSessionSlugMismatch
|
|
}
|
|
if !entry.completed {
|
|
return nil, ErrFlowSessionNotCompleted
|
|
}
|
|
cookies := entry.cookies
|
|
s.delete(ctx, sessionID)
|
|
return cookies, nil
|
|
}
|
|
|
|
// SessionCookies returns current Authentik cookies for an active session.
|
|
func (s *FlowSessionStore) SessionCookies(ctx context.Context, sessionID string) ([]SerializedCookie, error) {
|
|
entry, err := s.load(ctx, sessionID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return entry.cookies, nil
|
|
}
|
|
|
|
func (s *FlowSessionStore) Delete(ctx context.Context, sessionID string) {
|
|
s.delete(ctx, sessionID)
|
|
}
|
|
|
|
func (s *FlowSessionStore) save(ctx context.Context, sessionID string, entry *flowSessionEntry) error {
|
|
s.mu.Lock()
|
|
s.items[sessionID] = entry
|
|
s.mu.Unlock()
|
|
if s.rdb == nil {
|
|
return nil
|
|
}
|
|
stored := storedFlowSession{
|
|
Slug: entry.slug,
|
|
Cookies: entry.cookies,
|
|
CreatedAt: entry.createdAt,
|
|
Completed: entry.completed,
|
|
}
|
|
raw, err := json.Marshal(stored)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.rdb.Set(ctx, flowSessionKeyPrefix+sessionID, raw, s.ttl).Err()
|
|
}
|
|
|
|
func (s *FlowSessionStore) load(ctx context.Context, sessionID string) (*flowSessionEntry, error) {
|
|
if s.rdb != nil {
|
|
raw, err := s.rdb.Get(ctx, flowSessionKeyPrefix+sessionID).Bytes()
|
|
if err != nil {
|
|
if errors.Is(err, redis.Nil) {
|
|
return nil, ErrFlowSessionNotFound
|
|
}
|
|
return nil, err
|
|
}
|
|
var stored storedFlowSession
|
|
if err := json.Unmarshal(raw, &stored); err != nil {
|
|
return nil, err
|
|
}
|
|
if time.Since(stored.CreatedAt) > s.ttl {
|
|
s.delete(ctx, sessionID)
|
|
return nil, ErrFlowSessionNotFound
|
|
}
|
|
entry := &flowSessionEntry{
|
|
slug: stored.Slug,
|
|
cookies: stored.Cookies,
|
|
createdAt: stored.CreatedAt,
|
|
completed: stored.Completed,
|
|
}
|
|
s.mu.Lock()
|
|
s.items[sessionID] = entry
|
|
s.mu.Unlock()
|
|
return entry, nil
|
|
}
|
|
|
|
s.mu.Lock()
|
|
entry, ok := s.items[sessionID]
|
|
s.mu.Unlock()
|
|
if !ok {
|
|
return nil, ErrFlowSessionNotFound
|
|
}
|
|
if time.Since(entry.createdAt) > s.ttl {
|
|
s.delete(ctx, sessionID)
|
|
return nil, ErrFlowSessionNotFound
|
|
}
|
|
return entry, nil
|
|
}
|
|
|
|
func (s *FlowSessionStore) delete(ctx context.Context, sessionID string) {
|
|
s.mu.Lock()
|
|
delete(s.items, sessionID)
|
|
s.mu.Unlock()
|
|
if s.rdb != nil {
|
|
_ = s.rdb.Del(ctx, flowSessionKeyPrefix+sessionID).Err()
|
|
}
|
|
}
|
|
|
|
func randomSessionID() (string, error) {
|
|
buf := make([]byte, 24)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return "", err
|
|
}
|
|
return hex.EncodeToString(buf), nil
|
|
}
|