feat(identity-providers): add management for identity providers in admin API
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run

- Introduced new endpoints for managing identity providers, including retrieval of redirect URIs and testing/syncing providers.
- Enhanced organization settings to include identity provider configurations, allowing for self-enrollment and domain restrictions.
- Implemented caching for access policies and added validation for identity provider secrets.
- Added integration tests to ensure proper functionality of identity provider management and policy enforcement.
This commit is contained in:
R3D347HR4Y 2026-06-09 09:36:38 +02:00
parent b90edf317c
commit d3c930cac6
17 changed files with 1627 additions and 11 deletions

View File

@ -60,6 +60,10 @@ func (h *Handler) Routes() chi.Router {
r.With(read).Get("/org/settings", h.GetOrgSettings)
r.With(write).Put("/org/settings", h.PutOrgSettings)
r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI)
r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider)
r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider)
return r
}

View File

@ -0,0 +1,98 @@
package admin
import (
"context"
"fmt"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
"github.com/ultisuite/ulti-backend/internal/api/middleware"
"github.com/ultisuite/ulti-backend/internal/authentik"
)
func (s *Service) SyncIdentityProviders(ctx context.Context, actorSub, providerID string) (map[string]any, error) {
stored, _, _, err := s.loadOrgPolicyRaw(ctx)
if err != nil {
return nil, err
}
idpSection, _ := stored["identity_providers"].(map[string]any)
if idpSection == nil {
return nil, fmt.Errorf("identity_providers not configured")
}
providers, _ := idpSection["providers"].([]any)
found := false
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := pm["id"].(string); id == providerID {
found = true
break
}
}
if !found {
return nil, fmt.Errorf("provider not found")
}
syncer := authentik.NewSourceSyncer(s.db, s.cfg)
if err := syncer.SyncSection(ctx, idpSection); err != nil {
return nil, err
}
s.logAudit(ctx, actorSub, "sync_identity_provider", map[string]any{
"provider_id": providerID,
})
return s.GetOrgSettings(ctx)
}
func (h *Handler) GetIdentityProviderRedirectURI(w http.ResponseWriter, r *http.Request) {
slug := chi.URLParam(r, "slug")
if slug == "" {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
apivalidate.FieldDetail{Field: "slug", Message: "required"},
))
return
}
syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg)
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
"slug": slug,
"redirect_uri": syncer.RedirectURI(slug),
})
}
func (h *Handler) TestIdentityProvider(w http.ResponseWriter, r *http.Request) {
providerID := chi.URLParam(r, "providerID")
if providerID == "" {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
apivalidate.FieldDetail{Field: "providerID", Message: "required"},
))
return
}
syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg)
if err := syncer.TestProvider(r.Context(), providerID); err != nil {
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
}
func (h *Handler) SyncIdentityProvider(w http.ResponseWriter, r *http.Request) {
providerID := chi.URLParam(r, "providerID")
if providerID == "" {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
apivalidate.FieldDetail{Field: "providerID", Message: "required"},
))
return
}
claims := middleware.ClaimsFromContext(r.Context())
payload, err := h.svc.SyncIdentityProviders(r.Context(), claims.Sub, providerID)
if err != nil {
h.logger.Error("sync identity provider", "error", err, "provider_id", providerID)
apivalidate.WriteInternal(w, r)
return
}
apiresponse.WriteJSON(w, http.StatusOK, payload)
}

View File

@ -9,6 +9,7 @@ import (
"github.com/jackc/pgx/v5"
"github.com/ultisuite/ulti-backend/internal/config"
"github.com/ultisuite/ulti-backend/internal/authentik"
)
const orgSettingsSingletonID = 1
@ -24,6 +25,11 @@ func defaultOrgPolicy() map[string]any {
"allow_password_fallback": false,
"default_groups": "ulti-users",
},
"identity_providers": map[string]any{
"allow_self_enrollment": true,
"default_login_source": "",
"providers": []any{},
},
"two_factor": map[string]any{
"required_for_all": false,
"required_for_admins": true,
@ -204,6 +210,9 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
mergeSearchProviderSecrets(existing, patchWS, merged)
}
}
if patchIDP, ok := patch["identity_providers"].(map[string]any); ok {
mergeIdentityProviderSecrets(existing, patchIDP, merged)
}
return merged
}
@ -291,6 +300,65 @@ func mergeSearchProviderSecrets(existing map[string]any, patchWS map[string]any,
merged["search"] = mergedSearch
}
func mergeIdentityProviderSecrets(existing, patchIDP, merged map[string]any) {
patchProviders, _ := patchIDP["providers"].([]any)
if len(patchProviders) == 0 {
return
}
existingIDP, _ := existing["identity_providers"].(map[string]any)
existingProviders, _ := existingIDP["providers"].([]any)
mergedIDP, _ := merged["identity_providers"].(map[string]any)
if mergedIDP == nil {
return
}
mergedProviders := make([]any, len(patchProviders))
for i, item := range patchProviders {
pm, ok := item.(map[string]any)
if !ok {
mergedProviders[i] = item
continue
}
id, _ := pm["id"].(string)
for _, secretPath := range []struct{ section, key string }{
{"oauth", "client_secret"},
{"ldap", "bind_password"},
{"saml", "signing_cert"},
} {
mergeProviderSecretField(pm, existingProviders, id, secretPath.section, secretPath.key)
}
mergedProviders[i] = pm
}
mergedIDP["providers"] = mergedProviders
merged["identity_providers"] = mergedIDP
}
func mergeProviderSecretField(pm map[string]any, existingProviders []any, id, section, key string) {
nested, _ := pm[section].(map[string]any)
if nested == nil {
return
}
val, _ := nested[key].(string)
if strings.TrimSpace(val) != "" || id == "" {
return
}
for _, ep := range existingProviders {
em, ok := ep.(map[string]any)
if !ok {
continue
}
if eid, _ := em["id"].(string); eid != id {
continue
}
esec, _ := em[section].(map[string]any)
if esec == nil {
continue
}
if ek, ok := esec[key].(string); ok && ek != "" {
nested[key] = ek
}
}
}
func maskOrgPolicy(policy map[string]any) map[string]any {
cloned := deepCloneMap(policy)
maskStringField(cloned, "nextcloud", "admin_password")
@ -331,9 +399,41 @@ func maskOrgPolicy(policy map[string]any) map[string]any {
}
}
}
maskIdentityProviderSecrets(cloned)
return cloned
}
func maskIdentityProviderSecrets(policy map[string]any) {
idp, ok := policy["identity_providers"].(map[string]any)
if !ok {
return
}
providers, ok := idp["providers"].([]any)
if !ok {
return
}
for i, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
for _, section := range []string{"oauth", "ldap", "saml"} {
nested, ok := pm[section].(map[string]any)
if !ok {
continue
}
for _, key := range []string{"client_secret", "bind_password", "signing_cert"} {
if v, _ := nested[key].(string); strings.TrimSpace(v) != "" {
nested[key] = ""
}
}
pm[section] = nested
}
providers[i] = pm
}
idp["providers"] = providers
}
func maskStringField(m map[string]any, section, key string) {
sec, ok := m[section].(map[string]any)
if !ok {
@ -361,7 +461,7 @@ func secretConfigured(policy map[string]any, section, key string) bool {
}
func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string]any {
return map[string]any{
secrets := map[string]any{
"nextcloud_admin_password": map[string]any{
"configured": secretConfigured(policy, "nextcloud", "admin_password") || strings.TrimSpace(cfg.NCAdminPass) != "",
},
@ -381,10 +481,57 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
"configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "",
},
}
if idpSecrets := buildIdentityProviderSecretsStatus(policy); len(idpSecrets) > 0 {
secrets["identity_providers"] = idpSecrets
}
return secrets
}
func buildIdentityProviderSecretsStatus(policy map[string]any) map[string]any {
idp, ok := policy["identity_providers"].(map[string]any)
if !ok {
return nil
}
providers, ok := idp["providers"].([]any)
if !ok {
return nil
}
out := make(map[string]any)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
id, _ := pm["id"].(string)
if id == "" {
continue
}
entry := map[string]any{}
if oauth, ok := pm["oauth"].(map[string]any); ok {
if v, _ := oauth["client_secret"].(string); strings.TrimSpace(v) != "" {
entry["oauth_client_secret"] = map[string]any{"configured": true}
}
}
if ldap, ok := pm["ldap"].(map[string]any); ok {
if v, _ := ldap["bind_password"].(string); strings.TrimSpace(v) != "" {
entry["ldap_bind_password"] = map[string]any{"configured": true}
}
}
if saml, ok := pm["saml"].(map[string]any); ok {
if v, _ := saml["signing_cert"].(string); strings.TrimSpace(v) != "" {
entry["saml_signing_cert"] = map[string]any{"configured": true}
}
}
if len(entry) > 0 {
out[id] = entry
}
}
return out
}
func buildOrgEffective(cfg *config.Config) map[string]any {
authentikEnabled := strings.TrimSpace(cfg.AuthentikAPIToken) != "" || strings.TrimSpace(cfg.OIDCIssuer) != ""
publicBase := authentik.AuthentikPublicBaseURL(cfg.AuthentikAPIURL, cfg.AuthentikPublicHTTPS)
return map[string]any{
"authentik": map[string]any{
"enabled": authentikEnabled,
@ -392,6 +539,10 @@ func buildOrgEffective(cfg *config.Config) map[string]any {
"client_id": cfg.OIDCClientID,
"issuer": cfg.OIDCIssuer,
},
"identity_providers": map[string]any{
"authentik_public_url": publicBase,
"oauth_redirect_template": publicBase + "/source/oauth/callback/{slug}/",
},
"nextcloud": map[string]any{
"enabled": cfg.NextcloudEnabled,
"base_url": firstNonEmpty(cfg.NextcloudPublicURL, cfg.NextcloudURL),
@ -445,11 +596,11 @@ func (s *Service) GetOrgSettings(ctx context.Context) (map[string]any, error) {
}
func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map[string]any) (map[string]any, error) {
stored, _, _, err := s.loadOrgPolicyRaw(ctx)
storedBefore, _, _, err := s.loadOrgPolicyRaw(ctx)
if err != nil {
return nil, err
}
merged := mergeOrgSecrets(stored, patch)
merged := mergeOrgSecrets(storedBefore, patch)
raw, err := json.Marshal(merged)
if err != nil {
return nil, err
@ -468,6 +619,16 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map
s.logAudit(ctx, actorSub, "update_org_settings", map[string]any{
"sections": mapKeys(patch),
})
if _, ok := patch["identity_providers"]; ok {
removed := authentik.RemovedIdentityProviders(storedBefore, merged)
syncer := authentik.NewSourceSyncer(s.db, s.cfg)
idpSection, _ := merged["identity_providers"].(map[string]any)
if err := syncer.SyncSectionWithRemoved(ctx, idpSection, removed); err != nil {
s.logger.Warn("identity provider sync failed", "error", err)
}
}
return s.GetOrgSettings(ctx)
}

View File

@ -8,6 +8,7 @@ const (
CodeAuthInvalidToken = "auth.invalid_token"
CodeAuthUnauthorized = "auth.unauthorized"
CodeAuthForbidden = "auth.forbidden"
CodeIdentityNotAllowed = "auth.identity_not_allowed"
CodeInvalidQueryParam = "invalid_query_param"
CodeInvalidRequest = "invalid_request_body"

View File

@ -11,6 +11,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/apitokens"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
"github.com/ultisuite/ulti-backend/internal/permission"
"github.com/ultisuite/ulti-backend/internal/securityaudit"
"github.com/ultisuite/ulti-backend/internal/users"
@ -23,7 +24,7 @@ const (
apiTokenKey ctxKey = "api_token"
)
func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger) func(http.Handler) http.Handler {
func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger, orgPolicy *orgpolicy.Loader) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
header := r.Header.Get("Authorization")
@ -148,6 +149,23 @@ func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger)
}
return
}
if orgPolicy != nil {
policy, err := orgPolicy.AuthAccessPolicy(r.Context())
if err != nil {
slog.Error("load auth access policy", "sub", claims.Sub, "error", err)
} else if !policy.AllowsIdentity(claims.Email, claims) {
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeIdentityNotAllowed, "identity not allowed by organization policy", nil)
if audit != nil {
audit.Log(r.Context(), claims.Sub, securityaudit.ActionTokenRejected, map[string]any{
"reason": "identity_not_allowed",
"email": claims.Email,
"path": r.URL.Path,
"method": r.Method,
})
}
return
}
}
} else {
claims.Groups = permission.WithSuiteDefaults(claims.Groups)
}

View File

@ -56,7 +56,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) {
if password == "" {
password = publicSharePassword(r)
}
perms, err := h.svc.nc.GetPublicSharePathPermissions(r.Context(), token, req.Path, password)
perms, err := h.svc.nc.EffectivePublicSharePermissions(r.Context(), token, req.Path, password)
if err != nil {
apivalidate.WriteInternal(w, r)
return
@ -81,6 +81,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) {
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
"config": cfg,
"serverUrl": h.svc.PublicURL(),
"mode": mode,
})
}

View File

@ -16,6 +16,10 @@ type Claims struct {
Email string
Name string
Groups []string
Source string
HD string
TID string
Org string
}
type Verifier struct {
@ -96,6 +100,10 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error)
Email string `json:"email"`
Name string `json:"name"`
Groups []string `json:"groups"`
HD string `json:"hd"`
TID string `json:"tid"`
Org string `json:"org"`
Source string `json:"ak-source"`
}
if err := token.Claims(&claims); err != nil {
return nil, err
@ -106,5 +114,9 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error)
Email: claims.Email,
Name: claims.Name,
Groups: claims.Groups,
HD: claims.HD,
TID: claims.TID,
Org: claims.Org,
Source: claims.Source,
}, nil
}

View File

@ -0,0 +1,82 @@
package authentik
import "strings"
type OAuthPreset struct {
ProviderType string
AuthorizationURL string
AccessTokenURL string
ProfileURL string
DefaultScopes string
OrganizationClaim string
}
func OAuthPresetFor(provider string) OAuthPreset {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "google":
return OAuthPreset{
ProviderType: "google",
AuthorizationURL: "https://accounts.google.com/o/oauth2/auth",
AccessTokenURL: "https://oauth2.googleapis.com/token",
ProfileURL: "https://www.googleapis.com/oauth2/v1/userinfo",
DefaultScopes: "openid email profile",
OrganizationClaim: "hd",
}
case "github":
return OAuthPreset{
ProviderType: "github",
AuthorizationURL: "https://github.com/login/oauth/authorize",
AccessTokenURL: "https://github.com/login/oauth/access_token",
ProfileURL: "https://api.github.com/user",
DefaultScopes: "read:user user:email",
OrganizationClaim: "org",
}
case "linkedin":
return OAuthPreset{
ProviderType: "openidconnect",
AuthorizationURL: "https://www.linkedin.com/oauth/v2/authorization",
AccessTokenURL: "https://www.linkedin.com/oauth/v2/accessToken",
ProfileURL: "https://api.linkedin.com/v2/userinfo",
DefaultScopes: "openid profile email",
OrganizationClaim: "",
}
case "microsoft":
return OAuthPreset{
ProviderType: "azuread",
AuthorizationURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
AccessTokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
ProfileURL: "https://graph.microsoft.com/oidc/userinfo",
DefaultScopes: "openid email profile",
OrganizationClaim: "tid",
}
default:
return OAuthPreset{
ProviderType: "openidconnect",
DefaultScopes: "openid email profile",
}
}
}
func OAuthRedirectURI(publicBaseURL, slug string) string {
base := strings.TrimRight(strings.TrimSpace(publicBaseURL), "/")
if base == "" {
base = "http://localhost/auth"
}
return base + "/source/oauth/callback/" + strings.TrimSpace(slug) + "/"
}
func AuthentikPublicBaseURL(apiURL string, publicHTTPS bool) string {
apiURL = strings.TrimSpace(apiURL)
if apiURL == "" {
if publicHTTPS {
return "https://localhost/auth"
}
return "http://localhost/auth"
}
u := strings.TrimSuffix(apiURL, "/")
u = strings.TrimSuffix(u, "/api/v3")
if publicHTTPS && strings.HasPrefix(u, "http://") {
u = "https://" + strings.TrimPrefix(u, "http://")
}
return u
}

View File

@ -0,0 +1,46 @@
package authentik
import "testing"
func TestOAuthPresetForGoogle(t *testing.T) {
preset := OAuthPresetFor("google")
if preset.ProviderType != "google" {
t.Fatalf("provider type = %q", preset.ProviderType)
}
if preset.AuthorizationURL == "" || preset.AccessTokenURL == "" {
t.Fatal("expected google oauth endpoints")
}
}
func TestOAuthRedirectURI(t *testing.T) {
got := OAuthRedirectURI("http://localhost/auth", "google-workspace")
want := "http://localhost/auth/source/oauth/callback/google-workspace/"
if got != want {
t.Fatalf("redirect uri = %q, want %q", got, want)
}
}
func TestRemovedIdentityProviders(t *testing.T) {
before := map[string]any{
"identity_providers": map[string]any{
"providers": []any{
map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"},
map[string]any{"id": "b", "authentik_pk": 2, "type": "saml"},
},
},
}
after := map[string]any{
"identity_providers": map[string]any{
"providers": []any{
map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"},
},
},
}
removed := RemovedIdentityProviders(before, after)
if len(removed) != 1 {
t.Fatalf("removed count = %d, want 1", len(removed))
}
if id, _ := removed[0]["id"].(string); id != "b" {
t.Fatalf("removed id = %q", id)
}
}

View File

@ -0,0 +1,492 @@
package authentik
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/config"
)
const orgSettingsSingletonID = 1
type SourceSyncer struct {
db *pgxpool.Pool
cfg *config.Config
}
func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer {
return &SourceSyncer{db: db, cfg: cfg}
}
func (s *SourceSyncer) client() *Client {
apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL)
token := strings.TrimSpace(s.cfg.AuthentikAPIToken)
if apiURL == "" || token == "" {
return nil
}
return NewClient(apiURL, token)
}
func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
idpSection, _ := stored["identity_providers"].(map[string]any)
if idpSection == nil {
return nil
}
return s.syncSection(ctx, idpSection, stored, nil)
}
func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
return s.syncSection(ctx, idpSection, stored, nil)
}
func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
return s.syncSection(ctx, idpSection, stored, removed)
}
func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error {
client := s.client()
if client == nil {
return s.markAllPending(ctx, idpSection, "Authentik API not configured")
}
if err := client.Ping(ctx); err != nil {
return s.markAllPending(ctx, idpSection, err.Error())
}
authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx)
if err != nil {
return s.markAllPending(ctx, idpSection, err.Error())
}
allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool)
if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil {
return err
}
providers, _ := idpSection["providers"].([]any)
now := time.Now().UTC().Format(time.RFC3339)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
id, _ := pm["id"].(string)
if id == "" {
continue
}
pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored)
if syncErr != nil {
pm["sync_status"] = "error"
pm["sync_error"] = syncErr.Error()
} else {
pm["sync_status"] = "synced"
pm["sync_error"] = ""
pm["last_synced_at"] = now
if pk != "" {
if n, err := strconv.Atoi(pk); err == nil {
pm["authentik_pk"] = n
} else {
pm["authentik_pk"] = pk
}
}
}
}
if err := s.deleteRemovedSources(ctx, client, removed); err != nil {
return err
}
idpSection["providers"] = providers
fullStored["identity_providers"] = idpSection
return s.persistPolicy(ctx, fullStored)
}
func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error {
brand, err := client.FindDefaultBrand(ctx)
if err != nil {
return err
}
if allow {
flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment")
if err != nil {
return err
}
if !found {
return nil
}
return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK)
}
return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil)
}
func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) {
providerType, _ := pm["type"].(string)
enabled, _ := pm["enabled"].(bool)
slug, _ := pm["slug"].(string)
name, _ := pm["name"].(string)
if slug == "" {
return "", fmt.Errorf("missing slug")
}
if name == "" {
name = slug
}
pk := authentikPK(pm["authentik_pk"])
switch providerType {
case "oauth":
req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
if err != nil {
return "", err
}
if pk == "" {
createdPK, err := client.CreateOAuthSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateOAuthSource(ctx, pk, req)
case "saml":
req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name)
if pk == "" {
createdPK, err := client.CreateSAMLSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateSAMLSource(ctx, pk, req)
case "ldap":
req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
if err != nil {
return "", err
}
if pk == "" {
createdPK, err := client.CreateLDAPSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateLDAPSource(ctx, pk, req)
default:
return "", fmt.Errorf("unsupported provider type %q", providerType)
}
}
func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) {
oauth, _ := pm["oauth"].(map[string]any)
providerName, _ := oauth["provider"].(string)
preset := OAuthPresetFor(providerName)
clientID, _ := oauth["client_id"].(string)
clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret")
authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL)
tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL)
profileURL := stringOr(oauth["profile_url"], preset.ProfileURL)
scopes := stringOr(oauth["scopes"], preset.DefaultScopes)
providerType := preset.ProviderType
if providerName == "custom" {
providerType = "openidconnect"
}
if strings.TrimSpace(clientID) == "" {
return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required")
}
if strings.TrimSpace(clientSecret) == "" {
return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required")
}
return OAuthSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
ProviderType: providerType,
ClientID: clientID,
ClientSecret: clientSecret,
AuthorizationURL: authURL,
AccessTokenURL: tokenURL,
ProfileURL: profileURL,
Scopes: scopes,
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}, nil
}
func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest {
saml, _ := pm["saml"].(map[string]any)
return SAMLSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
MetadataURL: stringValue(saml["metadata_url"]),
MetadataXML: stringValue(saml["metadata_xml"]),
EntityID: stringValue(saml["entity_id"]),
SSOURL: stringValue(saml["sso_url"]),
SLOURL: stringValue(saml["slo_url"]),
SigningCert: stringValue(saml["signing_cert"]),
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}
}
func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) {
ldap, _ := pm["ldap"].(map[string]any)
bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password")
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required")
}
if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required")
}
if strings.TrimSpace(bindPassword) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required")
}
return LDAPSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
ServerURI: stringValue(ldap["server_uri"]),
BindDN: stringValue(ldap["bind_dn"]),
BindPassword: bindPassword,
BaseDN: stringValue(ldap["base_dn"]),
UserFilter: stringValue(ldap["user_filter"]),
StartTLS: boolValue(ldap["start_tls"]),
SyncUsers: boolValue(ldap["sync_users"]),
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}, nil
}
func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error {
for _, pm := range removed {
pk := authentikPK(pm["authentik_pk"])
if pk == "" {
continue
}
providerType, _ := pm["type"].(string)
switch providerType {
case "oauth":
_ = client.DeleteOAuthSource(ctx, pk)
case "saml":
_ = client.DeleteSAMLSource(ctx, pk)
case "ldap":
_ = client.DeleteLDAPSource(ctx, pk)
}
}
return nil
}
func RemovedIdentityProviders(before, after map[string]any) []map[string]any {
beforeSection, _ := before["identity_providers"].(map[string]any)
afterSection, _ := after["identity_providers"].(map[string]any)
beforeProviders, _ := beforeSection["providers"].([]any)
afterProviders, _ := afterSection["providers"].([]any)
afterIDs := make(map[string]struct{}, len(afterProviders))
for _, item := range afterProviders {
pm, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := pm["id"].(string); id != "" {
afterIDs[id] = struct{}{}
}
}
removed := make([]map[string]any, 0)
for _, item := range beforeProviders {
pm, ok := item.(map[string]any)
if !ok {
continue
}
id, _ := pm["id"].(string)
if id == "" {
continue
}
if _, stillPresent := afterIDs[id]; !stillPresent {
removed = append(removed, pm)
}
}
return removed
}
func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error {
providers, _ := idpSection["providers"].([]any)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
pm["sync_status"] = "error"
pm["sync_error"] = reason
}
idpSection["providers"] = providers
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
stored["identity_providers"] = idpSection
return s.persistPolicy(ctx, stored)
}
func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) {
var raw []byte
err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw)
if err != nil && err != pgx.ErrNoRows {
return nil, err
}
stored := map[string]any{}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &stored); err != nil {
return nil, err
}
}
return stored, nil
}
func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error {
raw, err := json.Marshal(policy)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `
UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1
`, orgSettingsSingletonID, raw)
return err
}
func (s *SourceSyncer) RedirectURI(slug string) string {
return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug)
}
func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
pm, err := findProviderByID(stored, providerID)
if err != nil {
return err
}
client := s.client()
if client == nil {
return fmt.Errorf("Authentik API not configured")
}
if err := client.Ping(ctx); err != nil {
return err
}
providerType, _ := pm["type"].(string)
switch providerType {
case "ldap":
ldap, _ := pm["ldap"].(map[string]any)
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
return fmt.Errorf("ldap server_uri required")
}
case "saml":
saml, _ := pm["saml"].(map[string]any)
if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" &&
strings.TrimSpace(stringValue(saml["sso_url"])) == "" {
return fmt.Errorf("saml metadata_url or sso_url required")
}
case "oauth":
oauth, _ := pm["oauth"].(map[string]any)
if strings.TrimSpace(stringValue(oauth["client_id"])) == "" {
return fmt.Errorf("oauth client_id required")
}
}
return nil
}
func findProviderByID(stored map[string]any, id string) (map[string]any, error) {
idpSection, _ := stored["identity_providers"].(map[string]any)
providers, _ := idpSection["providers"].([]any)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
if pid, _ := pm["id"].(string); pid == id {
return pm, nil
}
}
return nil, fmt.Errorf("provider not found")
}
func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string {
nested, _ := pm[section].(map[string]any)
if val := strings.TrimSpace(stringValue(nested[key])); val != "" {
return val
}
id, _ := pm["id"].(string)
if id == "" {
return ""
}
existing, _ := fullStored["identity_providers"].(map[string]any)
existingProviders, _ := existing["providers"].([]any)
for _, item := range existingProviders {
em, ok := item.(map[string]any)
if !ok {
continue
}
if eid, _ := em["id"].(string); eid != id {
continue
}
esec, _ := em[section].(map[string]any)
return strings.TrimSpace(stringValue(esec[key]))
}
return ""
}
func authentikPK(v any) string {
switch t := v.(type) {
case string:
return strings.TrimSpace(t)
case float64:
return strconv.Itoa(int(t))
case int:
return strconv.Itoa(t)
case int64:
return strconv.Itoa(int(t))
default:
return ""
}
}
func stringValue(v any) string {
s, _ := v.(string)
return s
}
func stringOr(v any, fallback string) string {
if s := strings.TrimSpace(stringValue(v)); s != "" {
return s
}
return fallback
}
func boolValue(v any) bool {
b, ok := v.(bool)
return ok && b
}

View File

@ -0,0 +1,322 @@
package authentik
import (
"context"
"fmt"
"net/url"
"strings"
)
type sourceRef struct {
PK string `json:"pk"`
Slug string `json:"slug"`
Name string `json:"name"`
}
type brandRef struct {
PK string `json:"pk"`
BrandUUID string `json:"brand_uuid"`
Domain string `json:"domain"`
FlowEnrollment *string `json:"flow_enrollment"`
}
type flowInstance struct {
PK string `json:"pk"`
Slug string `json:"slug"`
Name string `json:"name"`
}
func (c *Client) FindOAuthSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
q := url.Values{}
q.Set("slug", slug)
var out listResponse[sourceRef]
if err := c.getJSON(ctx, "/api/v3/sources/oauth/?"+q.Encode(), &out); err != nil {
return nil, false, err
}
if len(out.Results) == 0 {
return nil, false, nil
}
return &out.Results[0], true, nil
}
func (c *Client) FindSAMLSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
q := url.Values{}
q.Set("slug", slug)
var out listResponse[sourceRef]
if err := c.getJSON(ctx, "/api/v3/sources/saml/?"+q.Encode(), &out); err != nil {
return nil, false, err
}
if len(out.Results) == 0 {
return nil, false, nil
}
return &out.Results[0], true, nil
}
func (c *Client) FindLDAPSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
q := url.Values{}
q.Set("slug", slug)
var out listResponse[sourceRef]
if err := c.getJSON(ctx, "/api/v3/sources/ldap/?"+q.Encode(), &out); err != nil {
return nil, false, err
}
if len(out.Results) == 0 {
return nil, false, nil
}
return &out.Results[0], true, nil
}
type OAuthSourceRequest struct {
Name string
Slug string
Enabled bool
ProviderType string
ClientID string
ClientSecret string
AuthorizationURL string
AccessTokenURL string
ProfileURL string
Scopes string
AuthenticationFlow string
EnrollmentFlow string
}
func (c *Client) CreateOAuthSource(ctx context.Context, req OAuthSourceRequest) (string, error) {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"provider_type": req.ProviderType,
"consumer_key": req.ClientID,
"consumer_secret": req.ClientSecret,
"authorization_url": req.AuthorizationURL,
"access_token_url": req.AccessTokenURL,
"profile_url": req.ProfileURL,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
"user_matching_mode": "email_link",
"policy_engine_mode": "any",
"request_token_url": "",
"additional_scopes": req.Scopes,
}
var created sourceRef
if err := c.postJSON(ctx, "/api/v3/sources/oauth/", body, &created); err != nil {
return "", err
}
return created.PK, nil
}
func (c *Client) UpdateOAuthSource(ctx context.Context, pk string, req OAuthSourceRequest) error {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"provider_type": req.ProviderType,
"consumer_key": req.ClientID,
"authorization_url": req.AuthorizationURL,
"access_token_url": req.AccessTokenURL,
"profile_url": req.ProfileURL,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
"additional_scopes": req.Scopes,
}
if strings.TrimSpace(req.ClientSecret) != "" {
body["consumer_secret"] = req.ClientSecret
}
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk), body)
}
func (c *Client) DeleteOAuthSource(ctx context.Context, pk string) error {
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk))
}
type SAMLSourceRequest struct {
Name string
Slug string
Enabled bool
MetadataURL string
MetadataXML string
EntityID string
SSOURL string
SLOURL string
SigningCert string
AuthenticationFlow string
EnrollmentFlow string
}
func (c *Client) CreateSAMLSource(ctx context.Context, req SAMLSourceRequest) (string, error) {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"metadata_url": req.MetadataURL,
"sso_url": req.SSOURL,
"slo_url": req.SLOURL,
"issuer": req.EntityID,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
"user_matching_mode": "email_link",
"policy_engine_mode": "any",
}
if strings.TrimSpace(req.MetadataXML) != "" {
body["metadata"] = req.MetadataXML
}
if strings.TrimSpace(req.SigningCert) != "" {
body["verification_kp"] = req.SigningCert
}
var created sourceRef
if err := c.postJSON(ctx, "/api/v3/sources/saml/", body, &created); err != nil {
return "", err
}
return created.PK, nil
}
func (c *Client) UpdateSAMLSource(ctx context.Context, pk string, req SAMLSourceRequest) error {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"metadata_url": req.MetadataURL,
"sso_url": req.SSOURL,
"slo_url": req.SLOURL,
"issuer": req.EntityID,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
}
if strings.TrimSpace(req.MetadataXML) != "" {
body["metadata"] = req.MetadataXML
}
if strings.TrimSpace(req.SigningCert) != "" {
body["verification_kp"] = req.SigningCert
}
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk), body)
}
func (c *Client) DeleteSAMLSource(ctx context.Context, pk string) error {
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk))
}
type LDAPSourceRequest struct {
Name string
Slug string
Enabled bool
ServerURI string
BindDN string
BindPassword string
BaseDN string
UserFilter string
StartTLS bool
SyncUsers bool
AuthenticationFlow string
EnrollmentFlow string
}
func (c *Client) CreateLDAPSource(ctx context.Context, req LDAPSourceRequest) (string, error) {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"server_uri": req.ServerURI,
"bind_cn": req.BindDN,
"bind_password": req.BindPassword,
"base_dn": req.BaseDN,
"search_mode": "direct",
"start_tls": req.StartTLS,
"sync_users": req.SyncUsers,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
"user_matching_mode": "email_link",
"policy_engine_mode": "any",
}
if filter := strings.TrimSpace(req.UserFilter); filter != "" {
body["search_filter"] = filter
}
var created sourceRef
if err := c.postJSON(ctx, "/api/v3/sources/ldap/", body, &created); err != nil {
return "", err
}
return created.PK, nil
}
func (c *Client) UpdateLDAPSource(ctx context.Context, pk string, req LDAPSourceRequest) error {
body := map[string]any{
"name": req.Name,
"slug": req.Slug,
"enabled": req.Enabled,
"server_uri": req.ServerURI,
"bind_cn": req.BindDN,
"base_dn": req.BaseDN,
"start_tls": req.StartTLS,
"sync_users": req.SyncUsers,
"authentication_flow": req.AuthenticationFlow,
"enrollment_flow": req.EnrollmentFlow,
}
if filter := strings.TrimSpace(req.UserFilter); filter != "" {
body["search_filter"] = filter
}
if strings.TrimSpace(req.BindPassword) != "" {
body["bind_password"] = req.BindPassword
}
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk), body)
}
func (c *Client) DeleteLDAPSource(ctx context.Context, pk string) error {
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk))
}
func (c *Client) deleteJSON(ctx context.Context, path string) error {
return c.doJSON(ctx, "DELETE", path, nil, nil)
}
func (c *Client) FindDefaultBrand(ctx context.Context) (*brandRef, error) {
var out listResponse[brandRef]
if err := c.getJSON(ctx, "/api/v3/core/brands/", &out); err != nil {
return nil, err
}
for _, b := range out.Results {
if b.Domain == "authentik-default" || strings.Contains(strings.ToLower(b.Domain), "default") {
return &b, nil
}
}
if len(out.Results) > 0 {
return &out.Results[0], nil
}
return nil, fmt.Errorf("no brand found")
}
func (c *Client) SetBrandEnrollmentFlow(ctx context.Context, brandPK string, flowPK *string) error {
body := map[string]any{}
if flowPK == nil || strings.TrimSpace(*flowPK) == "" {
body["flow_enrollment"] = nil
} else {
body["flow_enrollment"] = *flowPK
}
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/core/brands/%s/", brandPK), body)
}
func (c *Client) FindFlowInstanceBySlug(ctx context.Context, slug string) (*flowInstance, bool, error) {
q := url.Values{}
q.Set("slug", slug)
var out listResponse[flowInstance]
if err := c.getJSON(ctx, "/api/v3/flows/instances/?"+q.Encode(), &out); err != nil {
return nil, false, err
}
if len(out.Results) == 0 {
return nil, false, nil
}
return &out.Results[0], true, nil
}
func (c *Client) ResolveSourceFlows(ctx context.Context) (authFlow, enrollmentFlow string, err error) {
authFlow, err = c.FindFlowBySlug(ctx, "default-source-authentication")
if err != nil {
authFlow, err = c.FindFlowBySlug(ctx, "default-authentication-flow")
}
if err != nil {
return "", "", err
}
enrollmentFlow, err = c.FindFlowBySlug(ctx, "default-source-enrollment")
if err != nil {
enrollmentFlow, _ = c.FindFlowBySlug(ctx, "default-enrollment-flow")
}
return authFlow, enrollmentFlow, nil
}

View File

@ -58,6 +58,78 @@ func TestAdminOrgSettings(t *testing.T) {
}
}
func TestAdminOrgSettingsIdentityProvidersSecrets(t *testing.T) {
h := integrationtest.RequireHarness(t)
adminClient, _ := integrationtest.RequireAdminClient(t, h)
providerID := "test-oauth-provider"
putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
"policy": map[string]any{
"identity_providers": map[string]any{
"allow_self_enrollment": true,
"providers": []any{
map[string]any{
"id": providerID,
"name": "Google Workspace",
"slug": "google-workspace",
"type": "oauth",
"enabled": true,
"oauth": map[string]any{
"provider": "google",
"client_id": "client-id",
"client_secret": "super-secret",
"scopes": "openid email profile",
},
"allowed_email_domains": []any{"company.com"},
},
},
},
},
})
integrationtest.FailIf(err, t, "put identity providers")
integrationtest.FailUnlessStatus(t, putResp, 200)
var afterPut map[string]any
integrationtest.DecodeJSON(t, putResp, &afterPut)
secrets, ok := afterPut["secrets"].(map[string]any)
if !ok {
t.Fatalf("missing secrets: %#v", afterPut)
}
idpSecrets, ok := secrets["identity_providers"].(map[string]any)
if !ok {
t.Fatalf("missing identity provider secrets: %#v", secrets)
}
providerSecrets, ok := idpSecrets[providerID].(map[string]any)
if !ok {
t.Fatalf("missing provider secrets entry: %#v", idpSecrets)
}
oauthSecret, ok := providerSecrets["oauth_client_secret"].(map[string]any)
if !ok || oauthSecret["configured"] != true {
t.Fatalf("oauth secret not configured: %#v", providerSecrets)
}
policy, ok := afterPut["policy"].(map[string]any)
if !ok {
t.Fatalf("missing policy")
}
idpPolicy, ok := policy["identity_providers"].(map[string]any)
if !ok {
t.Fatalf("missing identity_providers policy")
}
providers, ok := idpPolicy["providers"].([]any)
if !ok || len(providers) != 1 {
t.Fatalf("expected one provider")
}
provider, ok := providers[0].(map[string]any)
if !ok {
t.Fatalf("invalid provider payload")
}
oauth, ok := provider["oauth"].(map[string]any)
if !ok || oauth["client_secret"] != "" {
t.Fatalf("client_secret should be masked on GET")
}
}
func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) {
h := integrationtest.RequireHarness(t)
adminClient, _ := integrationtest.RequireAdminClient(t, h)

View File

@ -314,6 +314,24 @@ func (c *Client) GetPublicSharePermissions(ctx context.Context, token, password
return c.GetPublicSharePathPermissions(ctx, token, "/", password)
}
// EffectivePublicSharePermissions returns share permissions for a path.
// Nextcloud often omits oc:permissions on nested WebDAV nodes; fall back to root share bits.
func (c *Client) EffectivePublicSharePermissions(ctx context.Context, token, relPath, password string) (int, error) {
root, err := c.GetPublicSharePermissions(ctx, token, password)
if err != nil {
return 0, err
}
relPath = NormalizeClientPath(relPath)
if relPath == "/" {
return root, nil
}
pathPerms, err := c.GetPublicSharePathPermissions(ctx, token, relPath, password)
if err != nil || pathPerms == 0 {
return root, nil
}
return root | pathPerms, nil
}
const propfindPublicRevisionBody = `<?xml version="1.0" encoding="UTF-8"?>
<d:propfind xmlns:d="DAV:" xmlns:oc="http://owncloud.org/ns">
<d:prop>

144
internal/orgpolicy/auth.go Normal file
View File

@ -0,0 +1,144 @@
package orgpolicy
import (
"strings"
"github.com/ultisuite/ulti-backend/internal/auth"
)
type IdentityProviderPolicy struct {
ID string
Slug string
Type string
Enabled bool
AllowedEmailDomains []string
AllowedIdentities []string
AllowedOrganizations []string
}
type AuthAccessPolicy struct {
AllowSelfEnrollment bool
Providers []IdentityProviderPolicy
}
func (p AuthAccessPolicy) AllowsIdentity(email string, claims *auth.Claims) bool {
enabled := make([]IdentityProviderPolicy, 0)
for _, provider := range p.Providers {
if provider.Enabled {
enabled = append(enabled, provider)
}
}
if len(enabled) == 0 {
return true
}
if claims != nil && strings.TrimSpace(claims.Source) != "" {
for _, provider := range enabled {
if provider.Slug == claims.Source && providerAllows(provider, email, claims) {
return true
}
}
}
hasRestrictions := false
for _, provider := range enabled {
if providerHasRestrictions(provider) {
hasRestrictions = true
if providerAllows(provider, email, claims) {
return true
}
}
}
if !hasRestrictions {
return true
}
return false
}
func providerHasRestrictions(provider IdentityProviderPolicy) bool {
return len(provider.AllowedEmailDomains) > 0 ||
len(provider.AllowedIdentities) > 0 ||
len(provider.AllowedOrganizations) > 0
}
func providerAllows(provider IdentityProviderPolicy, email string, claims *auth.Claims) bool {
if len(provider.AllowedIdentities) > 0 {
if !containsFold(provider.AllowedIdentities, email) {
return false
}
}
if len(provider.AllowedEmailDomains) > 0 {
if !emailDomainAllowed(email, provider.AllowedEmailDomains) {
return false
}
}
if len(provider.AllowedOrganizations) > 0 {
if claims == nil || !organizationAllowed(claims, provider.AllowedOrganizations) {
return false
}
}
return true
}
func emailDomainAllowed(email string, domains []string) bool {
email = strings.ToLower(strings.TrimSpace(email))
at := strings.LastIndex(email, "@")
if at < 0 {
return false
}
domain := email[at+1:]
for _, allowed := range domains {
allowed = strings.ToLower(strings.TrimSpace(strings.TrimPrefix(allowed, "@")))
if allowed == "" {
continue
}
if domain == allowed || strings.HasSuffix(domain, "."+allowed) {
return true
}
}
return false
}
func organizationAllowed(claims *auth.Claims, orgs []string) bool {
candidates := []string{
claims.HD,
claims.TID,
claims.Org,
}
for _, org := range orgs {
org = strings.TrimSpace(org)
if org == "" {
continue
}
for _, candidate := range candidates {
if strings.EqualFold(strings.TrimSpace(candidate), org) {
return true
}
}
}
return false
}
func containsFold(list []string, value string) bool {
value = strings.ToLower(strings.TrimSpace(value))
for _, item := range list {
if strings.ToLower(strings.TrimSpace(item)) == value {
return true
}
}
return false
}
func stringSlice(v any) []string {
raw, ok := v.([]any)
if !ok {
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
out = append(out, strings.TrimSpace(s))
}
}
return out
}

View File

@ -0,0 +1,72 @@
package orgpolicy
import (
"testing"
"github.com/ultisuite/ulti-backend/internal/auth"
)
func TestAuthAccessPolicyAllowsOpenProviders(t *testing.T) {
policy := AuthAccessPolicy{
Providers: []IdentityProviderPolicy{
{Enabled: true, Slug: "google"},
},
}
claims := &auth.Claims{Email: "user@example.com"}
if !policy.AllowsIdentity(claims.Email, claims) {
t.Fatal("expected open provider to allow any identity")
}
}
func TestAuthAccessPolicyRejectsUnknownDomain(t *testing.T) {
policy := AuthAccessPolicy{
Providers: []IdentityProviderPolicy{
{
Enabled: true,
Slug: "google",
AllowedEmailDomains: []string{"company.com"},
},
},
}
claims := &auth.Claims{Email: "user@gmail.com"}
if policy.AllowsIdentity(claims.Email, claims) {
t.Fatal("expected domain restriction to reject identity")
}
}
func TestAuthAccessPolicyAllowsMatchingOrganization(t *testing.T) {
policy := AuthAccessPolicy{
Providers: []IdentityProviderPolicy{
{
Enabled: true,
Slug: "google",
AllowedOrganizations: []string{"company.com"},
},
},
}
claims := &auth.Claims{Email: "user@company.com", HD: "company.com"}
if !policy.AllowsIdentity(claims.Email, claims) {
t.Fatal("expected matching hosted domain to allow identity")
}
}
func TestAuthAccessPolicyMatchesSourceSpecificProvider(t *testing.T) {
policy := AuthAccessPolicy{
Providers: []IdentityProviderPolicy{
{
Enabled: true,
Slug: "corp-google",
AllowedEmailDomains: []string{"company.com"},
},
{
Enabled: true,
Slug: "partner-google",
AllowedEmailDomains: []string{"partner.org"},
},
},
}
claims := &auth.Claims{Email: "user@partner.org", Source: "partner-google"}
if !policy.AllowsIdentity(claims.Email, claims) {
t.Fatal("expected source-specific provider rules to allow identity")
}
}

View File

@ -28,10 +28,12 @@ type Loader struct {
db *pgxpool.Pool
cfg *config.Config
mu sync.Mutex
cached FilePolicies
cachedAt time.Time
ttl time.Duration
mu sync.Mutex
cached FilePolicies
cachedAt time.Time
authCached AuthAccessPolicy
authCachedAt time.Time
ttl time.Duration
}
func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader {
@ -120,6 +122,77 @@ func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) {
return true, fp.VirusTotalAPIKey, nil
}
func (l *Loader) AuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) {
l.mu.Lock()
if !l.authCachedAt.IsZero() && time.Since(l.authCachedAt) < l.ttl {
out := l.authCached
l.mu.Unlock()
return out, nil
}
l.mu.Unlock()
policy, err := l.loadAuthAccessPolicy(ctx)
if err != nil {
return AuthAccessPolicy{}, err
}
l.mu.Lock()
l.authCached = policy
l.authCachedAt = time.Now()
l.mu.Unlock()
return policy, nil
}
func (l *Loader) loadAuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) {
var raw []byte
err := l.db.QueryRow(ctx, `
SELECT settings FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil && err != pgx.ErrNoRows {
return AuthAccessPolicy{}, err
}
stored := map[string]any{}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &stored); err != nil {
return AuthAccessPolicy{}, err
}
}
idp, _ := stored["identity_providers"].(map[string]any)
if idp == nil {
return AuthAccessPolicy{AllowSelfEnrollment: true}, nil
}
allowSelfEnrollment := true
if v, ok := idp["allow_self_enrollment"].(bool); ok {
allowSelfEnrollment = v
}
providersRaw, _ := idp["providers"].([]any)
providers := make([]IdentityProviderPolicy, 0, len(providersRaw))
for _, item := range providersRaw {
pm, ok := item.(map[string]any)
if !ok {
continue
}
providers = append(providers, IdentityProviderPolicy{
ID: stringValue(pm["id"]),
Slug: stringValue(pm["slug"]),
Type: stringValue(pm["type"]),
Enabled: boolValue(pm["enabled"]),
AllowedEmailDomains: stringSlice(pm["allowed_email_domains"]),
AllowedIdentities: stringSlice(pm["allowed_identities"]),
AllowedOrganizations: stringSlice(pm["allowed_organizations"]),
})
}
return AuthAccessPolicy{
AllowSelfEnrollment: allowSelfEnrollment,
Providers: providers,
}, nil
}
func boolValue(v any) bool {
switch t := v.(type) {
case bool:

View File

@ -284,7 +284,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
JWTSecret: cfg.OnlyOfficeJWTSecret,
})
officeHandler := office.NewHandler(officeSvc, driveSvc)
r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger)))
r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader)))
driveHandler.SetPublicOffice(officeHandler)
}
if driveHandler != nil {
@ -292,7 +292,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
}
r.Group(func(r chi.Router) {
r.Use(middleware.Auth(verifierHolder, pool, auditLogger))
r.Use(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader))
r.Use(middleware.EnforceApiTokenPolicy())
r.Mount("/api/v1/users", usersapi.NewHandler(pool).Routes())