diff --git a/internal/api/auth/complete.go b/internal/api/auth/complete.go new file mode 100644 index 0000000..c5069a5 --- /dev/null +++ b/internal/api/auth/complete.go @@ -0,0 +1,87 @@ +package authapi + +import ( + "encoding/json" + "errors" + "net/http" + "net/url" + "strings" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/authentik" +) + +type flowCompleteRequest struct { + ReturnTo string `json:"returnTo"` +} + +type flowCompleteResponse struct { + RedirectURL string `json:"redirectUrl"` +} + +func (h *Handler) CompleteFlow(w http.ResponseWriter, r *http.Request) { + sessionID := readFlowSessionCookie(r) + if sessionID == "" { + apiresponse.WriteError(w, r, http.StatusUnauthorized, "flow_session_missing", "flow session cookie required", nil) + return + } + + var req flowCompleteRequest + if r.Body != nil { + _ = json.NewDecoder(r.Body).Decode(&req) + } + returnTo := strings.TrimSpace(req.ReturnTo) + if returnTo == "" { + returnTo = "/mail/inbox" + } + if !strings.HasPrefix(returnTo, "/") || strings.HasPrefix(returnTo, "//") { + apiresponse.WriteError(w, r, http.StatusBadRequest, "invalid_return_to", "returnTo must be a relative path", nil) + return + } + + slug := FlowAuthentication + cookies, err := h.flows.CompleteSession(r.Context(), sessionID, slug) + if err != nil { + if errors.Is(err, authentik.ErrFlowSessionNotFound) { + clearFlowSessionCookie(w) + apiresponse.WriteError(w, r, http.StatusGone, "flow_session_expired", "flow session expired; restart the flow", nil) + return + } + if errors.Is(err, authentik.ErrFlowSessionNotCompleted) { + apiresponse.WriteError(w, r, http.StatusConflict, "flow_not_completed", "authentication flow is not finished yet", nil) + return + } + if errors.Is(err, authentik.ErrFlowSessionSlugMismatch) { + apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil) + return + } + apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_complete_failed", err.Error(), nil) + return + } + + clearFlowSessionCookie(w) + setBrowserAuthentikCookies(w, cookies) + + loginURL := buildLoginRedirectURL(h.appURL, returnTo) + apiresponse.WriteJSON(w, http.StatusOK, flowCompleteResponse{RedirectURL: loginURL}) +} + +func buildLoginRedirectURL(appURL, returnTo string) string { + base := strings.TrimRight(strings.TrimSpace(appURL), "/") + if base == "" { + base = "http://localhost:3004" + } + params := url.Values{} + params.Set("returnTo", returnTo) + return base + "/api/auth/login?" + params.Encode() +} + +func setBrowserAuthentikCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) { + for _, c := range authentik.BrowserAuthentikCookies(stored) { + http.SetCookie(w, c) + } +} + +func forwardFlowCookies(w http.ResponseWriter, stored []authentik.SerializedCookie) { + setBrowserAuthentikCookies(w, stored) +} diff --git a/internal/api/auth/handlers.go b/internal/api/auth/handlers.go index 2b50f48..f2ea343 100644 --- a/internal/api/auth/handlers.go +++ b/internal/api/auth/handlers.go @@ -8,31 +8,38 @@ import ( "time" "github.com/go-chi/chi/v5" + "github.com/redis/go-redis/v9" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/authentik" ) const ( - FlowEnrollment = "ulti-enrollment" - FlowRecovery = "ulti-recovery" - flowSessionCookie = "ulti_flow_session" - flowSessionMaxAge = 20 * time.Minute + FlowEnrollment = "ulti-enrollment" + FlowRecovery = "ulti-recovery" + FlowAuthentication = "default-authentication-flow" + flowSessionCookie = "ulti_flow_session" + flowSessionMaxAge = 20 * time.Minute ) var allowedFlowSlugs = map[string]struct{}{ - FlowEnrollment: {}, - FlowRecovery: {}, + FlowEnrollment: {}, + FlowRecovery: {}, + FlowAuthentication: {}, } // Handler exposes public Authentik flow executor endpoints for custom auth UI. type Handler struct { - flows *authentik.FlowSessionStore + flows *authentik.FlowSessionStore + limiter *FlowRateLimiter + appURL string } -func NewHandler(baseURL string) *Handler { +func NewHandler(baseURL, appURL string, rdb *redis.Client) *Handler { return &Handler{ - flows: authentik.NewFlowSessionStore(baseURL), + flows: authentik.NewFlowSessionStore(baseURL, rdb), + limiter: NewFlowRateLimiter(rdb), + appURL: appURL, } } @@ -40,6 +47,7 @@ func (h *Handler) Routes() chi.Router { r := chi.NewRouter() r.Post("/flows/{slug}/start", h.StartFlow) r.Post("/flows/{slug}/respond", h.RespondFlow) + r.Post("/flows/complete", h.CompleteFlow) return r } @@ -55,6 +63,11 @@ type flowRespondRequest struct { } func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) { + if !h.limiter.Allow(r) { + apiresponse.WriteError(w, r, http.StatusTooManyRequests, apiresponse.CodeRateLimited, "too many flow start attempts", nil) + return + } + slug, ok := validateFlowSlug(w, r, chi.URLParam(r, "slug")) if !ok { return @@ -69,6 +82,9 @@ func (h *Handler) StartFlow(w http.ResponseWriter, r *http.Request) { done, denied := authentik.FlowDone(challenge) if !done { setFlowSessionCookie(w, sessionID) + if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil { + forwardFlowCookies(w, cookies) + } } else { clearFlowSessionCookie(w) } @@ -114,14 +130,30 @@ func (h *Handler) RespondFlow(w http.ResponseWriter, r *http.Request) { apiresponse.WriteError(w, r, http.StatusConflict, "flow_session_mismatch", "flow slug does not match active session", nil) return } + if errors.Is(err, authentik.ErrFlowSessionAlreadyCompleted) { + apiresponse.WriteError(w, r, http.StatusConflict, "flow_already_completed", "flow already completed", nil) + return + } apiresponse.WriteError(w, r, http.StatusBadGateway, "flow_respond_failed", err.Error(), nil) return } done, denied := authentik.FlowDone(challenge) if done { - clearFlowSessionCookie(w) + if slug == FlowAuthentication { + if denied { + clearFlowSessionCookie(w) + h.flows.Delete(r.Context(), sessionID) + } + // Successful auth: keep session cookie until /flows/complete. + } else { + clearFlowSessionCookie(w) + h.flows.Delete(r.Context(), sessionID) + } + } else if cookies, err := h.flows.SessionCookies(r.Context(), sessionID); err == nil { + forwardFlowCookies(w, cookies) } + writeFlowJSON(w, http.StatusOK, flowStartResponse{ SessionID: sessionID, Challenge: challenge, diff --git a/internal/api/auth/handlers_test.go b/internal/api/auth/handlers_test.go new file mode 100644 index 0000000..0934256 --- /dev/null +++ b/internal/api/auth/handlers_test.go @@ -0,0 +1,74 @@ +package authapi + +import ( + "bytes" + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" +) + +func TestValidateFlowSlugAllowed(t *testing.T) { + t.Parallel() + for _, slug := range []string{FlowEnrollment, FlowRecovery, FlowAuthentication} { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/flows/"+slug+"/start", nil) + got, ok := validateFlowSlug(rec, req, slug) + if !ok || got != slug { + t.Fatalf("slug %q: ok=%v got=%q code=%d", slug, ok, got, rec.Code) + } + } +} + +func TestValidateFlowSlugRejected(t *testing.T) { + t.Parallel() + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/flows/default-invalidation-flow/start", nil) + _, ok := validateFlowSlug(rec, req, "default-invalidation-flow") + if ok { + t.Fatal("expected slug to be rejected") + } + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d", rec.Code) + } +} + +func TestRespondFlowMissingCookie(t *testing.T) { + t.Parallel() + h := NewHandler("http://127.0.0.1:1", "http://localhost:3004", nil) + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/flows/ulti-enrollment/respond", bytes.NewReader([]byte(`{"payload":{"component":"x"}}`))) + rctx := chi.NewRouteContext() + rctx.URLParams.Add("slug", FlowEnrollment) + req = req.WithContext(context.WithValue(req.Context(), chi.RouteCtxKey, rctx)) + h.RespondFlow(rec, req) + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } +} + +func TestBuildLoginRedirectURL(t *testing.T) { + t.Parallel() + got := buildLoginRedirectURL("http://localhost:3004", "/mail/inbox") + want := "http://localhost:3004/api/auth/login?returnTo=%2Fmail%2Finbox" + if got != want { + t.Fatalf("got %q want %q", got, want) + } +} + +func TestFlowRateLimiterMemory(t *testing.T) { + t.Parallel() + lim := NewFlowRateLimiter(nil) + req := httptest.NewRequest(http.MethodPost, "/start", nil) + req.RemoteAddr = "203.0.113.1:1234" + for i := 0; i < flowRateLimitMax; i++ { + if !lim.Allow(req) { + t.Fatalf("attempt %d should be allowed", i+1) + } + } + if lim.Allow(req) { + t.Fatal("expected rate limit block") + } +} diff --git a/internal/api/auth/rate_limit.go b/internal/api/auth/rate_limit.go new file mode 100644 index 0000000..370f5e7 --- /dev/null +++ b/internal/api/auth/rate_limit.go @@ -0,0 +1,94 @@ +package authapi + +import ( + "net" + "net/http" + "strings" + "sync" + "time" + + "github.com/redis/go-redis/v9" +) + +const ( + flowRateLimitMax = 10 + flowRateLimitWindow = time.Minute + flowRateKeyPrefix = "auth_flow_rate:" +) + +// FlowRateLimiter limits flow start requests per client IP. +type FlowRateLimiter struct { + rdb *redis.Client + mu sync.Mutex + // fallback when KeyDB unavailable + hits map[string][]time.Time +} + +func NewFlowRateLimiter(rdb *redis.Client) *FlowRateLimiter { + return &FlowRateLimiter{ + rdb: rdb, + hits: make(map[string][]time.Time), + } +} + +func (l *FlowRateLimiter) Allow(r *http.Request) bool { + ip := clientIP(r) + if ip == "" { + return true + } + if l.rdb != nil { + return l.allowRedis(r, ip) + } + return l.allowMemory(ip) +} + +func (l *FlowRateLimiter) allowRedis(r *http.Request, ip string) bool { + ctx := r.Context() + key := flowRateKeyPrefix + ip + count, err := l.rdb.Incr(ctx, key).Result() + if err != nil { + return l.allowMemory(ip) + } + if count == 1 { + _ = l.rdb.Expire(ctx, key, flowRateLimitWindow).Err() + } + return count <= flowRateLimitMax +} + +func (l *FlowRateLimiter) allowMemory(ip string) bool { + now := time.Now() + cutoff := now.Add(-flowRateLimitWindow) + l.mu.Lock() + defer l.mu.Unlock() + prev := l.hits[ip] + next := prev[:0] + for _, t := range prev { + if t.After(cutoff) { + next = append(next, t) + } + } + if len(next) >= flowRateLimitMax { + l.hits[ip] = next + return false + } + next = append(next, now) + l.hits[ip] = next + return true +} + +func clientIP(r *http.Request) string { + if xff := strings.TrimSpace(r.Header.Get("X-Forwarded-For")); xff != "" { + parts := strings.Split(xff, ",") + if len(parts) > 0 { + return strings.TrimSpace(parts[0]) + } + } + if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" { + return xrip + } + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return strings.TrimSpace(r.RemoteAddr) + } + return host +} diff --git a/internal/authentik/flow_cookies.go b/internal/authentik/flow_cookies.go new file mode 100644 index 0000000..15b85a8 --- /dev/null +++ b/internal/authentik/flow_cookies.go @@ -0,0 +1,125 @@ +package authentik + +import ( + "net/http" + "strings" + "time" +) + +const authentikCookiePath = "/auth/" + +// SerializedCookie stores an HTTP cookie for flow session persistence. +type SerializedCookie struct { + Name string `json:"name"` + Value string `json:"value"` + Path string `json:"path,omitempty"` + Domain string `json:"domain,omitempty"` + MaxAge int `json:"maxAge,omitempty"` + Expires time.Time `json:"expires,omitempty"` + Secure bool `json:"secure,omitempty"` + HTTPOnly bool `json:"httpOnly,omitempty"` + SameSite http.SameSite `json:"-"` + SameSiteName string `json:"sameSite,omitempty"` +} + +func (fe *FlowExecutor) ExportCookies() []SerializedCookie { + if fe.client.Jar == nil { + return nil + } + raw := fe.client.Jar.Cookies(fe.cookieURL()) + out := make([]SerializedCookie, 0, len(raw)) + for _, c := range raw { + out = append(out, serializeCookie(c)) + } + return out +} + +func (fe *FlowExecutor) ImportCookies(stored []SerializedCookie) { + if fe.client.Jar == nil || len(stored) == 0 { + return + } + cookies := make([]*http.Cookie, 0, len(stored)) + for _, sc := range stored { + cookies = append(cookies, deserializeCookie(sc)) + } + fe.client.Jar.SetCookies(fe.cookieURL(), cookies) +} + +func serializeCookie(c *http.Cookie) SerializedCookie { + sc := SerializedCookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + MaxAge: c.MaxAge, + Expires: c.Expires, + Secure: c.Secure, + HTTPOnly: c.HttpOnly, + SameSite: c.SameSite, + } + switch c.SameSite { + case http.SameSiteDefaultMode: + sc.SameSiteName = "Default" + case http.SameSiteLaxMode: + sc.SameSiteName = "Lax" + case http.SameSiteStrictMode: + sc.SameSiteName = "Strict" + case http.SameSiteNoneMode: + sc.SameSiteName = "None" + } + return sc +} + +func deserializeCookie(sc SerializedCookie) *http.Cookie { + c := &http.Cookie{ + Name: sc.Name, + Value: sc.Value, + Path: sc.Path, + Domain: sc.Domain, + MaxAge: sc.MaxAge, + Expires: sc.Expires, + Secure: sc.Secure, + HttpOnly: sc.HTTPOnly, + } + switch strings.ToLower(sc.SameSiteName) { + case "lax": + c.SameSite = http.SameSiteLaxMode + case "strict": + c.SameSite = http.SameSiteStrictMode + case "none": + c.SameSite = http.SameSiteNoneMode + default: + c.SameSite = http.SameSiteLaxMode + } + return c +} + +// BrowserAuthentikCookies rewrites stored cookies for the public Authentik path. +func BrowserAuthentikCookies(stored []SerializedCookie) []*http.Cookie { + out := make([]*http.Cookie, 0, len(stored)) + for _, sc := range stored { + c := deserializeCookie(sc) + c.Path = authentikCookiePath + c.Domain = "" + out = append(out, c) + } + return out +} + +// RestoreFlowExecutor rebuilds an executor from persisted cookies. +func RestoreFlowExecutor(baseURL, slug string, stored []SerializedCookie) (*FlowExecutor, error) { + fe, err := NewFlowExecutor(baseURL, slug) + if err != nil { + return nil, err + } + fe.ImportCookies(stored) + return fe, nil +} + +// CloneCookieJar copies cookies between executors (tests). +func CloneCookieJar(from, to *FlowExecutor) { + if from == nil || to == nil { + return + } + to.ImportCookies(from.ExportCookies()) +} \ No newline at end of file diff --git a/internal/authentik/flow_errors.go b/internal/authentik/flow_errors.go index 36b09ef..f614388 100644 --- a/internal/authentik/flow_errors.go +++ b/internal/authentik/flow_errors.go @@ -3,6 +3,8 @@ package authentik import "errors" var ( - ErrFlowSessionNotFound = errors.New("flow session not found") - ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch") + ErrFlowSessionNotFound = errors.New("flow session not found") + ErrFlowSessionSlugMismatch = errors.New("flow session slug mismatch") + ErrFlowSessionNotCompleted = errors.New("flow session not completed") + ErrFlowSessionAlreadyCompleted = errors.New("flow session already completed") ) diff --git a/internal/authentik/flow_executor.go b/internal/authentik/flow_executor.go index 0dfd3aa..328aad4 100644 --- a/internal/authentik/flow_executor.go +++ b/internal/authentik/flow_executor.go @@ -106,12 +106,23 @@ func (fe *FlowExecutor) PostResponse(ctx context.Context, query string, payload return fe.doChallenge(req) } -func (fe *FlowExecutor) csrfToken() string { +func (fe *FlowExecutor) cookieURL() *url.URL { u, err := url.Parse(fe.baseURL) - if err != nil { - return "" + if err != nil || u.Host == "" { + return &url.URL{Scheme: "http", Host: "localhost", Path: "/auth/"} } - for _, c := range fe.client.Jar.Cookies(u) { + if u.Path == "" || u.Path == "/" { + u.Path = "/auth/" + } else if !strings.HasSuffix(u.Path, "/") { + u.Path += "/" + } + u.RawQuery = "" + u.Fragment = "" + return u +} + +func (fe *FlowExecutor) csrfToken() string { + for _, c := range fe.client.Jar.Cookies(fe.cookieURL()) { if c.Name == "authentik_csrf" { return c.Value } diff --git a/internal/authentik/flow_executor_test.go b/internal/authentik/flow_executor_test.go index 88a7e90..ce13aca 100644 --- a/internal/authentik/flow_executor_test.go +++ b/internal/authentik/flow_executor_test.go @@ -2,13 +2,12 @@ package authentik import ( "context" - "net/http/httptest" "testing" + "time" ) func TestFlowExecutorGetChallenge(t *testing.T) { t.Parallel() - // Smoke test structure only — integration tests hit real Authentik. fe, err := NewFlowExecutor("http://localhost:9000", "test-flow") if err != nil { t.Fatal(err) @@ -36,10 +35,52 @@ func TestFlowDone(t *testing.T) { func TestFlowSessionStoreLifecycle(t *testing.T) { t.Parallel() - store := NewFlowSessionStore("http://127.0.0.1:1") + store := NewFlowSessionStore("http://127.0.0.1:1", nil) _, err := store.Respond(context.Background(), "missing", "ulti-enrollment", "", map[string]any{"component": "x"}) if err != ErrFlowSessionNotFound { t.Fatalf("expected ErrFlowSessionNotFound, got %v", err) } - _ = httptest.NewRecorder() } + +func TestCookieRoundTrip(t *testing.T) { + t.Parallel() + fe, err := NewFlowExecutor("http://localhost/auth", "test-flow") + if err != nil { + t.Fatal(err) + } + fe.ImportCookies([]SerializedCookie{ + {Name: "authentik_csrf", Value: "abc123", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true}, + {Name: "authentik_session", Value: "sess", Path: "/auth/", SameSiteName: "Lax", HTTPOnly: true}, + }) + exported := fe.ExportCookies() + if len(exported) != 2 { + t.Fatalf("exported %d cookies", len(exported)) + } + fe2, err := RestoreFlowExecutor("http://localhost/auth", "test-flow", exported) + if err != nil { + t.Fatal(err) + } + if fe2.csrfToken() != "abc123" { + t.Fatalf("csrf = %q", fe2.csrfToken()) + } + browser := BrowserAuthentikCookies(exported) + if len(browser) != 2 || browser[0].Path != "/auth/" { + t.Fatalf("browser cookies: %+v", browser) + } +} + +func TestFlowSessionCompleteRequiresDone(t *testing.T) { + t.Parallel() + store := NewFlowSessionStore("http://127.0.0.1:1", nil) + store.mu.Lock() + store.items["sess1"] = &flowSessionEntry{ + slug: "default-authentication-flow", + createdAt: time.Now(), + completed: false, + } + store.mu.Unlock() + _, err := store.CompleteSession(context.Background(), "sess1", "default-authentication-flow") + if err != ErrFlowSessionNotCompleted { + t.Fatalf("expected ErrFlowSessionNotCompleted, got %v", err) + } +} \ No newline at end of file diff --git a/internal/authentik/flow_session_store.go b/internal/authentik/flow_session_store.go index 40c499b..a443a64 100644 --- a/internal/authentik/flow_session_store.go +++ b/internal/authentik/flow_session_store.go @@ -4,30 +4,47 @@ import ( "context" "crypto/rand" "encoding/hex" + "encoding/json" + "errors" "sync" "time" + + "github.com/redis/go-redis/v9" ) -const defaultFlowSessionTTL = 20 * time.Minute +const ( + defaultFlowSessionTTL = 20 * time.Minute + flowSessionKeyPrefix = "auth_flow_session:" +) -type flowSessionEntry struct { - executor *FlowExecutor - slug string - createdAt time.Time +type storedFlowSession struct { + Slug string `json:"slug"` + Cookies []SerializedCookie `json:"cookies"` + CreatedAt time.Time `json:"createdAt"` + Completed bool `json:"completed,omitempty"` } -// FlowSessionStore keeps in-memory Authentik flow executor sessions. +type flowSessionEntry struct { + slug string + cookies []SerializedCookie + createdAt time.Time + completed bool +} + +// FlowSessionStore keeps Authentik flow executor sessions (memory + optional KeyDB). type FlowSessionStore struct { mu sync.Mutex baseURL string ttl time.Duration + rdb *redis.Client items map[string]*flowSessionEntry } -func NewFlowSessionStore(baseURL string) *FlowSessionStore { +func NewFlowSessionStore(baseURL string, rdb *redis.Client) *FlowSessionStore { return &FlowSessionStore{ baseURL: baseURL, ttl: defaultFlowSessionTTL, + rdb: rdb, items: make(map[string]*flowSessionEntry), } } @@ -46,64 +63,151 @@ func (s *FlowSessionStore) Start(ctx context.Context, slug, query string) (sessi return "", nil, err } done, _ := FlowDone(challenge) - s.mu.Lock() - defer s.mu.Unlock() - s.cleanupLocked(time.Now()) if !done { - s.items[id] = &flowSessionEntry{ - executor: executor, + entry := &flowSessionEntry{ slug: slug, + cookies: executor.ExportCookies(), createdAt: time.Now(), } + if err := s.save(ctx, id, entry); err != nil { + return "", nil, err + } } return id, challenge, nil } func (s *FlowSessionStore) Respond(ctx context.Context, sessionID, slug, query string, payload map[string]any) (FlowChallenge, error) { - entry, err := s.get(sessionID) + entry, err := s.load(ctx, sessionID) if err != nil { return nil, err } if entry.slug != slug { return nil, ErrFlowSessionSlugMismatch } - challenge, err := entry.executor.PostResponse(ctx, query, payload) + if entry.completed { + return nil, ErrFlowSessionAlreadyCompleted + } + executor, err := RestoreFlowExecutor(s.baseURL, slug, entry.cookies) if err != nil { return nil, err } - done, _ := FlowDone(challenge) - if done { - s.deleteLocked(sessionID) + challenge, err := executor.PostResponse(ctx, query, payload) + if err != nil { + return nil, err + } + entry.cookies = executor.ExportCookies() + done, denied := FlowDone(challenge) + if done && !denied { + entry.completed = true + if err := s.save(ctx, sessionID, entry); err != nil { + return nil, err + } + return challenge, nil + } + if err := s.save(ctx, sessionID, entry); err != nil { + return nil, err } return challenge, nil } -func (s *FlowSessionStore) Delete(sessionID string) { - s.mu.Lock() - defer s.mu.Unlock() - s.deleteLocked(sessionID) +// CompleteSession returns cookies from a completed flow and removes the session. +func (s *FlowSessionStore) CompleteSession(ctx context.Context, sessionID, slug string) ([]SerializedCookie, error) { + entry, err := s.load(ctx, sessionID) + if err != nil { + return nil, err + } + if entry.slug != slug { + return nil, ErrFlowSessionSlugMismatch + } + if !entry.completed { + return nil, ErrFlowSessionNotCompleted + } + cookies := entry.cookies + s.delete(ctx, sessionID) + return cookies, nil } -func (s *FlowSessionStore) deleteLocked(sessionID string) { - delete(s.items, sessionID) +// SessionCookies returns current Authentik cookies for an active session. +func (s *FlowSessionStore) SessionCookies(ctx context.Context, sessionID string) ([]SerializedCookie, error) { + entry, err := s.load(ctx, sessionID) + if err != nil { + return nil, err + } + return entry.cookies, nil } -func (s *FlowSessionStore) get(sessionID string) (*flowSessionEntry, error) { +func (s *FlowSessionStore) Delete(ctx context.Context, sessionID string) { + s.delete(ctx, sessionID) +} + +func (s *FlowSessionStore) save(ctx context.Context, sessionID string, entry *flowSessionEntry) error { + s.mu.Lock() + s.items[sessionID] = entry + s.mu.Unlock() + if s.rdb == nil { + return nil + } + stored := storedFlowSession{ + Slug: entry.slug, + Cookies: entry.cookies, + CreatedAt: entry.createdAt, + Completed: entry.completed, + } + raw, err := json.Marshal(stored) + if err != nil { + return err + } + return s.rdb.Set(ctx, flowSessionKeyPrefix+sessionID, raw, s.ttl).Err() +} + +func (s *FlowSessionStore) load(ctx context.Context, sessionID string) (*flowSessionEntry, error) { + if s.rdb != nil { + raw, err := s.rdb.Get(ctx, flowSessionKeyPrefix+sessionID).Bytes() + if err != nil { + if errors.Is(err, redis.Nil) { + return nil, ErrFlowSessionNotFound + } + return nil, err + } + var stored storedFlowSession + if err := json.Unmarshal(raw, &stored); err != nil { + return nil, err + } + if time.Since(stored.CreatedAt) > s.ttl { + s.delete(ctx, sessionID) + return nil, ErrFlowSessionNotFound + } + entry := &flowSessionEntry{ + slug: stored.Slug, + cookies: stored.Cookies, + createdAt: stored.CreatedAt, + completed: stored.Completed, + } + s.mu.Lock() + s.items[sessionID] = entry + s.mu.Unlock() + return entry, nil + } + s.mu.Lock() - defer s.mu.Unlock() - s.cleanupLocked(time.Now()) entry, ok := s.items[sessionID] + s.mu.Unlock() if !ok { return nil, ErrFlowSessionNotFound } + if time.Since(entry.createdAt) > s.ttl { + s.delete(ctx, sessionID) + return nil, ErrFlowSessionNotFound + } return entry, nil } -func (s *FlowSessionStore) cleanupLocked(now time.Time) { - for id, entry := range s.items { - if now.Sub(entry.createdAt) > s.ttl { - delete(s.items, id) - } +func (s *FlowSessionStore) delete(ctx context.Context, sessionID string) { + s.mu.Lock() + delete(s.items, sessionID) + s.mu.Unlock() + if s.rdb != nil { + _ = s.rdb.Del(ctx, flowSessionKeyPrefix+sessionID).Err() } } diff --git a/internal/server/bootstrap.go b/internal/server/bootstrap.go index e15add7..6d47112 100644 --- a/internal/server/bootstrap.go +++ b/internal/server/bootstrap.go @@ -368,7 +368,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { r.Get("/api/v1/mail/addresses/check", mailHandler.CheckAddressAvailability) r.Get("/api/v1/migration/invite", migrationHandler.GetInvite) r.Post("/internal/provision/user", provisionHandler.ProvisionUser) - r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL).Routes()) + r.Mount("/api/v1/auth", authapi.NewHandler(cfg.AuthentikAPIURL, cfg.MailAppURL, rdb).Routes()) var driveHandler *drive.Handler var driveSvc *drive.Service