feat(auth): implement flow completion and rate limiting for authentication flows
- Added a new handler for completing authentication flows, including session validation and cookie management. - Implemented flow rate limiting to restrict the number of flow start requests per client IP. - Enhanced flow session management with Redis support for persistent session storage. - Updated existing handlers to integrate the new flow completion logic and error handling for various session states. - Introduced unit tests for the new flow completion and rate limiting functionalities to ensure reliability.
This commit is contained in:
parent
e4549f29b2
commit
8bbc539d77
87
internal/api/auth/complete.go
Normal file
87
internal/api/auth/complete.go
Normal file
@ -0,0 +1,87 @@
|
||||
package authapi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
)
|
||||
|
||||
type flowCompleteRequest struct {
|
||||
ReturnTo string `json:"returnTo"`
|
||||
}
|
||||
|
||||
type flowCompleteResponse struct {
|
||||
RedirectURL string `json:"redirectUrl"`
|
||||
}
|
||||
|
||||
func (h *Handler) CompleteFlow(w http.ResponseWriter, r *http.Request) {
|
||||
sessionID := readFlowSessionCookie(r)
|
||||
if sessionID == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, "flow_session_missing", "flow session cookie required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
var req flowCompleteRequest
|
||||
if r.Body != nil {
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
}
|
||||
returnTo := strings.TrimSpace(req.ReturnTo)
|
||||
if returnTo == "" {
|
||||
returnTo = "/mail/inbox"
|
||||
}
|
||||
if !strings.HasPrefix(returnTo, "/") || strings.HasPrefix(returnTo, "//") {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "invalid_return_to", "returnTo must be a relative path", nil)
|
||||
return
|
||||
}
|
||||
|
||||
slug := FlowAuthentication
|
||||
cookies, err := h.flows.CompleteSession(r.Context(), sessionID, slug)
|
||||
if err != nil {
|
||||
if errors.Is(err, authentik.ErrFlowSessionNotFound) {
|
||||
clearFlowSessionCookie(w)
|
||||
apiresponse.WriteError(w, r, http.StatusGone, "flow_session_expired", "flow session expired; restart the flow", nil)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, authentik.ErrFlowSessionNotCompleted) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "flow_not_completed", "authentication flow is not finished yet", nil)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, authentik.ErrFlowSessionSlugMismatch) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_complete_failed", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
clearFlowSessionCookie(w)
|
||||
setBrowserAuthentikCookies(w, cookies)
|
||||
|
||||
loginURL := buildLoginRedirectURL(h.appURL, returnTo)
|
||||
apiresponse.WriteJSON(w, http.StatusOK, flowCompleteResponse{RedirectURL: loginURL})
|
||||
}
|
||||
|
||||
func buildLoginRedirectURL(appURL, returnTo string) string {
|
||||
base := strings.TrimRight(strings.TrimSpace(appURL), "/")
|
||||
if base == "" {
|
||||
base = "http://localhost:3004"
|
||||
}
|
||||
params := url.Values{}
|
||||
params.Set("returnTo", returnTo)
|
||||
return base + "/api/auth/login?" + params.Encode()
|
||||
}
|
||||
|
||||
func setBrowserAuthentikCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) {
|
||||
for _, c := range authentik.BrowserAuthentikCookies(stored) {
|
||||
http.SetCookie(w, c)
|
||||
}
|
||||
}
|
||||
|
||||
func forwardFlowCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) {
|
||||
setBrowserAuthentikCookies(w, stored)
|
||||
}
|
||||
@ -8,31 +8,38 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
)
|
||||
|
||||
const (
|
||||
FlowEnrollment = "ulti-enrollment"
|
||||
FlowRecovery = "ulti-recovery"
|
||||
flowSessionCookie = "ulti_flow_session"
|
||||
flowSessionMaxAge = 20 * time.Minute
|
||||
FlowEnrollment = "ulti-enrollment"
|
||||
FlowRecovery = "ulti-recovery"
|
||||
FlowAuthentication = "default-authentication-flow"
|
||||
flowSessionCookie = "ulti_flow_session"
|
||||
flowSessionMaxAge = 20 * time.Minute
|
||||
)
|
||||
|
||||
var allowedFlowSlugs = map[string]struct{}{
|
||||
FlowEnrollment: {},
|
||||
FlowRecovery: {},
|
||||
FlowEnrollment: {},
|
||||
FlowRecovery: {},
|
||||
FlowAuthentication: {},
|
||||
}
|
||||
|
||||
// Handler exposes public Authentik flow executor endpoints for custom auth UI.
|
||||
type Handler struct {
|
||||
flows *authentik.FlowSessionStore
|
||||
flows *authentik.FlowSessionStore
|
||||
limiter *FlowRateLimiter
|
||||
appURL string
|
||||
}
|
||||
|
||||
func NewHandler(baseURL string) *Handler {
|
||||
func NewHandler(baseURL, appURL string, rdb *redis.Client) *Handler {
|
||||
return &Handler{
|
||||
flows: authentik.NewFlowSessionStore(baseURL),
|
||||
flows: authentik.NewFlowSessionStore(baseURL, rdb),
|
||||
limiter: NewFlowRateLimiter(rdb),
|
||||
appURL: appURL,
|
||||
}
|
||||
}
|
||||
|
||||
@ -40,6 +47,7 @@ func (h *Handler) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/flows/{slug}/start", h.StartFlow)
|
||||
r.Post("/flows/{slug}/respond", h.RespondFlow)
|
||||
r.Post("/flows/complete", h.CompleteFlow)
|
||||
return r
|
||||
}
|
||||
|
||||
@ -55,6 +63,11 @@ type flowRespondRequest struct {
|
||||
}
|
||||
|
||||
func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) {
|
||||
if !h.limiter.Allow(r) {
|
||||
apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, "too many flow start attempts", nil)
|
||||
return
|
||||
}
|
||||
|
||||
slug, ok := validateFlowSlug(w, r, chi.URLParam(r, "slug"))
|
||||
if !ok {
|
||||
return
|
||||
@ -69,6 +82,9 @@ func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) {
|
||||
done, denied := authentik.FlowDone(challenge)
|
||||
if !done {
|
||||
setFlowSessionCookie(w, sessionID)
|
||||
if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil {
|
||||
forwardFlowCookies(w, cookies)
|
||||
}
|
||||
} else {
|
||||
clearFlowSessionCookie(w)
|
||||
}
|
||||
@ -114,14 +130,30 @@ func (h *Handler) RespondFlow(w http.ResponseWriter, r *http.Request) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, authentik.ErrFlowSessionAlreadyCompleted) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "flow_already_completed", "flow already completed", nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_respond_failed", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
done, denied := authentik.FlowDone(challenge)
|
||||
if done {
|
||||
clearFlowSessionCookie(w)
|
||||
if slug == FlowAuthentication {
|
||||
if denied {
|
||||
clearFlowSessionCookie(w)
|
||||
h.flows.Delete(r.Context(), sessionID)
|
||||
}
|
||||
// Successful auth: keep session cookie until /flows/complete.
|
||||
} else {
|
||||
clearFlowSessionCookie(w)
|
||||
h.flows.Delete(r.Context(), sessionID)
|
||||
}
|
||||
} else if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil {
|
||||
forwardFlowCookies(w, cookies)
|
||||
}
|
||||
|
||||
writeFlowJSON(w, http.StatusOK, flowStartResponse{
|
||||
SessionID: sessionID,
|
||||
Challenge: challenge,
|
||||
|
||||
74
internal/api/auth/handlers_test.go
Normal file
74
internal/api/auth/handlers_test.go
Normal file
@ -0,0 +1,74 @@
|
||||
package authapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
)
|
||||
|
||||
func TestValidateFlowSlugAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, slug := range []string{FlowEnrollment, FlowRecovery, FlowAuthentication} {
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/flows/"+slug+"/start", nil)
|
||||
got, ok := validateFlowSlug(rec, req, slug)
|
||||
if !ok || got != slug {
|
||||
t.Fatalf("slug %q: ok=%v got=%q code=%d", slug, ok, got, rec.Code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFlowSlugRejected(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/flows/default-invalidation-flow/start", nil)
|
||||
_, ok := validateFlowSlug(rec, req, "default-invalidation-flow")
|
||||
if ok {
|
||||
t.Fatal("expected slug to be rejected")
|
||||
}
|
||||
if rec.Code != http.StatusNotFound {
|
||||
t.Fatalf("status = %d", rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRespondFlowMissingCookie(t *testing.T) {
|
||||
t.Parallel()
|
||||
h := NewHandler("http://127.0.0.1:1", "http://localhost:3004", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/flows/ulti-enrollment/respond", bytes.NewReader([]byte(`{"payload":{"component":"x"}}`)))
|
||||
rctx := chi.NewRouteContext()
|
||||
rctx.URLParams.Add("slug", FlowEnrollment)
|
||||
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||
h.RespondFlow(rec, req)
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildLoginRedirectURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := buildLoginRedirectURL("http://localhost:3004", "/mail/inbox")
|
||||
want := "http://localhost:3004/api/auth/login?returnTo=%2Fmail%2Finbox"
|
||||
if got != want {
|
||||
t.Fatalf("got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowRateLimiterMemory(t *testing.T) {
|
||||
t.Parallel()
|
||||
lim := NewFlowRateLimiter(nil)
|
||||
req := httptest.NewRequest(http.MethodPost, "/start", nil)
|
||||
req.RemoteAddr = "203.0.113.1:1234"
|
||||
for i := 0; i < flowRateLimitMax; i++ {
|
||||
if !lim.Allow(req) {
|
||||
t.Fatalf("attempt %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
if lim.Allow(req) {
|
||||
t.Fatal("expected rate limit block")
|
||||
}
|
||||
}
|
||||
94
internal/api/auth/rate_limit.go
Normal file
94
internal/api/auth/rate_limit.go
Normal file
@ -0,0 +1,94 @@
|
||||
package authapi
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
flowRateLimitMax = 10
|
||||
flowRateLimitWindow = time.Minute
|
||||
flowRateKeyPrefix = "auth_flow_rate:"
|
||||
)
|
||||
|
||||
// FlowRateLimiter limits flow start requests per client IP.
|
||||
type FlowRateLimiter struct {
|
||||
rdb *redis.Client
|
||||
mu sync.Mutex
|
||||
// fallback when KeyDB unavailable
|
||||
hits map[string][]time.Time
|
||||
}
|
||||
|
||||
func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter {
|
||||
return &FlowRateLimiter{
|
||||
rdb: rdb,
|
||||
hits: make(map[string][]time.Time),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *FlowRateLimiter) Allow(r *http.Request) bool {
|
||||
ip := clientIP(r)
|
||||
if ip == "" {
|
||||
return true
|
||||
}
|
||||
if l.rdb != nil {
|
||||
return l.allowRedis(r, ip)
|
||||
}
|
||||
return l.allowMemory(ip)
|
||||
}
|
||||
|
||||
func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool {
|
||||
ctx := r.Context()
|
||||
key := flowRateKeyPrefix + ip
|
||||
count, err := l.rdb.Incr(ctx, key).Result()
|
||||
if err != nil {
|
||||
return l.allowMemory(ip)
|
||||
}
|
||||
if count == 1 {
|
||||
_ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err()
|
||||
}
|
||||
return count <= flowRateLimitMax
|
||||
}
|
||||
|
||||
func (l *FlowRateLimiter) allowMemory(ip string) bool {
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-flowRateLimitWindow)
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
prev := l.hits[ip]
|
||||
next := prev[:0]
|
||||
for _, t := range prev {
|
||||
if t.After(cutoff) {
|
||||
next = append(next, t)
|
||||
}
|
||||
}
|
||||
if len(next) >= flowRateLimitMax {
|
||||
l.hits[ip] = next
|
||||
return false
|
||||
}
|
||||
next = append(next, now)
|
||||
l.hits[ip] = next
|
||||
return true
|
||||
}
|
||||
|
||||
func clientIP(r *http.Request) string {
|
||||
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
}
|
||||
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
|
||||
return xrip
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err != nil {
|
||||
return strings.TrimSpace(r.RemoteAddr)
|
||||
}
|
||||
return host
|
||||
}
|
||||
125
internal/authentik/flow_cookies.go
Normal file
125
internal/authentik/flow_cookies.go
Normal file
@ -0,0 +1,125 @@
|
||||
package authentik
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const authentikCookiePath = "/auth/"
|
||||
|
||||
// SerializedCookie stores an HTTP cookie for flow session persistence.
|
||||
type SerializedCookie struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
Path string `json:"path,omitempty"`
|
||||
Domain string `json:"domain,omitempty"`
|
||||
MaxAge int `json:"maxAge,omitempty"`
|
||||
Expires time.Time `json:"expires,omitempty"`
|
||||
Secure bool `json:"secure,omitempty"`
|
||||
HTTPOnly bool `json:"httpOnly,omitempty"`
|
||||
SameSite http.SameSite `json:"-"`
|
||||
SameSiteName string `json:"sameSite,omitempty"`
|
||||
}
|
||||
|
||||
func (fe *FlowExecutor) ExportCookies() []SerializedCookie {
|
||||
if fe.client.Jar == nil {
|
||||
return nil
|
||||
}
|
||||
raw := fe.client.Jar.Cookies(fe.cookieURL())
|
||||
out := make([]SerializedCookie, 0, len(raw))
|
||||
for _, c := range raw {
|
||||
out = append(out, serializeCookie(c))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (fe *FlowExecutor) ImportCookies(stored []SerializedCookie) {
|
||||
if fe.client.Jar == nil || len(stored) == 0 {
|
||||
return
|
||||
}
|
||||
cookies := make([]*http.Cookie, 0, len(stored))
|
||||
for _, sc := range stored {
|
||||
cookies = append(cookies, deserializeCookie(sc))
|
||||
}
|
||||
fe.client.Jar.SetCookies(fe.cookieURL(), cookies)
|
||||
}
|
||||
|
||||
func serializeCookie(c *http.Cookie) SerializedCookie {
|
||||
sc := SerializedCookie{
|
||||
Name: c.Name,
|
||||
Value: c.Value,
|
||||
Path: c.Path,
|
||||
Domain: c.Domain,
|
||||
MaxAge: c.MaxAge,
|
||||
Expires: c.Expires,
|
||||
Secure: c.Secure,
|
||||
HTTPOnly: c.HttpOnly,
|
||||
SameSite: c.SameSite,
|
||||
}
|
||||
switch c.SameSite {
|
||||
case http.SameSiteDefaultMode:
|
||||
sc.SameSiteName = "Default"
|
||||
case http.SameSiteLaxMode:
|
||||
sc.SameSiteName = "Lax"
|
||||
case http.SameSiteStrictMode:
|
||||
sc.SameSiteName = "Strict"
|
||||
case http.SameSiteNoneMode:
|
||||
sc.SameSiteName = "None"
|
||||
}
|
||||
return sc
|
||||
}
|
||||
|
||||
func deserializeCookie(sc SerializedCookie) *http.Cookie {
|
||||
c := &http.Cookie{
|
||||
Name: sc.Name,
|
||||
Value: sc.Value,
|
||||
Path: sc.Path,
|
||||
Domain: sc.Domain,
|
||||
MaxAge: sc.MaxAge,
|
||||
Expires: sc.Expires,
|
||||
Secure: sc.Secure,
|
||||
HttpOnly: sc.HTTPOnly,
|
||||
}
|
||||
switch strings.ToLower(sc.SameSiteName) {
|
||||
case "lax":
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
case "strict":
|
||||
c.SameSite = http.SameSiteStrictMode
|
||||
case "none":
|
||||
c.SameSite = http.SameSiteNoneMode
|
||||
default:
|
||||
c.SameSite = http.SameSiteLaxMode
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
// BrowserAuthentikCookies rewrites stored cookies for the public Authentik path.
|
||||
func BrowserAuthentikCookies(stored []SerializedCookie) []*http.Cookie {
|
||||
out := make([]*http.Cookie, 0, len(stored))
|
||||
for _, sc := range stored {
|
||||
c := deserializeCookie(sc)
|
||||
c.Path = authentikCookiePath
|
||||
c.Domain = ""
|
||||
out = append(out, c)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// RestoreFlowExecutor rebuilds an executor from persisted cookies.
|
||||
func RestoreFlowExecutor(baseURL, slug string, stored []SerializedCookie) (*FlowExecutor, error) {
|
||||
fe, err := NewFlowExecutor(baseURL, slug)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fe.ImportCookies(stored)
|
||||
return fe, nil
|
||||
}
|
||||
|
||||
// CloneCookieJar copies cookies between executors (tests).
|
||||
func CloneCookieJar(from, to *FlowExecutor) {
|
||||
if from == nil || to == nil {
|
||||
return
|
||||
}
|
||||
to.ImportCookies(from.ExportCookies())
|
||||
}
|
||||
@ -3,6 +3,8 @@ package authentik
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrFlowSessionNotFound = errors.New("flow session not found")
|
||||
ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch")
|
||||
ErrFlowSessionNotFound = errors.New("flow session not found")
|
||||
ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch")
|
||||
ErrFlowSessionNotCompleted = errors.New("flow session not completed")
|
||||
ErrFlowSessionAlreadyCompleted = errors.New("flow session already completed")
|
||||
)
|
||||
|
||||
@ -106,12 +106,23 @@ func (fe *FlowExecutor) PostResponse(ctx context.Context, query string, payload
|
||||
return fe.doChallenge(req)
|
||||
}
|
||||
|
||||
func (fe *FlowExecutor) csrfToken() string {
|
||||
func (fe *FlowExecutor) cookieURL() *url.URL {
|
||||
u, err := url.Parse(fe.baseURL)
|
||||
if err != nil {
|
||||
return ""
|
||||
if err != nil || u.Host == "" {
|
||||
return &url.URL{Scheme: "http", Host: "localhost", Path: "/auth/"}
|
||||
}
|
||||
for _, c := range fe.client.Jar.Cookies(u) {
|
||||
if u.Path == "" || u.Path == "/" {
|
||||
u.Path = "/auth/"
|
||||
} else if !strings.HasSuffix(u.Path, "/") {
|
||||
u.Path += "/"
|
||||
}
|
||||
u.RawQuery = ""
|
||||
u.Fragment = ""
|
||||
return u
|
||||
}
|
||||
|
||||
func (fe *FlowExecutor) csrfToken() string {
|
||||
for _, c := range fe.client.Jar.Cookies(fe.cookieURL()) {
|
||||
if c.Name == "authentik_csrf" {
|
||||
return c.Value
|
||||
}
|
||||
|
||||
@ -2,13 +2,12 @@ package authentik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestFlowExecutorGetChallenge(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Smoke test structure only — integration tests hit real Authentik.
|
||||
fe, err := NewFlowExecutor("http://localhost:9000", "test-flow")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@ -36,10 +35,52 @@ func TestFlowDone(t *testing.T) {
|
||||
|
||||
func TestFlowSessionStoreLifecycle(t *testing.T) {
|
||||
t.Parallel()
|
||||
store := NewFlowSessionStore("http://127.0.0.1:1")
|
||||
store := NewFlowSessionStore("http://127.0.0.1:1", nil)
|
||||
_, err := store.Respond(context.Background(), "missing", "ulti-enrollment", "", map[string]any{"component": "x"})
|
||||
if err != ErrFlowSessionNotFound {
|
||||
t.Fatalf("expected ErrFlowSessionNotFound, got %v", err)
|
||||
}
|
||||
_ = httptest.NewRecorder()
|
||||
}
|
||||
|
||||
func TestCookieRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
fe, err := NewFlowExecutor("http://localhost/auth", "test-flow")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
fe.ImportCookies([]SerializedCookie{
|
||||
{Name: "authentik_csrf", Value: "abc123", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true},
|
||||
{Name: "authentik_session", Value: "sess", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true},
|
||||
})
|
||||
exported := fe.ExportCookies()
|
||||
if len(exported) != 2 {
|
||||
t.Fatalf("exported %d cookies", len(exported))
|
||||
}
|
||||
fe2, err := RestoreFlowExecutor("http://localhost/auth", "test-flow", exported)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if fe2.csrfToken() != "abc123" {
|
||||
t.Fatalf("csrf = %q", fe2.csrfToken())
|
||||
}
|
||||
browser := BrowserAuthentikCookies(exported)
|
||||
if len(browser) != 2 || browser[0].Path != "/auth/" {
|
||||
t.Fatalf("browser cookies: %+v", browser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlowSessionCompleteRequiresDone(t *testing.T) {
|
||||
t.Parallel()
|
||||
store := NewFlowSessionStore("http://127.0.0.1:1", nil)
|
||||
store.mu.Lock()
|
||||
store.items["sess1"] = &flowSessionEntry{
|
||||
slug: "default-authentication-flow",
|
||||
createdAt: time.Now(),
|
||||
completed: false,
|
||||
}
|
||||
store.mu.Unlock()
|
||||
_, err := store.CompleteSession(context.Background(), "sess1", "default-authentication-flow")
|
||||
if err != ErrFlowSessionNotCompleted {
|
||||
t.Fatalf("expected ErrFlowSessionNotCompleted, got %v", err)
|
||||
}
|
||||
}
|
||||
@ -4,30 +4,47 @@ import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const defaultFlowSessionTTL = 20 * time.Minute
|
||||
const (
|
||||
defaultFlowSessionTTL = 20 * time.Minute
|
||||
flowSessionKeyPrefix = "auth_flow_session:"
|
||||
)
|
||||
|
||||
type flowSessionEntry struct {
|
||||
executor *FlowExecutor
|
||||
slug string
|
||||
createdAt time.Time
|
||||
type storedFlowSession struct {
|
||||
Slug string `json:"slug"`
|
||||
Cookies []SerializedCookie `json:"cookies"`
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
Completed bool `json:"completed,omitempty"`
|
||||
}
|
||||
|
||||
// FlowSessionStore keeps in-memory Authentik flow executor sessions.
|
||||
type flowSessionEntry struct {
|
||||
slug string
|
||||
cookies []SerializedCookie
|
||||
createdAt time.Time
|
||||
completed bool
|
||||
}
|
||||
|
||||
// FlowSessionStore keeps Authentik flow executor sessions (memory + optional KeyDB).
|
||||
type FlowSessionStore struct {
|
||||
mu sync.Mutex
|
||||
baseURL string
|
||||
ttl time.Duration
|
||||
rdb *redis.Client
|
||||
items map[string]*flowSessionEntry
|
||||
}
|
||||
|
||||
func NewFlowSessionStore(baseURL string) *FlowSessionStore {
|
||||
func NewFlowSessionStore(baseURL string, rdb *redis.Client) *FlowSessionStore {
|
||||
return &FlowSessionStore{
|
||||
baseURL: baseURL,
|
||||
ttl: defaultFlowSessionTTL,
|
||||
rdb: rdb,
|
||||
items: make(map[string]*flowSessionEntry),
|
||||
}
|
||||
}
|
||||
@ -46,64 +63,151 @@ func (s *FlowSessionStore) Start(ctx context.Context, slug, query string) (sessi
|
||||
return "", nil, err
|
||||
}
|
||||
done, _ := FlowDone(challenge)
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupLocked(time.Now())
|
||||
if !done {
|
||||
s.items[id] = &flowSessionEntry{
|
||||
executor: executor,
|
||||
entry := &flowSessionEntry{
|
||||
slug: slug,
|
||||
cookies: executor.ExportCookies(),
|
||||
createdAt: time.Now(),
|
||||
}
|
||||
if err := s.save(ctx, id, entry); err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
}
|
||||
return id, challenge, nil
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) Respond(ctx context.Context, sessionID, slug, query string, payload map[string]any) (FlowChallenge, error) {
|
||||
entry, err := s.get(sessionID)
|
||||
entry, err := s.load(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry.slug != slug {
|
||||
return nil, ErrFlowSessionSlugMismatch
|
||||
}
|
||||
challenge, err := entry.executor.PostResponse(ctx, query, payload)
|
||||
if entry.completed {
|
||||
return nil, ErrFlowSessionAlreadyCompleted
|
||||
}
|
||||
executor, err := RestoreFlowExecutor(s.baseURL, slug, entry.cookies)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
done, _ := FlowDone(challenge)
|
||||
if done {
|
||||
s.deleteLocked(sessionID)
|
||||
challenge, err := executor.PostResponse(ctx, query, payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.cookies = executor.ExportCookies()
|
||||
done, denied := FlowDone(challenge)
|
||||
if done && !denied {
|
||||
entry.completed = true
|
||||
if err := s.save(ctx, sessionID, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return challenge, nil
|
||||
}
|
||||
if err := s.save(ctx, sessionID, entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return challenge, nil
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) Delete(sessionID string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.deleteLocked(sessionID)
|
||||
// CompleteSession returns cookies from a completed flow and removes the session.
|
||||
func (s *FlowSessionStore) CompleteSession(ctx context.Context, sessionID, slug string) ([]SerializedCookie, error) {
|
||||
entry, err := s.load(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry.slug != slug {
|
||||
return nil, ErrFlowSessionSlugMismatch
|
||||
}
|
||||
if !entry.completed {
|
||||
return nil, ErrFlowSessionNotCompleted
|
||||
}
|
||||
cookies := entry.cookies
|
||||
s.delete(ctx, sessionID)
|
||||
return cookies, nil
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) deleteLocked(sessionID string) {
|
||||
delete(s.items, sessionID)
|
||||
// SessionCookies returns current Authentik cookies for an active session.
|
||||
func (s *FlowSessionStore) SessionCookies(ctx context.Context, sessionID string) ([]SerializedCookie, error) {
|
||||
entry, err := s.load(ctx, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entry.cookies, nil
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) get(sessionID string) (*flowSessionEntry, error) {
|
||||
func (s *FlowSessionStore) Delete(ctx context.Context, sessionID string) {
|
||||
s.delete(ctx, sessionID)
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) save(ctx context.Context, sessionID string, entry *flowSessionEntry) error {
|
||||
s.mu.Lock()
|
||||
s.items[sessionID] = entry
|
||||
s.mu.Unlock()
|
||||
if s.rdb == nil {
|
||||
return nil
|
||||
}
|
||||
stored := storedFlowSession{
|
||||
Slug: entry.slug,
|
||||
Cookies: entry.cookies,
|
||||
CreatedAt: entry.createdAt,
|
||||
Completed: entry.completed,
|
||||
}
|
||||
raw, err := json.Marshal(stored)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, flowSessionKeyPrefix+sessionID, raw, s.ttl).Err()
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) load(ctx context.Context, sessionID string) (*flowSessionEntry, error) {
|
||||
if s.rdb != nil {
|
||||
raw, err := s.rdb.Get(ctx, flowSessionKeyPrefix+sessionID).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, ErrFlowSessionNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
var stored storedFlowSession
|
||||
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if time.Since(stored.CreatedAt) > s.ttl {
|
||||
s.delete(ctx, sessionID)
|
||||
return nil, ErrFlowSessionNotFound
|
||||
}
|
||||
entry := &flowSessionEntry{
|
||||
slug: stored.Slug,
|
||||
cookies: stored.Cookies,
|
||||
createdAt: stored.CreatedAt,
|
||||
completed: stored.Completed,
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.items[sessionID] = entry
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.cleanupLocked(time.Now())
|
||||
entry, ok := s.items[sessionID]
|
||||
s.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, ErrFlowSessionNotFound
|
||||
}
|
||||
if time.Since(entry.createdAt) > s.ttl {
|
||||
s.delete(ctx, sessionID)
|
||||
return nil, ErrFlowSessionNotFound
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *FlowSessionStore) cleanupLocked(now time.Time) {
|
||||
for id, entry := range s.items {
|
||||
if now.Sub(entry.createdAt) > s.ttl {
|
||||
delete(s.items, id)
|
||||
}
|
||||
func (s *FlowSessionStore) delete(ctx context.Context, sessionID string) {
|
||||
s.mu.Lock()
|
||||
delete(s.items, sessionID)
|
||||
s.mu.Unlock()
|
||||
if s.rdb != nil {
|
||||
_ = s.rdb.Del(ctx, flowSessionKeyPrefix+sessionID).Err()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -368,7 +368,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
r.Get("/api/v1/mail/addresses/check", mailHandler.CheckAddressAvailability)
|
||||
r.Get("/api/v1/migration/invite", migrationHandler.GetInvite)
|
||||
r.Post("/internal/provision/user", provisionHandler.ProvisionUser)
|
||||
r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL).Routes())
|
||||
r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL, cfg.MailAppURL, rdb).Routes())
|
||||
|
||||
var driveHandler *drive.Handler
|
||||
var driveSvc *drive.Service
|
||||
|
||||
Loading…
Reference in New Issue
Block a user