From d3c930cac6df629239bed434e38975493dfe4303 Mon Sep 17 00:00:00 2001 From: R3D347HR4Y Date: Tue, 9 Jun 2026 09:36:38 +0200 Subject: [PATCH] feat(identity-providers): add management for identity providers in admin API - Introduced new endpoints for managing identity providers, including retrieval of redirect URIs and testing/syncing providers. - Enhanced organization settings to include identity provider configurations, allowing for self-enrollment and domain restrictions. - Implemented caching for access policies and added validation for identity provider secrets. - Added integration tests to ensure proper functionality of identity provider management and policy enforcement. --- internal/api/admin/handlers.go | 4 + internal/api/admin/identity_providers.go | 98 ++++ internal/api/admin/org_settings.go | 167 +++++- internal/api/apiresponse/codes.go | 1 + internal/api/middleware/auth.go | 20 +- internal/api/office/public_handlers.go | 3 +- internal/auth/oidc.go | 12 + internal/authentik/presets.go | 82 +++ internal/authentik/presets_test.go | 46 ++ internal/authentik/source_sync.go | 492 ++++++++++++++++++ internal/authentik/sources.go | 322 ++++++++++++ .../admin/org_settings_test.go | 72 +++ internal/nextcloud/public_share.go | 18 + internal/orgpolicy/auth.go | 144 +++++ internal/orgpolicy/auth_test.go | 72 +++ internal/orgpolicy/loader.go | 81 ++- internal/server/bootstrap.go | 4 +- 17 files changed, 1627 insertions(+), 11 deletions(-) create mode 100644 internal/api/admin/identity_providers.go create mode 100644 internal/authentik/presets.go create mode 100644 internal/authentik/presets_test.go create mode 100644 internal/authentik/source_sync.go create mode 100644 internal/authentik/sources.go create mode 100644 internal/orgpolicy/auth.go create mode 100644 internal/orgpolicy/auth_test.go diff --git a/internal/api/admin/handlers.go b/internal/api/admin/handlers.go index 393c9f1..e1fd9e2 100644 --- a/internal/api/admin/handlers.go +++ b/internal/api/admin/handlers.go @@ -60,6 +60,10 @@ func (h *Handler) Routes() chi.Router { r.With(read).Get("/org/settings", h.GetOrgSettings) r.With(write).Put("/org/settings", h.PutOrgSettings) + r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI) + r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider) + r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider) + return r } diff --git a/internal/api/admin/identity_providers.go b/internal/api/admin/identity_providers.go new file mode 100644 index 0000000..82f0702 --- /dev/null +++ b/internal/api/admin/identity_providers.go @@ -0,0 +1,98 @@ +package admin + +import ( + "context" + "fmt" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/authentik" +) + +func (s *Service) SyncIdentityProviders(ctx context.Context, actorSub, providerID string) (map[string]any, error) { + stored, _, _, err := s.loadOrgPolicyRaw(ctx) + if err != nil { + return nil, err + } + idpSection, _ := stored["identity_providers"].(map[string]any) + if idpSection == nil { + return nil, fmt.Errorf("identity_providers not configured") + } + providers, _ := idpSection["providers"].([]any) + found := false + for _, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := pm["id"].(string); id == providerID { + found = true + break + } + } + if !found { + return nil, fmt.Errorf("provider not found") + } + + syncer := authentik.NewSourceSyncer(s.db, s.cfg) + if err := syncer.SyncSection(ctx, idpSection); err != nil { + return nil, err + } + s.logAudit(ctx, actorSub, "sync_identity_provider", map[string]any{ + "provider_id": providerID, + }) + return s.GetOrgSettings(ctx) +} + +func (h *Handler) GetIdentityProviderRedirectURI(w http.ResponseWriter, r *http.Request) { + slug := chi.URLParam(r, "slug") + if slug == "" { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError( + apivalidate.FieldDetail{Field: "slug", Message: "required"}, + )) + return + } + syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg) + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ + "slug": slug, + "redirect_uri": syncer.RedirectURI(slug), + }) +} + +func (h *Handler) TestIdentityProvider(w http.ResponseWriter, r *http.Request) { + providerID := chi.URLParam(r, "providerID") + if providerID == "" { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError( + apivalidate.FieldDetail{Field: "providerID", Message: "required"}, + )) + return + } + syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg) + if err := syncer.TestProvider(r.Context(), providerID); err != nil { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true}) +} + +func (h *Handler) SyncIdentityProvider(w http.ResponseWriter, r *http.Request) { + providerID := chi.URLParam(r, "providerID") + if providerID == "" { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError( + apivalidate.FieldDetail{Field: "providerID", Message: "required"}, + )) + return + } + claims := middleware.ClaimsFromContext(r.Context()) + payload, err := h.svc.SyncIdentityProviders(r.Context(), claims.Sub, providerID) + if err != nil { + h.logger.Error("sync identity provider", "error", err, "provider_id", providerID) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, payload) +} diff --git a/internal/api/admin/org_settings.go b/internal/api/admin/org_settings.go index c6e8570..f251635 100644 --- a/internal/api/admin/org_settings.go +++ b/internal/api/admin/org_settings.go @@ -9,6 +9,7 @@ import ( "github.com/jackc/pgx/v5" "github.com/ultisuite/ulti-backend/internal/config" + "github.com/ultisuite/ulti-backend/internal/authentik" ) const orgSettingsSingletonID = 1 @@ -24,6 +25,11 @@ func defaultOrgPolicy() map[string]any { "allow_password_fallback": false, "default_groups": "ulti-users", }, + "identity_providers": map[string]any{ + "allow_self_enrollment": true, + "default_login_source": "", + "providers": []any{}, + }, "two_factor": map[string]any{ "required_for_all": false, "required_for_admins": true, @@ -204,6 +210,9 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any { mergeSearchProviderSecrets(existing, patchWS, merged) } } + if patchIDP, ok := patch["identity_providers"].(map[string]any); ok { + mergeIdentityProviderSecrets(existing, patchIDP, merged) + } return merged } @@ -291,6 +300,65 @@ func mergeSearchProviderSecrets(existing map[string]any, patchWS map[string]any, merged["search"] = mergedSearch } +func mergeIdentityProviderSecrets(existing, patchIDP, merged map[string]any) { + patchProviders, _ := patchIDP["providers"].([]any) + if len(patchProviders) == 0 { + return + } + existingIDP, _ := existing["identity_providers"].(map[string]any) + existingProviders, _ := existingIDP["providers"].([]any) + mergedIDP, _ := merged["identity_providers"].(map[string]any) + if mergedIDP == nil { + return + } + mergedProviders := make([]any, len(patchProviders)) + for i, item := range patchProviders { + pm, ok := item.(map[string]any) + if !ok { + mergedProviders[i] = item + continue + } + id, _ := pm["id"].(string) + for _, secretPath := range []struct{ section, key string }{ + {"oauth", "client_secret"}, + {"ldap", "bind_password"}, + {"saml", "signing_cert"}, + } { + mergeProviderSecretField(pm, existingProviders, id, secretPath.section, secretPath.key) + } + mergedProviders[i] = pm + } + mergedIDP["providers"] = mergedProviders + merged["identity_providers"] = mergedIDP +} + +func mergeProviderSecretField(pm map[string]any, existingProviders []any, id, section, key string) { + nested, _ := pm[section].(map[string]any) + if nested == nil { + return + } + val, _ := nested[key].(string) + if strings.TrimSpace(val) != "" || id == "" { + return + } + for _, ep := range existingProviders { + em, ok := ep.(map[string]any) + if !ok { + continue + } + if eid, _ := em["id"].(string); eid != id { + continue + } + esec, _ := em[section].(map[string]any) + if esec == nil { + continue + } + if ek, ok := esec[key].(string); ok && ek != "" { + nested[key] = ek + } + } +} + func maskOrgPolicy(policy map[string]any) map[string]any { cloned := deepCloneMap(policy) maskStringField(cloned, "nextcloud", "admin_password") @@ -331,9 +399,41 @@ func maskOrgPolicy(policy map[string]any) map[string]any { } } } + maskIdentityProviderSecrets(cloned) return cloned } +func maskIdentityProviderSecrets(policy map[string]any) { + idp, ok := policy["identity_providers"].(map[string]any) + if !ok { + return + } + providers, ok := idp["providers"].([]any) + if !ok { + return + } + for i, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + for _, section := range []string{"oauth", "ldap", "saml"} { + nested, ok := pm[section].(map[string]any) + if !ok { + continue + } + for _, key := range []string{"client_secret", "bind_password", "signing_cert"} { + if v, _ := nested[key].(string); strings.TrimSpace(v) != "" { + nested[key] = "" + } + } + pm[section] = nested + } + providers[i] = pm + } + idp["providers"] = providers +} + func maskStringField(m map[string]any, section, key string) { sec, ok := m[section].(map[string]any) if !ok { @@ -361,7 +461,7 @@ func secretConfigured(policy map[string]any, section, key string) bool { } func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string]any { - return map[string]any{ + secrets := map[string]any{ "nextcloud_admin_password": map[string]any{ "configured": secretConfigured(policy, "nextcloud", "admin_password") || strings.TrimSpace(cfg.NCAdminPass) != "", }, @@ -381,10 +481,57 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string "configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "", }, } + if idpSecrets := buildIdentityProviderSecretsStatus(policy); len(idpSecrets) > 0 { + secrets["identity_providers"] = idpSecrets + } + return secrets +} + +func buildIdentityProviderSecretsStatus(policy map[string]any) map[string]any { + idp, ok := policy["identity_providers"].(map[string]any) + if !ok { + return nil + } + providers, ok := idp["providers"].([]any) + if !ok { + return nil + } + out := make(map[string]any) + for _, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + id, _ := pm["id"].(string) + if id == "" { + continue + } + entry := map[string]any{} + if oauth, ok := pm["oauth"].(map[string]any); ok { + if v, _ := oauth["client_secret"].(string); strings.TrimSpace(v) != "" { + entry["oauth_client_secret"] = map[string]any{"configured": true} + } + } + if ldap, ok := pm["ldap"].(map[string]any); ok { + if v, _ := ldap["bind_password"].(string); strings.TrimSpace(v) != "" { + entry["ldap_bind_password"] = map[string]any{"configured": true} + } + } + if saml, ok := pm["saml"].(map[string]any); ok { + if v, _ := saml["signing_cert"].(string); strings.TrimSpace(v) != "" { + entry["saml_signing_cert"] = map[string]any{"configured": true} + } + } + if len(entry) > 0 { + out[id] = entry + } + } + return out } func buildOrgEffective(cfg *config.Config) map[string]any { authentikEnabled := strings.TrimSpace(cfg.AuthentikAPIToken) != "" || strings.TrimSpace(cfg.OIDCIssuer) != "" + publicBase := authentik.AuthentikPublicBaseURL(cfg.AuthentikAPIURL, cfg.AuthentikPublicHTTPS) return map[string]any{ "authentik": map[string]any{ "enabled": authentikEnabled, @@ -392,6 +539,10 @@ func buildOrgEffective(cfg *config.Config) map[string]any { "client_id": cfg.OIDCClientID, "issuer": cfg.OIDCIssuer, }, + "identity_providers": map[string]any{ + "authentik_public_url": publicBase, + "oauth_redirect_template": publicBase + "/source/oauth/callback/{slug}/", + }, "nextcloud": map[string]any{ "enabled": cfg.NextcloudEnabled, "base_url": firstNonEmpty(cfg.NextcloudPublicURL, cfg.NextcloudURL), @@ -445,11 +596,11 @@ func (s *Service) GetOrgSettings(ctx context.Context) (map[string]any, error) { } func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map[string]any) (map[string]any, error) { - stored, _, _, err := s.loadOrgPolicyRaw(ctx) + storedBefore, _, _, err := s.loadOrgPolicyRaw(ctx) if err != nil { return nil, err } - merged := mergeOrgSecrets(stored, patch) + merged := mergeOrgSecrets(storedBefore, patch) raw, err := json.Marshal(merged) if err != nil { return nil, err @@ -468,6 +619,16 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map s.logAudit(ctx, actorSub, "update_org_settings", map[string]any{ "sections": mapKeys(patch), }) + + if _, ok := patch["identity_providers"]; ok { + removed := authentik.RemovedIdentityProviders(storedBefore, merged) + syncer := authentik.NewSourceSyncer(s.db, s.cfg) + idpSection, _ := merged["identity_providers"].(map[string]any) + if err := syncer.SyncSectionWithRemoved(ctx, idpSection, removed); err != nil { + s.logger.Warn("identity provider sync failed", "error", err) + } + } + return s.GetOrgSettings(ctx) } diff --git a/internal/api/apiresponse/codes.go b/internal/api/apiresponse/codes.go index ebe5fb3..2c537bd 100644 --- a/internal/api/apiresponse/codes.go +++ b/internal/api/apiresponse/codes.go @@ -8,6 +8,7 @@ const ( CodeAuthInvalidToken = "auth.invalid_token" CodeAuthUnauthorized = "auth.unauthorized" CodeAuthForbidden = "auth.forbidden" + CodeIdentityNotAllowed = "auth.identity_not_allowed" CodeInvalidQueryParam = "invalid_query_param" CodeInvalidRequest = "invalid_request_body" diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index 2e078c1..324eebf 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -11,6 +11,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/apitokens" "github.com/ultisuite/ulti-backend/internal/auth" + "github.com/ultisuite/ulti-backend/internal/orgpolicy" "github.com/ultisuite/ulti-backend/internal/permission" "github.com/ultisuite/ulti-backend/internal/securityaudit" "github.com/ultisuite/ulti-backend/internal/users" @@ -23,7 +24,7 @@ const ( apiTokenKey ctxKey = "api_token" ) -func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger) func(http.Handler) http.Handler { +func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger, orgPolicy *orgpolicy.Loader) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { header := r.Header.Get("Authorization") @@ -148,6 +149,23 @@ func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger) } return } + if orgPolicy != nil { + policy, err := orgPolicy.AuthAccessPolicy(r.Context()) + if err != nil { + slog.Error("load auth access policy", "sub", claims.Sub, "error", err) + } else if !policy.AllowsIdentity(claims.Email, claims) { + apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeIdentityNotAllowed, "identity not allowed by organization policy", nil) + if audit != nil { + audit.Log(r.Context(), claims.Sub, securityaudit.ActionTokenRejected, map[string]any{ + "reason": "identity_not_allowed", + "email": claims.Email, + "path": r.URL.Path, + "method": r.Method, + }) + } + return + } + } } else { claims.Groups = permission.WithSuiteDefaults(claims.Groups) } diff --git a/internal/api/office/public_handlers.go b/internal/api/office/public_handlers.go index 6376985..e1cdee0 100644 --- a/internal/api/office/public_handlers.go +++ b/internal/api/office/public_handlers.go @@ -56,7 +56,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) { if password == "" { password = publicSharePassword(r) } - perms, err := h.svc.nc.GetPublicSharePathPermissions(r.Context(), token, req.Path, password) + perms, err := h.svc.nc.EffectivePublicSharePermissions(r.Context(), token, req.Path, password) if err != nil { apivalidate.WriteInternal(w, r) return @@ -81,6 +81,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) { apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ "config": cfg, "serverUrl": h.svc.PublicURL(), + "mode": mode, }) } diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 2d8109e..3a8250f 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -16,6 +16,10 @@ type Claims struct { Email string Name string Groups []string + Source string + HD string + TID string + Org string } type Verifier struct { @@ -96,6 +100,10 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error) Email string `json:"email"` Name string `json:"name"` Groups []string `json:"groups"` + HD string `json:"hd"` + TID string `json:"tid"` + Org string `json:"org"` + Source string `json:"ak-source"` } if err := token.Claims(&claims); err != nil { return nil, err @@ -106,5 +114,9 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error) Email: claims.Email, Name: claims.Name, Groups: claims.Groups, + HD: claims.HD, + TID: claims.TID, + Org: claims.Org, + Source: claims.Source, }, nil } diff --git a/internal/authentik/presets.go b/internal/authentik/presets.go new file mode 100644 index 0000000..a088f3c --- /dev/null +++ b/internal/authentik/presets.go @@ -0,0 +1,82 @@ +package authentik + +import "strings" + +type OAuthPreset struct { + ProviderType string + AuthorizationURL string + AccessTokenURL string + ProfileURL string + DefaultScopes string + OrganizationClaim string +} + +func OAuthPresetFor(provider string) OAuthPreset { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "google": + return OAuthPreset{ + ProviderType: "google", + AuthorizationURL: "https://accounts.google.com/o/oauth2/auth", + AccessTokenURL: "https://oauth2.googleapis.com/token", + ProfileURL: "https://www.googleapis.com/oauth2/v1/userinfo", + DefaultScopes: "openid email profile", + OrganizationClaim: "hd", + } + case "github": + return OAuthPreset{ + ProviderType: "github", + AuthorizationURL: "https://github.com/login/oauth/authorize", + AccessTokenURL: "https://github.com/login/oauth/access_token", + ProfileURL: "https://api.github.com/user", + DefaultScopes: "read:user user:email", + OrganizationClaim: "org", + } + case "linkedin": + return OAuthPreset{ + ProviderType: "openidconnect", + AuthorizationURL: "https://www.linkedin.com/oauth/v2/authorization", + AccessTokenURL: "https://www.linkedin.com/oauth/v2/accessToken", + ProfileURL: "https://api.linkedin.com/v2/userinfo", + DefaultScopes: "openid profile email", + OrganizationClaim: "", + } + case "microsoft": + return OAuthPreset{ + ProviderType: "azuread", + AuthorizationURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize", + AccessTokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token", + ProfileURL: "https://graph.microsoft.com/oidc/userinfo", + DefaultScopes: "openid email profile", + OrganizationClaim: "tid", + } + default: + return OAuthPreset{ + ProviderType: "openidconnect", + DefaultScopes: "openid email profile", + } + } +} + +func OAuthRedirectURI(publicBaseURL, slug string) string { + base := strings.TrimRight(strings.TrimSpace(publicBaseURL), "/") + if base == "" { + base = "http://localhost/auth" + } + return base + "/source/oauth/callback/" + strings.TrimSpace(slug) + "/" +} + +func AuthentikPublicBaseURL(apiURL string, publicHTTPS bool) string { + apiURL = strings.TrimSpace(apiURL) + if apiURL == "" { + if publicHTTPS { + return "https://localhost/auth" + } + return "http://localhost/auth" + } + u := strings.TrimSuffix(apiURL, "/") + u = strings.TrimSuffix(u, "/api/v3") + if publicHTTPS && strings.HasPrefix(u, "http://") { + u = "https://" + strings.TrimPrefix(u, "http://") + } + return u +} diff --git a/internal/authentik/presets_test.go b/internal/authentik/presets_test.go new file mode 100644 index 0000000..2b11c2d --- /dev/null +++ b/internal/authentik/presets_test.go @@ -0,0 +1,46 @@ +package authentik + +import "testing" + +func TestOAuthPresetForGoogle(t *testing.T) { + preset := OAuthPresetFor("google") + if preset.ProviderType != "google" { + t.Fatalf("provider type = %q", preset.ProviderType) + } + if preset.AuthorizationURL == "" || preset.AccessTokenURL == "" { + t.Fatal("expected google oauth endpoints") + } +} + +func TestOAuthRedirectURI(t *testing.T) { + got := OAuthRedirectURI("http://localhost/auth", "google-workspace") + want := "http://localhost/auth/source/oauth/callback/google-workspace/" + if got != want { + t.Fatalf("redirect uri = %q, want %q", got, want) + } +} + +func TestRemovedIdentityProviders(t *testing.T) { + before := map[string]any{ + "identity_providers": map[string]any{ + "providers": []any{ + map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"}, + map[string]any{"id": "b", "authentik_pk": 2, "type": "saml"}, + }, + }, + } + after := map[string]any{ + "identity_providers": map[string]any{ + "providers": []any{ + map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"}, + }, + }, + } + removed := RemovedIdentityProviders(before, after) + if len(removed) != 1 { + t.Fatalf("removed count = %d, want 1", len(removed)) + } + if id, _ := removed[0]["id"].(string); id != "b" { + t.Fatalf("removed id = %q", id) + } +} diff --git a/internal/authentik/source_sync.go b/internal/authentik/source_sync.go new file mode 100644 index 0000000..0a6459c --- /dev/null +++ b/internal/authentik/source_sync.go @@ -0,0 +1,492 @@ +package authentik + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/config" +) + +const orgSettingsSingletonID = 1 + +type SourceSyncer struct { + db *pgxpool.Pool + cfg *config.Config +} + +func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer { + return &SourceSyncer{db: db, cfg: cfg} +} + +func (s *SourceSyncer) client() *Client { + apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL) + token := strings.TrimSpace(s.cfg.AuthentikAPIToken) + if apiURL == "" || token == "" { + return nil + } + return NewClient(apiURL, token) +} + +func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error { + stored, err := s.loadStoredPolicy(ctx) + if err != nil { + return err + } + idpSection, _ := stored["identity_providers"].(map[string]any) + if idpSection == nil { + return nil + } + return s.syncSection(ctx, idpSection, stored, nil) +} + +func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error { + stored, err := s.loadStoredPolicy(ctx) + if err != nil { + return err + } + return s.syncSection(ctx, idpSection, stored, nil) +} + +func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error { + stored, err := s.loadStoredPolicy(ctx) + if err != nil { + return err + } + return s.syncSection(ctx, idpSection, stored, removed) +} + +func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error { + client := s.client() + if client == nil { + return s.markAllPending(ctx, idpSection, "Authentik API not configured") + } + + if err := client.Ping(ctx); err != nil { + return s.markAllPending(ctx, idpSection, err.Error()) + } + + authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx) + if err != nil { + return s.markAllPending(ctx, idpSection, err.Error()) + } + + allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool) + if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil { + return err + } + + providers, _ := idpSection["providers"].([]any) + now := time.Now().UTC().Format(time.RFC3339) + + for _, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + id, _ := pm["id"].(string) + if id == "" { + continue + } + pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored) + if syncErr != nil { + pm["sync_status"] = "error" + pm["sync_error"] = syncErr.Error() + } else { + pm["sync_status"] = "synced" + pm["sync_error"] = "" + pm["last_synced_at"] = now + if pk != "" { + if n, err := strconv.Atoi(pk); err == nil { + pm["authentik_pk"] = n + } else { + pm["authentik_pk"] = pk + } + } + } + } + + if err := s.deleteRemovedSources(ctx, client, removed); err != nil { + return err + } + + idpSection["providers"] = providers + fullStored["identity_providers"] = idpSection + return s.persistPolicy(ctx, fullStored) +} + +func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error { + brand, err := client.FindDefaultBrand(ctx) + if err != nil { + return err + } + if allow { + flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment") + if err != nil { + return err + } + if !found { + return nil + } + return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK) + } + return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil) +} + +func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) { + providerType, _ := pm["type"].(string) + enabled, _ := pm["enabled"].(bool) + slug, _ := pm["slug"].(string) + name, _ := pm["name"].(string) + if slug == "" { + return "", fmt.Errorf("missing slug") + } + if name == "" { + name = slug + } + + pk := authentikPK(pm["authentik_pk"]) + switch providerType { + case "oauth": + req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name) + if err != nil { + return "", err + } + if pk == "" { + createdPK, err := client.CreateOAuthSource(ctx, req) + if err != nil { + return "", err + } + return createdPK, nil + } + return pk, client.UpdateOAuthSource(ctx, pk, req) + case "saml": + req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name) + if pk == "" { + createdPK, err := client.CreateSAMLSource(ctx, req) + if err != nil { + return "", err + } + return createdPK, nil + } + return pk, client.UpdateSAMLSource(ctx, pk, req) + case "ldap": + req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name) + if err != nil { + return "", err + } + if pk == "" { + createdPK, err := client.CreateLDAPSource(ctx, req) + if err != nil { + return "", err + } + return createdPK, nil + } + return pk, client.UpdateLDAPSource(ctx, pk, req) + default: + return "", fmt.Errorf("unsupported provider type %q", providerType) + } +} + +func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) { + oauth, _ := pm["oauth"].(map[string]any) + providerName, _ := oauth["provider"].(string) + preset := OAuthPresetFor(providerName) + + clientID, _ := oauth["client_id"].(string) + clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret") + + authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL) + tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL) + profileURL := stringOr(oauth["profile_url"], preset.ProfileURL) + scopes := stringOr(oauth["scopes"], preset.DefaultScopes) + providerType := preset.ProviderType + if providerName == "custom" { + providerType = "openidconnect" + } + + if strings.TrimSpace(clientID) == "" { + return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required") + } + if strings.TrimSpace(clientSecret) == "" { + return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required") + } + + return OAuthSourceRequest{ + Name: name, + Slug: slug, + Enabled: enabled, + ProviderType: providerType, + ClientID: clientID, + ClientSecret: clientSecret, + AuthorizationURL: authURL, + AccessTokenURL: tokenURL, + ProfileURL: profileURL, + Scopes: scopes, + AuthenticationFlow: authFlow, + EnrollmentFlow: enrollmentFlow, + }, nil +} + +func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest { + saml, _ := pm["saml"].(map[string]any) + return SAMLSourceRequest{ + Name: name, + Slug: slug, + Enabled: enabled, + MetadataURL: stringValue(saml["metadata_url"]), + MetadataXML: stringValue(saml["metadata_xml"]), + EntityID: stringValue(saml["entity_id"]), + SSOURL: stringValue(saml["sso_url"]), + SLOURL: stringValue(saml["slo_url"]), + SigningCert: stringValue(saml["signing_cert"]), + AuthenticationFlow: authFlow, + EnrollmentFlow: enrollmentFlow, + } +} + +func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) { + ldap, _ := pm["ldap"].(map[string]any) + bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password") + if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" { + return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required") + } + if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" { + return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required") + } + if strings.TrimSpace(bindPassword) == "" { + return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required") + } + return LDAPSourceRequest{ + Name: name, + Slug: slug, + Enabled: enabled, + ServerURI: stringValue(ldap["server_uri"]), + BindDN: stringValue(ldap["bind_dn"]), + BindPassword: bindPassword, + BaseDN: stringValue(ldap["base_dn"]), + UserFilter: stringValue(ldap["user_filter"]), + StartTLS: boolValue(ldap["start_tls"]), + SyncUsers: boolValue(ldap["sync_users"]), + AuthenticationFlow: authFlow, + EnrollmentFlow: enrollmentFlow, + }, nil +} + +func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error { + for _, pm := range removed { + pk := authentikPK(pm["authentik_pk"]) + if pk == "" { + continue + } + providerType, _ := pm["type"].(string) + switch providerType { + case "oauth": + _ = client.DeleteOAuthSource(ctx, pk) + case "saml": + _ = client.DeleteSAMLSource(ctx, pk) + case "ldap": + _ = client.DeleteLDAPSource(ctx, pk) + } + } + return nil +} + +func RemovedIdentityProviders(before, after map[string]any) []map[string]any { + beforeSection, _ := before["identity_providers"].(map[string]any) + afterSection, _ := after["identity_providers"].(map[string]any) + beforeProviders, _ := beforeSection["providers"].([]any) + afterProviders, _ := afterSection["providers"].([]any) + + afterIDs := make(map[string]struct{}, len(afterProviders)) + for _, item := range afterProviders { + pm, ok := item.(map[string]any) + if !ok { + continue + } + if id, _ := pm["id"].(string); id != "" { + afterIDs[id] = struct{}{} + } + } + + removed := make([]map[string]any, 0) + for _, item := range beforeProviders { + pm, ok := item.(map[string]any) + if !ok { + continue + } + id, _ := pm["id"].(string) + if id == "" { + continue + } + if _, stillPresent := afterIDs[id]; !stillPresent { + removed = append(removed, pm) + } + } + return removed +} + +func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error { + providers, _ := idpSection["providers"].([]any) + for _, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + pm["sync_status"] = "error" + pm["sync_error"] = reason + } + idpSection["providers"] = providers + stored, err := s.loadStoredPolicy(ctx) + if err != nil { + return err + } + stored["identity_providers"] = idpSection + return s.persistPolicy(ctx, stored) +} + +func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) { + var raw []byte + err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw) + if err != nil && err != pgx.ErrNoRows { + return nil, err + } + stored := map[string]any{} + if len(raw) > 0 { + if err := json.Unmarshal(raw, &stored); err != nil { + return nil, err + } + } + return stored, nil +} + +func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error { + raw, err := json.Marshal(policy) + if err != nil { + return err + } + _, err = s.db.Exec(ctx, ` + UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1 + `, orgSettingsSingletonID, raw) + return err +} + +func (s *SourceSyncer) RedirectURI(slug string) string { + return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug) +} + +func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error { + stored, err := s.loadStoredPolicy(ctx) + if err != nil { + return err + } + pm, err := findProviderByID(stored, providerID) + if err != nil { + return err + } + client := s.client() + if client == nil { + return fmt.Errorf("Authentik API not configured") + } + if err := client.Ping(ctx); err != nil { + return err + } + providerType, _ := pm["type"].(string) + switch providerType { + case "ldap": + ldap, _ := pm["ldap"].(map[string]any) + if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" { + return fmt.Errorf("ldap server_uri required") + } + case "saml": + saml, _ := pm["saml"].(map[string]any) + if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" && + strings.TrimSpace(stringValue(saml["sso_url"])) == "" { + return fmt.Errorf("saml metadata_url or sso_url required") + } + case "oauth": + oauth, _ := pm["oauth"].(map[string]any) + if strings.TrimSpace(stringValue(oauth["client_id"])) == "" { + return fmt.Errorf("oauth client_id required") + } + } + return nil +} + +func findProviderByID(stored map[string]any, id string) (map[string]any, error) { + idpSection, _ := stored["identity_providers"].(map[string]any) + providers, _ := idpSection["providers"].([]any) + for _, item := range providers { + pm, ok := item.(map[string]any) + if !ok { + continue + } + if pid, _ := pm["id"].(string); pid == id { + return pm, nil + } + } + return nil, fmt.Errorf("provider not found") +} + +func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string { + nested, _ := pm[section].(map[string]any) + if val := strings.TrimSpace(stringValue(nested[key])); val != "" { + return val + } + id, _ := pm["id"].(string) + if id == "" { + return "" + } + existing, _ := fullStored["identity_providers"].(map[string]any) + existingProviders, _ := existing["providers"].([]any) + for _, item := range existingProviders { + em, ok := item.(map[string]any) + if !ok { + continue + } + if eid, _ := em["id"].(string); eid != id { + continue + } + esec, _ := em[section].(map[string]any) + return strings.TrimSpace(stringValue(esec[key])) + } + return "" +} + +func authentikPK(v any) string { + switch t := v.(type) { + case string: + return strings.TrimSpace(t) + case float64: + return strconv.Itoa(int(t)) + case int: + return strconv.Itoa(t) + case int64: + return strconv.Itoa(int(t)) + default: + return "" + } +} + +func stringValue(v any) string { + s, _ := v.(string) + return s +} + +func stringOr(v any, fallback string) string { + if s := strings.TrimSpace(stringValue(v)); s != "" { + return s + } + return fallback +} + +func boolValue(v any) bool { + b, ok := v.(bool) + return ok && b +} diff --git a/internal/authentik/sources.go b/internal/authentik/sources.go new file mode 100644 index 0000000..7a5080f --- /dev/null +++ b/internal/authentik/sources.go @@ -0,0 +1,322 @@ +package authentik + +import ( + "context" + "fmt" + "net/url" + "strings" +) + +type sourceRef struct { + PK string `json:"pk"` + Slug string `json:"slug"` + Name string `json:"name"` +} + +type brandRef struct { + PK string `json:"pk"` + BrandUUID string `json:"brand_uuid"` + Domain string `json:"domain"` + FlowEnrollment *string `json:"flow_enrollment"` +} + +type flowInstance struct { + PK string `json:"pk"` + Slug string `json:"slug"` + Name string `json:"name"` +} + +func (c *Client) FindOAuthSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) { + q := url.Values{} + q.Set("slug", slug) + var out listResponse[sourceRef] + if err := c.getJSON(ctx, "/api/v3/sources/oauth/?"+q.Encode(), &out); err != nil { + return nil, false, err + } + if len(out.Results) == 0 { + return nil, false, nil + } + return &out.Results[0], true, nil +} + +func (c *Client) FindSAMLSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) { + q := url.Values{} + q.Set("slug", slug) + var out listResponse[sourceRef] + if err := c.getJSON(ctx, "/api/v3/sources/saml/?"+q.Encode(), &out); err != nil { + return nil, false, err + } + if len(out.Results) == 0 { + return nil, false, nil + } + return &out.Results[0], true, nil +} + +func (c *Client) FindLDAPSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) { + q := url.Values{} + q.Set("slug", slug) + var out listResponse[sourceRef] + if err := c.getJSON(ctx, "/api/v3/sources/ldap/?"+q.Encode(), &out); err != nil { + return nil, false, err + } + if len(out.Results) == 0 { + return nil, false, nil + } + return &out.Results[0], true, nil +} + +type OAuthSourceRequest struct { + Name string + Slug string + Enabled bool + ProviderType string + ClientID string + ClientSecret string + AuthorizationURL string + AccessTokenURL string + ProfileURL string + Scopes string + AuthenticationFlow string + EnrollmentFlow string +} + +func (c *Client) CreateOAuthSource(ctx context.Context, req OAuthSourceRequest) (string, error) { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "provider_type": req.ProviderType, + "consumer_key": req.ClientID, + "consumer_secret": req.ClientSecret, + "authorization_url": req.AuthorizationURL, + "access_token_url": req.AccessTokenURL, + "profile_url": req.ProfileURL, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + "user_matching_mode": "email_link", + "policy_engine_mode": "any", + "request_token_url": "", + "additional_scopes": req.Scopes, + } + var created sourceRef + if err := c.postJSON(ctx, "/api/v3/sources/oauth/", body, &created); err != nil { + return "", err + } + return created.PK, nil +} + +func (c *Client) UpdateOAuthSource(ctx context.Context, pk string, req OAuthSourceRequest) error { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "provider_type": req.ProviderType, + "consumer_key": req.ClientID, + "authorization_url": req.AuthorizationURL, + "access_token_url": req.AccessTokenURL, + "profile_url": req.ProfileURL, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + "additional_scopes": req.Scopes, + } + if strings.TrimSpace(req.ClientSecret) != "" { + body["consumer_secret"] = req.ClientSecret + } + return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk), body) +} + +func (c *Client) DeleteOAuthSource(ctx context.Context, pk string) error { + return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk)) +} + +type SAMLSourceRequest struct { + Name string + Slug string + Enabled bool + MetadataURL string + MetadataXML string + EntityID string + SSOURL string + SLOURL string + SigningCert string + AuthenticationFlow string + EnrollmentFlow string +} + +func (c *Client) CreateSAMLSource(ctx context.Context, req SAMLSourceRequest) (string, error) { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "metadata_url": req.MetadataURL, + "sso_url": req.SSOURL, + "slo_url": req.SLOURL, + "issuer": req.EntityID, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + "user_matching_mode": "email_link", + "policy_engine_mode": "any", + } + if strings.TrimSpace(req.MetadataXML) != "" { + body["metadata"] = req.MetadataXML + } + if strings.TrimSpace(req.SigningCert) != "" { + body["verification_kp"] = req.SigningCert + } + var created sourceRef + if err := c.postJSON(ctx, "/api/v3/sources/saml/", body, &created); err != nil { + return "", err + } + return created.PK, nil +} + +func (c *Client) UpdateSAMLSource(ctx context.Context, pk string, req SAMLSourceRequest) error { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "metadata_url": req.MetadataURL, + "sso_url": req.SSOURL, + "slo_url": req.SLOURL, + "issuer": req.EntityID, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + } + if strings.TrimSpace(req.MetadataXML) != "" { + body["metadata"] = req.MetadataXML + } + if strings.TrimSpace(req.SigningCert) != "" { + body["verification_kp"] = req.SigningCert + } + return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk), body) +} + +func (c *Client) DeleteSAMLSource(ctx context.Context, pk string) error { + return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk)) +} + +type LDAPSourceRequest struct { + Name string + Slug string + Enabled bool + ServerURI string + BindDN string + BindPassword string + BaseDN string + UserFilter string + StartTLS bool + SyncUsers bool + AuthenticationFlow string + EnrollmentFlow string +} + +func (c *Client) CreateLDAPSource(ctx context.Context, req LDAPSourceRequest) (string, error) { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "server_uri": req.ServerURI, + "bind_cn": req.BindDN, + "bind_password": req.BindPassword, + "base_dn": req.BaseDN, + "search_mode": "direct", + "start_tls": req.StartTLS, + "sync_users": req.SyncUsers, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + "user_matching_mode": "email_link", + "policy_engine_mode": "any", + } + if filter := strings.TrimSpace(req.UserFilter); filter != "" { + body["search_filter"] = filter + } + var created sourceRef + if err := c.postJSON(ctx, "/api/v3/sources/ldap/", body, &created); err != nil { + return "", err + } + return created.PK, nil +} + +func (c *Client) UpdateLDAPSource(ctx context.Context, pk string, req LDAPSourceRequest) error { + body := map[string]any{ + "name": req.Name, + "slug": req.Slug, + "enabled": req.Enabled, + "server_uri": req.ServerURI, + "bind_cn": req.BindDN, + "base_dn": req.BaseDN, + "start_tls": req.StartTLS, + "sync_users": req.SyncUsers, + "authentication_flow": req.AuthenticationFlow, + "enrollment_flow": req.EnrollmentFlow, + } + if filter := strings.TrimSpace(req.UserFilter); filter != "" { + body["search_filter"] = filter + } + if strings.TrimSpace(req.BindPassword) != "" { + body["bind_password"] = req.BindPassword + } + return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk), body) +} + +func (c *Client) DeleteLDAPSource(ctx context.Context, pk string) error { + return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk)) +} + +func (c *Client) deleteJSON(ctx context.Context, path string) error { + return c.doJSON(ctx, "DELETE", path, nil, nil) +} + +func (c *Client) FindDefaultBrand(ctx context.Context) (*brandRef, error) { + var out listResponse[brandRef] + if err := c.getJSON(ctx, "/api/v3/core/brands/", &out); err != nil { + return nil, err + } + for _, b := range out.Results { + if b.Domain == "authentik-default" || strings.Contains(strings.ToLower(b.Domain), "default") { + return &b, nil + } + } + if len(out.Results) > 0 { + return &out.Results[0], nil + } + return nil, fmt.Errorf("no brand found") +} + +func (c *Client) SetBrandEnrollmentFlow(ctx context.Context, brandPK string, flowPK *string) error { + body := map[string]any{} + if flowPK == nil || strings.TrimSpace(*flowPK) == "" { + body["flow_enrollment"] = nil + } else { + body["flow_enrollment"] = *flowPK + } + return c.patchJSON(ctx, fmt.Sprintf("/api/v3/core/brands/%s/", brandPK), body) +} + +func (c *Client) FindFlowInstanceBySlug(ctx context.Context, slug string) (*flowInstance, bool, error) { + q := url.Values{} + q.Set("slug", slug) + var out listResponse[flowInstance] + if err := c.getJSON(ctx, "/api/v3/flows/instances/?"+q.Encode(), &out); err != nil { + return nil, false, err + } + if len(out.Results) == 0 { + return nil, false, nil + } + return &out.Results[0], true, nil +} + +func (c *Client) ResolveSourceFlows(ctx context.Context) (authFlow, enrollmentFlow string, err error) { + authFlow, err = c.FindFlowBySlug(ctx, "default-source-authentication") + if err != nil { + authFlow, err = c.FindFlowBySlug(ctx, "default-authentication-flow") + } + if err != nil { + return "", "", err + } + enrollmentFlow, err = c.FindFlowBySlug(ctx, "default-source-enrollment") + if err != nil { + enrollmentFlow, _ = c.FindFlowBySlug(ctx, "default-enrollment-flow") + } + return authFlow, enrollmentFlow, nil +} diff --git a/internal/integrationtest/admin/org_settings_test.go b/internal/integrationtest/admin/org_settings_test.go index 4fbc629..424aef6 100644 --- a/internal/integrationtest/admin/org_settings_test.go +++ b/internal/integrationtest/admin/org_settings_test.go @@ -58,6 +58,78 @@ func TestAdminOrgSettings(t *testing.T) { } } +func TestAdminOrgSettingsIdentityProvidersSecrets(t *testing.T) { + h := integrationtest.RequireHarness(t) + adminClient, _ := integrationtest.RequireAdminClient(t, h) + + providerID := "test-oauth-provider" + putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{ + "policy": map[string]any{ + "identity_providers": map[string]any{ + "allow_self_enrollment": true, + "providers": []any{ + map[string]any{ + "id": providerID, + "name": "Google Workspace", + "slug": "google-workspace", + "type": "oauth", + "enabled": true, + "oauth": map[string]any{ + "provider": "google", + "client_id": "client-id", + "client_secret": "super-secret", + "scopes": "openid email profile", + }, + "allowed_email_domains": []any{"company.com"}, + }, + }, + }, + }, + }) + integrationtest.FailIf(err, t, "put identity providers") + integrationtest.FailUnlessStatus(t, putResp, 200) + + var afterPut map[string]any + integrationtest.DecodeJSON(t, putResp, &afterPut) + secrets, ok := afterPut["secrets"].(map[string]any) + if !ok { + t.Fatalf("missing secrets: %#v", afterPut) + } + idpSecrets, ok := secrets["identity_providers"].(map[string]any) + if !ok { + t.Fatalf("missing identity provider secrets: %#v", secrets) + } + providerSecrets, ok := idpSecrets[providerID].(map[string]any) + if !ok { + t.Fatalf("missing provider secrets entry: %#v", idpSecrets) + } + oauthSecret, ok := providerSecrets["oauth_client_secret"].(map[string]any) + if !ok || oauthSecret["configured"] != true { + t.Fatalf("oauth secret not configured: %#v", providerSecrets) + } + + policy, ok := afterPut["policy"].(map[string]any) + if !ok { + t.Fatalf("missing policy") + } + idpPolicy, ok := policy["identity_providers"].(map[string]any) + if !ok { + t.Fatalf("missing identity_providers policy") + } + providers, ok := idpPolicy["providers"].([]any) + if !ok || len(providers) != 1 { + t.Fatalf("expected one provider") + } + provider, ok := providers[0].(map[string]any) + if !ok { + t.Fatalf("invalid provider payload") + } + oauth, ok := provider["oauth"].(map[string]any) + if !ok || oauth["client_secret"] != "" { + t.Fatalf("client_secret should be masked on GET") + } +} + func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) { h := integrationtest.RequireHarness(t) adminClient, _ := integrationtest.RequireAdminClient(t, h) diff --git a/internal/nextcloud/public_share.go b/internal/nextcloud/public_share.go index 6af9b8a..49cea9d 100644 --- a/internal/nextcloud/public_share.go +++ b/internal/nextcloud/public_share.go @@ -314,6 +314,24 @@ func (c *Client) GetPublicSharePermissions(ctx context.Context, token, password return c.GetPublicSharePathPermissions(ctx, token, "/", password) } +// EffectivePublicSharePermissions returns share permissions for a path. +// Nextcloud often omits oc:permissions on nested WebDAV nodes; fall back to root share bits. +func (c *Client) EffectivePublicSharePermissions(ctx context.Context, token, relPath, password string) (int, error) { + root, err := c.GetPublicSharePermissions(ctx, token, password) + if err != nil { + return 0, err + } + relPath = NormalizeClientPath(relPath) + if relPath == "/" { + return root, nil + } + pathPerms, err := c.GetPublicSharePathPermissions(ctx, token, relPath, password) + if err != nil || pathPerms == 0 { + return root, nil + } + return root | pathPerms, nil +} + const propfindPublicRevisionBody = ` diff --git a/internal/orgpolicy/auth.go b/internal/orgpolicy/auth.go new file mode 100644 index 0000000..d66d9c4 --- /dev/null +++ b/internal/orgpolicy/auth.go @@ -0,0 +1,144 @@ +package orgpolicy + +import ( + "strings" + + "github.com/ultisuite/ulti-backend/internal/auth" +) + +type IdentityProviderPolicy struct { + ID string + Slug string + Type string + Enabled bool + AllowedEmailDomains []string + AllowedIdentities []string + AllowedOrganizations []string +} + +type AuthAccessPolicy struct { + AllowSelfEnrollment bool + Providers []IdentityProviderPolicy +} + +func (p AuthAccessPolicy) AllowsIdentity(email string, claims *auth.Claims) bool { + enabled := make([]IdentityProviderPolicy, 0) + for _, provider := range p.Providers { + if provider.Enabled { + enabled = append(enabled, provider) + } + } + if len(enabled) == 0 { + return true + } + + if claims != nil && strings.TrimSpace(claims.Source) != "" { + for _, provider := range enabled { + if provider.Slug == claims.Source && providerAllows(provider, email, claims) { + return true + } + } + } + + hasRestrictions := false + for _, provider := range enabled { + if providerHasRestrictions(provider) { + hasRestrictions = true + if providerAllows(provider, email, claims) { + return true + } + } + } + if !hasRestrictions { + return true + } + return false +} + +func providerHasRestrictions(provider IdentityProviderPolicy) bool { + return len(provider.AllowedEmailDomains) > 0 || + len(provider.AllowedIdentities) > 0 || + len(provider.AllowedOrganizations) > 0 +} + +func providerAllows(provider IdentityProviderPolicy, email string, claims *auth.Claims) bool { + if len(provider.AllowedIdentities) > 0 { + if !containsFold(provider.AllowedIdentities, email) { + return false + } + } + if len(provider.AllowedEmailDomains) > 0 { + if !emailDomainAllowed(email, provider.AllowedEmailDomains) { + return false + } + } + if len(provider.AllowedOrganizations) > 0 { + if claims == nil || !organizationAllowed(claims, provider.AllowedOrganizations) { + return false + } + } + return true +} + +func emailDomainAllowed(email string, domains []string) bool { + email = strings.ToLower(strings.TrimSpace(email)) + at := strings.LastIndex(email, "@") + if at < 0 { + return false + } + domain := email[at+1:] + for _, allowed := range domains { + allowed = strings.ToLower(strings.TrimSpace(strings.TrimPrefix(allowed, "@"))) + if allowed == "" { + continue + } + if domain == allowed || strings.HasSuffix(domain, "."+allowed) { + return true + } + } + return false +} + +func organizationAllowed(claims *auth.Claims, orgs []string) bool { + candidates := []string{ + claims.HD, + claims.TID, + claims.Org, + } + for _, org := range orgs { + org = strings.TrimSpace(org) + if org == "" { + continue + } + for _, candidate := range candidates { + if strings.EqualFold(strings.TrimSpace(candidate), org) { + return true + } + } + } + return false +} + +func containsFold(list []string, value string) bool { + value = strings.ToLower(strings.TrimSpace(value)) + for _, item := range list { + if strings.ToLower(strings.TrimSpace(item)) == value { + return true + } + } + return false +} + +func stringSlice(v any) []string { + raw, ok := v.([]any) + if !ok { + return nil + } + out := make([]string, 0, len(raw)) + for _, item := range raw { + if s, ok := item.(string); ok && strings.TrimSpace(s) != "" { + out = append(out, strings.TrimSpace(s)) + } + } + return out +} diff --git a/internal/orgpolicy/auth_test.go b/internal/orgpolicy/auth_test.go new file mode 100644 index 0000000..118c6e2 --- /dev/null +++ b/internal/orgpolicy/auth_test.go @@ -0,0 +1,72 @@ +package orgpolicy + +import ( + "testing" + + "github.com/ultisuite/ulti-backend/internal/auth" +) + +func TestAuthAccessPolicyAllowsOpenProviders(t *testing.T) { + policy := AuthAccessPolicy{ + Providers: []IdentityProviderPolicy{ + {Enabled: true, Slug: "google"}, + }, + } + claims := &auth.Claims{Email: "user@example.com"} + if !policy.AllowsIdentity(claims.Email, claims) { + t.Fatal("expected open provider to allow any identity") + } +} + +func TestAuthAccessPolicyRejectsUnknownDomain(t *testing.T) { + policy := AuthAccessPolicy{ + Providers: []IdentityProviderPolicy{ + { + Enabled: true, + Slug: "google", + AllowedEmailDomains: []string{"company.com"}, + }, + }, + } + claims := &auth.Claims{Email: "user@gmail.com"} + if policy.AllowsIdentity(claims.Email, claims) { + t.Fatal("expected domain restriction to reject identity") + } +} + +func TestAuthAccessPolicyAllowsMatchingOrganization(t *testing.T) { + policy := AuthAccessPolicy{ + Providers: []IdentityProviderPolicy{ + { + Enabled: true, + Slug: "google", + AllowedOrganizations: []string{"company.com"}, + }, + }, + } + claims := &auth.Claims{Email: "user@company.com", HD: "company.com"} + if !policy.AllowsIdentity(claims.Email, claims) { + t.Fatal("expected matching hosted domain to allow identity") + } +} + +func TestAuthAccessPolicyMatchesSourceSpecificProvider(t *testing.T) { + policy := AuthAccessPolicy{ + Providers: []IdentityProviderPolicy{ + { + Enabled: true, + Slug: "corp-google", + AllowedEmailDomains: []string{"company.com"}, + }, + { + Enabled: true, + Slug: "partner-google", + AllowedEmailDomains: []string{"partner.org"}, + }, + }, + } + claims := &auth.Claims{Email: "user@partner.org", Source: "partner-google"} + if !policy.AllowsIdentity(claims.Email, claims) { + t.Fatal("expected source-specific provider rules to allow identity") + } +} diff --git a/internal/orgpolicy/loader.go b/internal/orgpolicy/loader.go index 9514834..2d2239f 100644 --- a/internal/orgpolicy/loader.go +++ b/internal/orgpolicy/loader.go @@ -28,10 +28,12 @@ type Loader struct { db *pgxpool.Pool cfg *config.Config - mu sync.Mutex - cached FilePolicies - cachedAt time.Time - ttl time.Duration + mu sync.Mutex + cached FilePolicies + cachedAt time.Time + authCached AuthAccessPolicy + authCachedAt time.Time + ttl time.Duration } func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader { @@ -120,6 +122,77 @@ func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) { return true, fp.VirusTotalAPIKey, nil } +func (l *Loader) AuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) { + l.mu.Lock() + if !l.authCachedAt.IsZero() && time.Since(l.authCachedAt) < l.ttl { + out := l.authCached + l.mu.Unlock() + return out, nil + } + l.mu.Unlock() + + policy, err := l.loadAuthAccessPolicy(ctx) + if err != nil { + return AuthAccessPolicy{}, err + } + + l.mu.Lock() + l.authCached = policy + l.authCachedAt = time.Now() + l.mu.Unlock() + return policy, nil +} + +func (l *Loader) loadAuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) { + var raw []byte + err := l.db.QueryRow(ctx, ` + SELECT settings FROM org_settings WHERE id = $1 + `, orgSettingsSingletonID).Scan(&raw) + if err != nil && err != pgx.ErrNoRows { + return AuthAccessPolicy{}, err + } + + stored := map[string]any{} + if len(raw) > 0 { + if err := json.Unmarshal(raw, &stored); err != nil { + return AuthAccessPolicy{}, err + } + } + + idp, _ := stored["identity_providers"].(map[string]any) + if idp == nil { + return AuthAccessPolicy{AllowSelfEnrollment: true}, nil + } + + allowSelfEnrollment := true + if v, ok := idp["allow_self_enrollment"].(bool); ok { + allowSelfEnrollment = v + } + + providersRaw, _ := idp["providers"].([]any) + providers := make([]IdentityProviderPolicy, 0, len(providersRaw)) + for _, item := range providersRaw { + pm, ok := item.(map[string]any) + if !ok { + continue + } + providers = append(providers, IdentityProviderPolicy{ + ID: stringValue(pm["id"]), + Slug: stringValue(pm["slug"]), + Type: stringValue(pm["type"]), + Enabled: boolValue(pm["enabled"]), + AllowedEmailDomains: stringSlice(pm["allowed_email_domains"]), + AllowedIdentities: stringSlice(pm["allowed_identities"]), + AllowedOrganizations: stringSlice(pm["allowed_organizations"]), + }) + } + + return AuthAccessPolicy{ + AllowSelfEnrollment: allowSelfEnrollment, + Providers: providers, + }, nil +} + func boolValue(v any) bool { switch t := v.(type) { case bool: diff --git a/internal/server/bootstrap.go b/internal/server/bootstrap.go index 07f7468..e47e818 100644 --- a/internal/server/bootstrap.go +++ b/internal/server/bootstrap.go @@ -284,7 +284,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { JWTSecret: cfg.OnlyOfficeJWTSecret, }) officeHandler := office.NewHandler(officeSvc, driveSvc) - r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger))) + r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader))) driveHandler.SetPublicOffice(officeHandler) } if driveHandler != nil { @@ -292,7 +292,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { } r.Group(func(r chi.Router) { - r.Use(middleware.Auth(verifierHolder, pool, auditLogger)) + r.Use(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader)) r.Use(middleware.EnforceApiTokenPolicy()) r.Mount("/api/v1/users", usersapi.NewHandler(pool).Routes())