package middleware import ( "context" "log/slog" "net/http" "strings" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/apitokens" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/permission" "github.com/ultisuite/ulti-backend/internal/securityaudit" "github.com/ultisuite/ulti-backend/internal/users" ) type ctxKey string const ( claimsKey ctxKey = "claims" apiTokenKey ctxKey = "api_token" ) func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") if header == "" { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthMissingAuthorization, "missing authorization header", nil) if audit != nil { audit.Log(r.Context(), "anonymous", securityaudit.ActionTokenRejected, map[string]any{ "reason": "missing_authorization_header", "path": r.URL.Path, "method": r.Method, }) } return } token, found := strings.CutPrefix(header, "Bearer ") if !found { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthInvalidAuthorization, "invalid authorization header", nil) if audit != nil { audit.Log(r.Context(), "anonymous", securityaudit.ActionTokenRejected, map[string]any{ "reason": "invalid_authorization_header", "path": r.URL.Path, "method": r.Method, }) } return } token = strings.TrimSpace(token) if strings.HasPrefix(token, apitokens.TokenPrefix()) { if db == nil { apiresponse.WriteError(w, r, http.StatusServiceUnavailable, apiresponse.CodeAuthUnavailable, "authentication unavailable", nil) return } apiAuth, err := apitokens.Authenticate(r.Context(), db, token) if err != nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthInvalidToken, "invalid api token", nil) if audit != nil { audit.Log(r.Context(), "anonymous", securityaudit.ActionTokenRejected, map[string]any{ "reason": "api_token_verification_failed", "path": r.URL.Path, "method": r.Method, }) } return } if isApiTokenManagementRoute(r.URL.Path) && !apitokens.HasPermission(apiAuth, "automation.api_tokens", true) { apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "api token management requires super admin permission", nil) return } claims := &auth.Claims{ Sub: apiAuth.ExternalID, Email: apiAuth.Email, Name: apiAuth.Name, } if audit != nil { audit.Log(r.Context(), claims.Sub, securityaudit.ActionLogin, map[string]any{ "email": claims.Email, "path": r.URL.Path, "method": r.Method, "api_token": apiAuth.TokenID, "auth_mode": "api_token", }) } ctx := context.WithValue(r.Context(), claimsKey, claims) ctx = context.WithValue(ctx, apiTokenKey, apiAuth) next.ServeHTTP(w, r.WithContext(ctx)) return } if verifier == nil || !verifier.Ready() { apiresponse.WriteError(w, r, http.StatusServiceUnavailable, apiresponse.CodeAuthUnavailable, "authentication unavailable", nil) if audit != nil { audit.Log(r.Context(), "system", securityaudit.ActionTokenRejected, map[string]any{ "reason": "verifier_unavailable", "path": r.URL.Path, "method": r.Method, }) } return } claims, err := verifier.Verify(r.Context(), token) if err != nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, apiresponse.CodeAuthInvalidToken, "invalid token", nil) if audit != nil { audit.Log(r.Context(), "anonymous", securityaudit.ActionTokenRejected, map[string]any{ "reason": "token_verification_failed", "path": r.URL.Path, "method": r.Method, }) } return } if db != nil { if _, err := users.EnsureUser(r.Context(), db, claims); err != nil { slog.Error("provision user", "sub", claims.Sub, "error", err) apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, "failed to provision user", nil) return } if err := users.ApplyAccountGroups(r.Context(), db, claims); err != nil { slog.Error("apply account groups", "sub", claims.Sub, "error", err) apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, "failed to read user privileges", nil) return } var disabled bool if err := db.QueryRow(r.Context(), ` SELECT status = 'disabled' FROM users WHERE external_id = $1 `, claims.Sub).Scan(&disabled); err != nil { slog.Error("read user status", "sub", claims.Sub, "error", err) apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, "failed to read user status", nil) return } if disabled { apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "account disabled", nil) if audit != nil { audit.Log(r.Context(), claims.Sub, securityaudit.ActionTokenRejected, map[string]any{ "reason": "account_disabled", "path": r.URL.Path, "method": r.Method, }) } return } } else { claims.Groups = permission.WithSuiteDefaults(claims.Groups) } if audit != nil { audit.Log(r.Context(), claims.Sub, securityaudit.ActionLogin, map[string]any{ "email": claims.Email, "path": r.URL.Path, "method": r.Method, }) } ctx := context.WithValue(r.Context(), claimsKey, claims) next.ServeHTTP(w, r.WithContext(ctx)) }) } } func ClaimsFromContext(ctx context.Context) *auth.Claims { claims, _ := ctx.Value(claimsKey).(*auth.Claims) return claims } func ApiTokenFromContext(ctx context.Context) *apitokens.AuthContext { authCtx, _ := ctx.Value(apiTokenKey).(*apitokens.AuthContext) return authCtx } func isApiTokenManagementRoute(path string) bool { return strings.Contains(path, "/api-tokens") }