feat(auth): implement flow completion and rate limiting for authentication flows
- Added a new handler for completing authentication flows, including session validation and cookie management. - Implemented flow rate limiting to restrict the number of flow start requests per client IP. - Enhanced flow session management with Redis support for persistent session storage. - Updated existing handlers to integrate the new flow completion logic and error handling for various session states. - Introduced unit tests for the new flow completion and rate limiting functionalities to ensure reliability.
This commit is contained in:
parent
e4549f29b2
commit
8bbc539d77
87
internal/api/auth/complete.go
Normal file
87
internal/api/auth/complete.go
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
package authapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||||
|
)
|
||||||
|
|
||||||
|
type flowCompleteRequest struct {
|
||||||
|
ReturnTo string `json:"returnTo"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type flowCompleteResponse struct {
|
||||||
|
RedirectURL string `json:"redirectUrl"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) CompleteFlow(w http.ResponseWriter, r *http.Request) {
|
||||||
|
sessionID := readFlowSessionCookie(r)
|
||||||
|
if sessionID == "" {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, "flow_session_missing", "flow session cookie required", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req flowCompleteRequest
|
||||||
|
if r.Body != nil {
|
||||||
|
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||||
|
}
|
||||||
|
returnTo := strings.TrimSpace(req.ReturnTo)
|
||||||
|
if returnTo == "" {
|
||||||
|
returnTo = "/mail/inbox"
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(returnTo, "/") || strings.HasPrefix(returnTo, "//") {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusBadRequest, "invalid_return_to", "returnTo must be a relative path", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slug := FlowAuthentication
|
||||||
|
cookies, err := h.flows.CompleteSession(r.Context(), sessionID, slug)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, authentik.ErrFlowSessionNotFound) {
|
||||||
|
clearFlowSessionCookie(w)
|
||||||
|
apiresponse.WriteError(w, r, http.StatusGone, "flow_session_expired", "flow session expired; restart the flow", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, authentik.ErrFlowSessionNotCompleted) {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusConflict, "flow_not_completed", "authentication flow is not finished yet", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, authentik.ErrFlowSessionSlugMismatch) {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_complete_failed", err.Error(), nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clearFlowSessionCookie(w)
|
||||||
|
setBrowserAuthentikCookies(w, cookies)
|
||||||
|
|
||||||
|
loginURL := buildLoginRedirectURL(h.appURL, returnTo)
|
||||||
|
apiresponse.WriteJSON(w, http.StatusOK, flowCompleteResponse{RedirectURL: loginURL})
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildLoginRedirectURL(appURL, returnTo string) string {
|
||||||
|
base := strings.TrimRight(strings.TrimSpace(appURL), "/")
|
||||||
|
if base == "" {
|
||||||
|
base = "http://localhost:3004"
|
||||||
|
}
|
||||||
|
params := url.Values{}
|
||||||
|
params.Set("returnTo", returnTo)
|
||||||
|
return base + "/api/auth/login?" + params.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
func setBrowserAuthentikCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) {
|
||||||
|
for _, c := range authentik.BrowserAuthentikCookies(stored) {
|
||||||
|
http.SetCookie(w, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func forwardFlowCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) {
|
||||||
|
setBrowserAuthentikCookies(w, stored)
|
||||||
|
}
|
||||||
@ -8,31 +8,38 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
|
||||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
FlowEnrollment = "ulti-enrollment"
|
FlowEnrollment = "ulti-enrollment"
|
||||||
FlowRecovery = "ulti-recovery"
|
FlowRecovery = "ulti-recovery"
|
||||||
flowSessionCookie = "ulti_flow_session"
|
FlowAuthentication = "default-authentication-flow"
|
||||||
flowSessionMaxAge = 20 * time.Minute
|
flowSessionCookie = "ulti_flow_session"
|
||||||
|
flowSessionMaxAge = 20 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
var allowedFlowSlugs = map[string]struct{}{
|
var allowedFlowSlugs = map[string]struct{}{
|
||||||
FlowEnrollment: {},
|
FlowEnrollment: {},
|
||||||
FlowRecovery: {},
|
FlowRecovery: {},
|
||||||
|
FlowAuthentication: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler exposes public Authentik flow executor endpoints for custom auth UI.
|
// Handler exposes public Authentik flow executor endpoints for custom auth UI.
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
flows *authentik.FlowSessionStore
|
flows *authentik.FlowSessionStore
|
||||||
|
limiter *FlowRateLimiter
|
||||||
|
appURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(baseURL string) *Handler {
|
func NewHandler(baseURL, appURL string, rdb *redis.Client) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
flows: authentik.NewFlowSessionStore(baseURL),
|
flows: authentik.NewFlowSessionStore(baseURL, rdb),
|
||||||
|
limiter: NewFlowRateLimiter(rdb),
|
||||||
|
appURL: appURL,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -40,6 +47,7 @@ func (h *Handler) Routes() chi.Router {
|
|||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Post("/flows/{slug}/start", h.StartFlow)
|
r.Post("/flows/{slug}/start", h.StartFlow)
|
||||||
r.Post("/flows/{slug}/respond", h.RespondFlow)
|
r.Post("/flows/{slug}/respond", h.RespondFlow)
|
||||||
|
r.Post("/flows/complete", h.CompleteFlow)
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -55,6 +63,11 @@ type flowRespondRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !h.limiter.Allow(r) {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, "too many flow start attempts", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
slug, ok := validateFlowSlug(w, r, chi.URLParam(r, "slug"))
|
slug, ok := validateFlowSlug(w, r, chi.URLParam(r, "slug"))
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
@ -69,6 +82,9 @@ func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) {
|
|||||||
done, denied := authentik.FlowDone(challenge)
|
done, denied := authentik.FlowDone(challenge)
|
||||||
if !done {
|
if !done {
|
||||||
setFlowSessionCookie(w, sessionID)
|
setFlowSessionCookie(w, sessionID)
|
||||||
|
if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil {
|
||||||
|
forwardFlowCookies(w, cookies)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
clearFlowSessionCookie(w)
|
clearFlowSessionCookie(w)
|
||||||
}
|
}
|
||||||
@ -114,14 +130,30 @@ func (h *Handler) RespondFlow(w http.ResponseWriter, r *http.Request) {
|
|||||||
apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil)
|
apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if errors.Is(err, authentik.ErrFlowSessionAlreadyCompleted) {
|
||||||
|
apiresponse.WriteError(w, r, http.StatusConflict, "flow_already_completed", "flow already completed", nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_respond_failed", err.Error(), nil)
|
apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_respond_failed", err.Error(), nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
done, denied := authentik.FlowDone(challenge)
|
done, denied := authentik.FlowDone(challenge)
|
||||||
if done {
|
if done {
|
||||||
clearFlowSessionCookie(w)
|
if slug == FlowAuthentication {
|
||||||
|
if denied {
|
||||||
|
clearFlowSessionCookie(w)
|
||||||
|
h.flows.Delete(r.Context(), sessionID)
|
||||||
|
}
|
||||||
|
// Successful auth: keep session cookie until /flows/complete.
|
||||||
|
} else {
|
||||||
|
clearFlowSessionCookie(w)
|
||||||
|
h.flows.Delete(r.Context(), sessionID)
|
||||||
|
}
|
||||||
|
} else if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil {
|
||||||
|
forwardFlowCookies(w, cookies)
|
||||||
}
|
}
|
||||||
|
|
||||||
writeFlowJSON(w, http.StatusOK, flowStartResponse{
|
writeFlowJSON(w, http.StatusOK, flowStartResponse{
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
Challenge: challenge,
|
Challenge: challenge,
|
||||||
|
|||||||
74
internal/api/auth/handlers_test.go
Normal file
74
internal/api/auth/handlers_test.go
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
package authapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestValidateFlowSlugAllowed(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
for _, slug := range []string{FlowEnrollment, FlowRecovery, FlowAuthentication} {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/flows/"+slug+"/start", nil)
|
||||||
|
got, ok := validateFlowSlug(rec, req, slug)
|
||||||
|
if !ok || got != slug {
|
||||||
|
t.Fatalf("slug %q: ok=%v got=%q code=%d", slug, ok, got, rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateFlowSlugRejected(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/flows/default-invalidation-flow/start", nil)
|
||||||
|
_, ok := validateFlowSlug(rec, req, "default-invalidation-flow")
|
||||||
|
if ok {
|
||||||
|
t.Fatal("expected slug to be rejected")
|
||||||
|
}
|
||||||
|
if rec.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status = %d", rec.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRespondFlowMissingCookie(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
h := NewHandler("http://127.0.0.1:1", "http://localhost:3004", nil)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/flows/ulti-enrollment/respond", bytes.NewReader([]byte(`{"payload":{"component":"x"}}`)))
|
||||||
|
rctx := chi.NewRouteContext()
|
||||||
|
rctx.URLParams.Add("slug", FlowEnrollment)
|
||||||
|
req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx))
|
||||||
|
h.RespondFlow(rec, req)
|
||||||
|
if rec.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildLoginRedirectURL(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
got := buildLoginRedirectURL("http://localhost:3004", "/mail/inbox")
|
||||||
|
want := "http://localhost:3004/api/auth/login?returnTo=%2Fmail%2Finbox"
|
||||||
|
if got != want {
|
||||||
|
t.Fatalf("got %q want %q", got, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowRateLimiterMemory(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
lim := NewFlowRateLimiter(nil)
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/start", nil)
|
||||||
|
req.RemoteAddr = "203.0.113.1:1234"
|
||||||
|
for i := 0; i < flowRateLimitMax; i++ {
|
||||||
|
if !lim.Allow(req) {
|
||||||
|
t.Fatalf("attempt %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if lim.Allow(req) {
|
||||||
|
t.Fatal("expected rate limit block")
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/api/auth/rate_limit.go
Normal file
94
internal/api/auth/rate_limit.go
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
package authapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
flowRateLimitMax = 10
|
||||||
|
flowRateLimitWindow = time.Minute
|
||||||
|
flowRateKeyPrefix = "auth_flow_rate:"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FlowRateLimiter limits flow start requests per client IP.
|
||||||
|
type FlowRateLimiter struct {
|
||||||
|
rdb *redis.Client
|
||||||
|
mu sync.Mutex
|
||||||
|
// fallback when KeyDB unavailable
|
||||||
|
hits map[string][]time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter {
|
||||||
|
return &FlowRateLimiter{
|
||||||
|
rdb: rdb,
|
||||||
|
hits: make(map[string][]time.Time),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *FlowRateLimiter) Allow(r *http.Request) bool {
|
||||||
|
ip := clientIP(r)
|
||||||
|
if ip == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if l.rdb != nil {
|
||||||
|
return l.allowRedis(r, ip)
|
||||||
|
}
|
||||||
|
return l.allowMemory(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool {
|
||||||
|
ctx := r.Context()
|
||||||
|
key := flowRateKeyPrefix + ip
|
||||||
|
count, err := l.rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return l.allowMemory(ip)
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
_ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err()
|
||||||
|
}
|
||||||
|
return count <= flowRateLimitMax
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *FlowRateLimiter) allowMemory(ip string) bool {
|
||||||
|
now := time.Now()
|
||||||
|
cutoff := now.Add(-flowRateLimitWindow)
|
||||||
|
l.mu.Lock()
|
||||||
|
defer l.mu.Unlock()
|
||||||
|
prev := l.hits[ip]
|
||||||
|
next := prev[:0]
|
||||||
|
for _, t := range prev {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
next = append(next, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(next) >= flowRateLimitMax {
|
||||||
|
l.hits[ip] = next
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
next = append(next, now)
|
||||||
|
l.hits[ip] = next
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func clientIP(r *http.Request) string {
|
||||||
|
if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" {
|
||||||
|
parts := strings.Split(xff, ",")
|
||||||
|
if len(parts) > 0 {
|
||||||
|
return strings.TrimSpace(parts[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" {
|
||||||
|
return xrip
|
||||||
|
}
|
||||||
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err != nil {
|
||||||
|
return strings.TrimSpace(r.RemoteAddr)
|
||||||
|
}
|
||||||
|
return host
|
||||||
|
}
|
||||||
125
internal/authentik/flow_cookies.go
Normal file
125
internal/authentik/flow_cookies.go
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
package authentik
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const authentikCookiePath = "/auth/"
|
||||||
|
|
||||||
|
// SerializedCookie stores an HTTP cookie for flow session persistence.
|
||||||
|
type SerializedCookie struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Value string `json:"value"`
|
||||||
|
Path string `json:"path,omitempty"`
|
||||||
|
Domain string `json:"domain,omitempty"`
|
||||||
|
MaxAge int `json:"maxAge,omitempty"`
|
||||||
|
Expires time.Time `json:"expires,omitempty"`
|
||||||
|
Secure bool `json:"secure,omitempty"`
|
||||||
|
HTTPOnly bool `json:"httpOnly,omitempty"`
|
||||||
|
SameSite http.SameSite `json:"-"`
|
||||||
|
SameSiteName string `json:"sameSite,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fe *FlowExecutor) ExportCookies() []SerializedCookie {
|
||||||
|
if fe.client.Jar == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw := fe.client.Jar.Cookies(fe.cookieURL())
|
||||||
|
out := make([]SerializedCookie, 0, len(raw))
|
||||||
|
for _, c := range raw {
|
||||||
|
out = append(out, serializeCookie(c))
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fe *FlowExecutor) ImportCookies(stored []SerializedCookie) {
|
||||||
|
if fe.client.Jar == nil || len(stored) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cookies := make([]*http.Cookie, 0, len(stored))
|
||||||
|
for _, sc := range stored {
|
||||||
|
cookies = append(cookies, deserializeCookie(sc))
|
||||||
|
}
|
||||||
|
fe.client.Jar.SetCookies(fe.cookieURL(), cookies)
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeCookie(c *http.Cookie) SerializedCookie {
|
||||||
|
sc := SerializedCookie{
|
||||||
|
Name: c.Name,
|
||||||
|
Value: c.Value,
|
||||||
|
Path: c.Path,
|
||||||
|
Domain: c.Domain,
|
||||||
|
MaxAge: c.MaxAge,
|
||||||
|
Expires: c.Expires,
|
||||||
|
Secure: c.Secure,
|
||||||
|
HTTPOnly: c.HttpOnly,
|
||||||
|
SameSite: c.SameSite,
|
||||||
|
}
|
||||||
|
switch c.SameSite {
|
||||||
|
case http.SameSiteDefaultMode:
|
||||||
|
sc.SameSiteName = "Default"
|
||||||
|
case http.SameSiteLaxMode:
|
||||||
|
sc.SameSiteName = "Lax"
|
||||||
|
case http.SameSiteStrictMode:
|
||||||
|
sc.SameSiteName = "Strict"
|
||||||
|
case http.SameSiteNoneMode:
|
||||||
|
sc.SameSiteName = "None"
|
||||||
|
}
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
|
||||||
|
func deserializeCookie(sc SerializedCookie) *http.Cookie {
|
||||||
|
c := &http.Cookie{
|
||||||
|
Name: sc.Name,
|
||||||
|
Value: sc.Value,
|
||||||
|
Path: sc.Path,
|
||||||
|
Domain: sc.Domain,
|
||||||
|
MaxAge: sc.MaxAge,
|
||||||
|
Expires: sc.Expires,
|
||||||
|
Secure: sc.Secure,
|
||||||
|
HttpOnly: sc.HTTPOnly,
|
||||||
|
}
|
||||||
|
switch strings.ToLower(sc.SameSiteName) {
|
||||||
|
case "lax":
|
||||||
|
c.SameSite = http.SameSiteLaxMode
|
||||||
|
case "strict":
|
||||||
|
c.SameSite = http.SameSiteStrictMode
|
||||||
|
case "none":
|
||||||
|
c.SameSite = http.SameSiteNoneMode
|
||||||
|
default:
|
||||||
|
c.SameSite = http.SameSiteLaxMode
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// BrowserAuthentikCookies rewrites stored cookies for the public Authentik path.
|
||||||
|
func BrowserAuthentikCookies(stored []SerializedCookie) []*http.Cookie {
|
||||||
|
out := make([]*http.Cookie, 0, len(stored))
|
||||||
|
for _, sc := range stored {
|
||||||
|
c := deserializeCookie(sc)
|
||||||
|
c.Path = authentikCookiePath
|
||||||
|
c.Domain = ""
|
||||||
|
out = append(out, c)
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestoreFlowExecutor rebuilds an executor from persisted cookies.
|
||||||
|
func RestoreFlowExecutor(baseURL, slug string, stored []SerializedCookie) (*FlowExecutor, error) {
|
||||||
|
fe, err := NewFlowExecutor(baseURL, slug)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fe.ImportCookies(stored)
|
||||||
|
return fe, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloneCookieJar copies cookies between executors (tests).
|
||||||
|
func CloneCookieJar(from, to *FlowExecutor) {
|
||||||
|
if from == nil || to == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
to.ImportCookies(from.ExportCookies())
|
||||||
|
}
|
||||||
@ -3,6 +3,8 @@ package authentik
|
|||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrFlowSessionNotFound = errors.New("flow session not found")
|
ErrFlowSessionNotFound = errors.New("flow session not found")
|
||||||
ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch")
|
ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch")
|
||||||
|
ErrFlowSessionNotCompleted = errors.New("flow session not completed")
|
||||||
|
ErrFlowSessionAlreadyCompleted = errors.New("flow session already completed")
|
||||||
)
|
)
|
||||||
|
|||||||
@ -106,12 +106,23 @@ func (fe *FlowExecutor) PostResponse(ctx context.Context, query string, payload
|
|||||||
return fe.doChallenge(req)
|
return fe.doChallenge(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fe *FlowExecutor) csrfToken() string {
|
func (fe *FlowExecutor) cookieURL() *url.URL {
|
||||||
u, err := url.Parse(fe.baseURL)
|
u, err := url.Parse(fe.baseURL)
|
||||||
if err != nil {
|
if err != nil || u.Host == "" {
|
||||||
return ""
|
return &url.URL{Scheme: "http", Host: "localhost", Path: "/auth/"}
|
||||||
}
|
}
|
||||||
for _, c := range fe.client.Jar.Cookies(u) {
|
if u.Path == "" || u.Path == "/" {
|
||||||
|
u.Path = "/auth/"
|
||||||
|
} else if !strings.HasSuffix(u.Path, "/") {
|
||||||
|
u.Path += "/"
|
||||||
|
}
|
||||||
|
u.RawQuery = ""
|
||||||
|
u.Fragment = ""
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fe *FlowExecutor) csrfToken() string {
|
||||||
|
for _, c := range fe.client.Jar.Cookies(fe.cookieURL()) {
|
||||||
if c.Name == "authentik_csrf" {
|
if c.Name == "authentik_csrf" {
|
||||||
return c.Value
|
return c.Value
|
||||||
}
|
}
|
||||||
|
|||||||
@ -2,13 +2,12 @@ package authentik
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestFlowExecutorGetChallenge(t *testing.T) {
|
func TestFlowExecutorGetChallenge(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
// Smoke test structure only — integration tests hit real Authentik.
|
|
||||||
fe, err := NewFlowExecutor("http://localhost:9000", "test-flow")
|
fe, err := NewFlowExecutor("http://localhost:9000", "test-flow")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@ -36,10 +35,52 @@ func TestFlowDone(t *testing.T) {
|
|||||||
|
|
||||||
func TestFlowSessionStoreLifecycle(t *testing.T) {
|
func TestFlowSessionStoreLifecycle(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
store := NewFlowSessionStore("http://127.0.0.1:1")
|
store := NewFlowSessionStore("http://127.0.0.1:1", nil)
|
||||||
_, err := store.Respond(context.Background(), "missing", "ulti-enrollment", "", map[string]any{"component": "x"})
|
_, err := store.Respond(context.Background(), "missing", "ulti-enrollment", "", map[string]any{"component": "x"})
|
||||||
if err != ErrFlowSessionNotFound {
|
if err != ErrFlowSessionNotFound {
|
||||||
t.Fatalf("expected ErrFlowSessionNotFound, got %v", err)
|
t.Fatalf("expected ErrFlowSessionNotFound, got %v", err)
|
||||||
}
|
}
|
||||||
_ = httptest.NewRecorder()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCookieRoundTrip(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
fe, err := NewFlowExecutor("http://localhost/auth", "test-flow")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
fe.ImportCookies([]SerializedCookie{
|
||||||
|
{Name: "authentik_csrf", Value: "abc123", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true},
|
||||||
|
{Name: "authentik_session", Value: "sess", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true},
|
||||||
|
})
|
||||||
|
exported := fe.ExportCookies()
|
||||||
|
if len(exported) != 2 {
|
||||||
|
t.Fatalf("exported %d cookies", len(exported))
|
||||||
|
}
|
||||||
|
fe2, err := RestoreFlowExecutor("http://localhost/auth", "test-flow", exported)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if fe2.csrfToken() != "abc123" {
|
||||||
|
t.Fatalf("csrf = %q", fe2.csrfToken())
|
||||||
|
}
|
||||||
|
browser := BrowserAuthentikCookies(exported)
|
||||||
|
if len(browser) != 2 || browser[0].Path != "/auth/" {
|
||||||
|
t.Fatalf("browser cookies: %+v", browser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFlowSessionCompleteRequiresDone(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
store := NewFlowSessionStore("http://127.0.0.1:1", nil)
|
||||||
|
store.mu.Lock()
|
||||||
|
store.items["sess1"] = &flowSessionEntry{
|
||||||
|
slug: "default-authentication-flow",
|
||||||
|
createdAt: time.Now(),
|
||||||
|
completed: false,
|
||||||
|
}
|
||||||
|
store.mu.Unlock()
|
||||||
|
_, err := store.CompleteSession(context.Background(), "sess1", "default-authentication-flow")
|
||||||
|
if err != ErrFlowSessionNotCompleted {
|
||||||
|
t.Fatalf("expected ErrFlowSessionNotCompleted, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -4,30 +4,47 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/redis/go-redis/v9"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultFlowSessionTTL = 20 * time.Minute
|
const (
|
||||||
|
defaultFlowSessionTTL = 20 * time.Minute
|
||||||
|
flowSessionKeyPrefix = "auth_flow_session:"
|
||||||
|
)
|
||||||
|
|
||||||
type flowSessionEntry struct {
|
type storedFlowSession struct {
|
||||||
executor *FlowExecutor
|
Slug string `json:"slug"`
|
||||||
slug string
|
Cookies []SerializedCookie `json:"cookies"`
|
||||||
createdAt time.Time
|
CreatedAt time.Time `json:"createdAt"`
|
||||||
|
Completed bool `json:"completed,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FlowSessionStore keeps in-memory Authentik flow executor sessions.
|
type flowSessionEntry struct {
|
||||||
|
slug string
|
||||||
|
cookies []SerializedCookie
|
||||||
|
createdAt time.Time
|
||||||
|
completed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// FlowSessionStore keeps Authentik flow executor sessions (memory + optional KeyDB).
|
||||||
type FlowSessionStore struct {
|
type FlowSessionStore struct {
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
baseURL string
|
baseURL string
|
||||||
ttl time.Duration
|
ttl time.Duration
|
||||||
|
rdb *redis.Client
|
||||||
items map[string]*flowSessionEntry
|
items map[string]*flowSessionEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewFlowSessionStore(baseURL string) *FlowSessionStore {
|
func NewFlowSessionStore(baseURL string, rdb *redis.Client) *FlowSessionStore {
|
||||||
return &FlowSessionStore{
|
return &FlowSessionStore{
|
||||||
baseURL: baseURL,
|
baseURL: baseURL,
|
||||||
ttl: defaultFlowSessionTTL,
|
ttl: defaultFlowSessionTTL,
|
||||||
|
rdb: rdb,
|
||||||
items: make(map[string]*flowSessionEntry),
|
items: make(map[string]*flowSessionEntry),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -46,64 +63,151 @@ func (s *FlowSessionStore) Start(ctx context.Context, slug, query string) (sessi
|
|||||||
return "", nil, err
|
return "", nil, err
|
||||||
}
|
}
|
||||||
done, _ := FlowDone(challenge)
|
done, _ := FlowDone(challenge)
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.cleanupLocked(time.Now())
|
|
||||||
if !done {
|
if !done {
|
||||||
s.items[id] = &flowSessionEntry{
|
entry := &flowSessionEntry{
|
||||||
executor: executor,
|
|
||||||
slug: slug,
|
slug: slug,
|
||||||
|
cookies: executor.ExportCookies(),
|
||||||
createdAt: time.Now(),
|
createdAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
if err := s.save(ctx, id, entry); err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return id, challenge, nil
|
return id, challenge, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FlowSessionStore) Respond(ctx context.Context, sessionID, slug, query string, payload map[string]any) (FlowChallenge, error) {
|
func (s *FlowSessionStore) Respond(ctx context.Context, sessionID, slug, query string, payload map[string]any) (FlowChallenge, error) {
|
||||||
entry, err := s.get(sessionID)
|
entry, err := s.load(ctx, sessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if entry.slug != slug {
|
if entry.slug != slug {
|
||||||
return nil, ErrFlowSessionSlugMismatch
|
return nil, ErrFlowSessionSlugMismatch
|
||||||
}
|
}
|
||||||
challenge, err := entry.executor.PostResponse(ctx, query, payload)
|
if entry.completed {
|
||||||
|
return nil, ErrFlowSessionAlreadyCompleted
|
||||||
|
}
|
||||||
|
executor, err := RestoreFlowExecutor(s.baseURL, slug, entry.cookies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
done, _ := FlowDone(challenge)
|
challenge, err := executor.PostResponse(ctx, query, payload)
|
||||||
if done {
|
if err != nil {
|
||||||
s.deleteLocked(sessionID)
|
return nil, err
|
||||||
|
}
|
||||||
|
entry.cookies = executor.ExportCookies()
|
||||||
|
done, denied := FlowDone(challenge)
|
||||||
|
if done && !denied {
|
||||||
|
entry.completed = true
|
||||||
|
if err := s.save(ctx, sessionID, entry); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return challenge, nil
|
||||||
|
}
|
||||||
|
if err := s.save(ctx, sessionID, entry); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
return challenge, nil
|
return challenge, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FlowSessionStore) Delete(sessionID string) {
|
// CompleteSession returns cookies from a completed flow and removes the session.
|
||||||
s.mu.Lock()
|
func (s *FlowSessionStore) CompleteSession(ctx context.Context, sessionID, slug string) ([]SerializedCookie, error) {
|
||||||
defer s.mu.Unlock()
|
entry, err := s.load(ctx, sessionID)
|
||||||
s.deleteLocked(sessionID)
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if entry.slug != slug {
|
||||||
|
return nil, ErrFlowSessionSlugMismatch
|
||||||
|
}
|
||||||
|
if !entry.completed {
|
||||||
|
return nil, ErrFlowSessionNotCompleted
|
||||||
|
}
|
||||||
|
cookies := entry.cookies
|
||||||
|
s.delete(ctx, sessionID)
|
||||||
|
return cookies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FlowSessionStore) deleteLocked(sessionID string) {
|
// SessionCookies returns current Authentik cookies for an active session.
|
||||||
delete(s.items, sessionID)
|
func (s *FlowSessionStore) SessionCookies(ctx context.Context, sessionID string) ([]SerializedCookie, error) {
|
||||||
|
entry, err := s.load(ctx, sessionID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return entry.cookies, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FlowSessionStore) get(sessionID string) (*flowSessionEntry, error) {
|
func (s *FlowSessionStore) Delete(ctx context.Context, sessionID string) {
|
||||||
|
s.delete(ctx, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FlowSessionStore) save(ctx context.Context, sessionID string, entry *flowSessionEntry) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
s.items[sessionID] = entry
|
||||||
|
s.mu.Unlock()
|
||||||
|
if s.rdb == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
stored := storedFlowSession{
|
||||||
|
Slug: entry.slug,
|
||||||
|
Cookies: entry.cookies,
|
||||||
|
CreatedAt: entry.createdAt,
|
||||||
|
Completed: entry.completed,
|
||||||
|
}
|
||||||
|
raw, err := json.Marshal(stored)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.rdb.Set(ctx, flowSessionKeyPrefix+sessionID, raw, s.ttl).Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FlowSessionStore) load(ctx context.Context, sessionID string) (*flowSessionEntry, error) {
|
||||||
|
if s.rdb != nil {
|
||||||
|
raw, err := s.rdb.Get(ctx, flowSessionKeyPrefix+sessionID).Bytes()
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, redis.Nil) {
|
||||||
|
return nil, ErrFlowSessionNotFound
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var stored storedFlowSession
|
||||||
|
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if time.Since(stored.CreatedAt) > s.ttl {
|
||||||
|
s.delete(ctx, sessionID)
|
||||||
|
return nil, ErrFlowSessionNotFound
|
||||||
|
}
|
||||||
|
entry := &flowSessionEntry{
|
||||||
|
slug: stored.Slug,
|
||||||
|
cookies: stored.Cookies,
|
||||||
|
createdAt: stored.CreatedAt,
|
||||||
|
completed: stored.Completed,
|
||||||
|
}
|
||||||
|
s.mu.Lock()
|
||||||
|
s.items[sessionID] = entry
|
||||||
|
s.mu.Unlock()
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
|
||||||
s.cleanupLocked(time.Now())
|
|
||||||
entry, ok := s.items[sessionID]
|
entry, ok := s.items[sessionID]
|
||||||
|
s.mu.Unlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, ErrFlowSessionNotFound
|
return nil, ErrFlowSessionNotFound
|
||||||
}
|
}
|
||||||
|
if time.Since(entry.createdAt) > s.ttl {
|
||||||
|
s.delete(ctx, sessionID)
|
||||||
|
return nil, ErrFlowSessionNotFound
|
||||||
|
}
|
||||||
return entry, nil
|
return entry, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *FlowSessionStore) cleanupLocked(now time.Time) {
|
func (s *FlowSessionStore) delete(ctx context.Context, sessionID string) {
|
||||||
for id, entry := range s.items {
|
s.mu.Lock()
|
||||||
if now.Sub(entry.createdAt) > s.ttl {
|
delete(s.items, sessionID)
|
||||||
delete(s.items, id)
|
s.mu.Unlock()
|
||||||
}
|
if s.rdb != nil {
|
||||||
|
_ = s.rdb.Del(ctx, flowSessionKeyPrefix+sessionID).Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -368,7 +368,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
|||||||
r.Get("/api/v1/mail/addresses/check", mailHandler.CheckAddressAvailability)
|
r.Get("/api/v1/mail/addresses/check", mailHandler.CheckAddressAvailability)
|
||||||
r.Get("/api/v1/migration/invite", migrationHandler.GetInvite)
|
r.Get("/api/v1/migration/invite", migrationHandler.GetInvite)
|
||||||
r.Post("/internal/provision/user", provisionHandler.ProvisionUser)
|
r.Post("/internal/provision/user", provisionHandler.ProvisionUser)
|
||||||
r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL).Routes())
|
r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL, cfg.MailAppURL, rdb).Routes())
|
||||||
|
|
||||||
var driveHandler *drive.Handler
|
var driveHandler *drive.Handler
|
||||||
var driveSvc *drive.Service
|
var driveSvc *drive.Service
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user