feat(identity-providers): add management for identity providers in admin API
- Introduced new endpoints for managing identity providers, including retrieval of redirect URIs and testing/syncing providers. - Enhanced organization settings to include identity provider configurations, allowing for self-enrollment and domain restrictions. - Implemented caching for access policies and added validation for identity provider secrets. - Added integration tests to ensure proper functionality of identity provider management and policy enforcement.
This commit is contained in:
parent
b90edf317c
commit
d3c930cac6
@ -60,6 +60,10 @@ func (h *Handler) Routes() chi.Router {
|
||||
r.With(read).Get("/org/settings", h.GetOrgSettings)
|
||||
r.With(write).Put("/org/settings", h.PutOrgSettings)
|
||||
|
||||
r.With(read).Get("/org/identity-providers/redirect-uri/{slug}", h.GetIdentityProviderRedirectURI)
|
||||
r.With(write).Post("/org/identity-providers/{providerID}/test", h.TestIdentityProvider)
|
||||
r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
98
internal/api/admin/identity_providers.go
Normal file
98
internal/api/admin/identity_providers.go
Normal file
@ -0,0 +1,98 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
)
|
||||
|
||||
func (s *Service) SyncIdentityProviders(ctx context.Context, actorSub, providerID string) (map[string]any, error) {
|
||||
stored, _, _, err := s.loadOrgPolicyRaw(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idpSection, _ := stored["identity_providers"].(map[string]any)
|
||||
if idpSection == nil {
|
||||
return nil, fmt.Errorf("identity_providers not configured")
|
||||
}
|
||||
providers, _ := idpSection["providers"].([]any)
|
||||
found := false
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if id, _ := pm["id"].(string); id == providerID {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
return nil, fmt.Errorf("provider not found")
|
||||
}
|
||||
|
||||
syncer := authentik.NewSourceSyncer(s.db, s.cfg)
|
||||
if err := syncer.SyncSection(ctx, idpSection); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.logAudit(ctx, actorSub, "sync_identity_provider", map[string]any{
|
||||
"provider_id": providerID,
|
||||
})
|
||||
return s.GetOrgSettings(ctx)
|
||||
}
|
||||
|
||||
func (h *Handler) GetIdentityProviderRedirectURI(w http.ResponseWriter, r *http.Request) {
|
||||
slug := chi.URLParam(r, "slug")
|
||||
if slug == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "slug", Message: "required"},
|
||||
))
|
||||
return
|
||||
}
|
||||
syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg)
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"slug": slug,
|
||||
"redirect_uri": syncer.RedirectURI(slug),
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) TestIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
providerID := chi.URLParam(r, "providerID")
|
||||
if providerID == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "providerID", Message: "required"},
|
||||
))
|
||||
return
|
||||
}
|
||||
syncer := authentik.NewSourceSyncer(h.svc.db, h.svc.cfg)
|
||||
if err := syncer.TestProvider(r.Context(), providerID); err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"ok": true})
|
||||
}
|
||||
|
||||
func (h *Handler) SyncIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
providerID := chi.URLParam(r, "providerID")
|
||||
if providerID == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(
|
||||
apivalidate.FieldDetail{Field: "providerID", Message: "required"},
|
||||
))
|
||||
return
|
||||
}
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
payload, err := h.svc.SyncIdentityProviders(r.Context(), claims.Sub, providerID)
|
||||
if err != nil {
|
||||
h.logger.Error("sync identity provider", "error", err, "provider_id", providerID)
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, payload)
|
||||
}
|
||||
@ -9,6 +9,7 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
)
|
||||
|
||||
const orgSettingsSingletonID = 1
|
||||
@ -24,6 +25,11 @@ func defaultOrgPolicy() map[string]any {
|
||||
"allow_password_fallback": false,
|
||||
"default_groups": "ulti-users",
|
||||
},
|
||||
"identity_providers": map[string]any{
|
||||
"allow_self_enrollment": true,
|
||||
"default_login_source": "",
|
||||
"providers": []any{},
|
||||
},
|
||||
"two_factor": map[string]any{
|
||||
"required_for_all": false,
|
||||
"required_for_admins": true,
|
||||
@ -204,6 +210,9 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
|
||||
mergeSearchProviderSecrets(existing, patchWS, merged)
|
||||
}
|
||||
}
|
||||
if patchIDP, ok := patch["identity_providers"].(map[string]any); ok {
|
||||
mergeIdentityProviderSecrets(existing, patchIDP, merged)
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
@ -291,6 +300,65 @@ func mergeSearchProviderSecrets(existing map[string]any, patchWS map[string]any,
|
||||
merged["search"] = mergedSearch
|
||||
}
|
||||
|
||||
func mergeIdentityProviderSecrets(existing, patchIDP, merged map[string]any) {
|
||||
patchProviders, _ := patchIDP["providers"].([]any)
|
||||
if len(patchProviders) == 0 {
|
||||
return
|
||||
}
|
||||
existingIDP, _ := existing["identity_providers"].(map[string]any)
|
||||
existingProviders, _ := existingIDP["providers"].([]any)
|
||||
mergedIDP, _ := merged["identity_providers"].(map[string]any)
|
||||
if mergedIDP == nil {
|
||||
return
|
||||
}
|
||||
mergedProviders := make([]any, len(patchProviders))
|
||||
for i, item := range patchProviders {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
mergedProviders[i] = item
|
||||
continue
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
for _, secretPath := range []struct{ section, key string }{
|
||||
{"oauth", "client_secret"},
|
||||
{"ldap", "bind_password"},
|
||||
{"saml", "signing_cert"},
|
||||
} {
|
||||
mergeProviderSecretField(pm, existingProviders, id, secretPath.section, secretPath.key)
|
||||
}
|
||||
mergedProviders[i] = pm
|
||||
}
|
||||
mergedIDP["providers"] = mergedProviders
|
||||
merged["identity_providers"] = mergedIDP
|
||||
}
|
||||
|
||||
func mergeProviderSecretField(pm map[string]any, existingProviders []any, id, section, key string) {
|
||||
nested, _ := pm[section].(map[string]any)
|
||||
if nested == nil {
|
||||
return
|
||||
}
|
||||
val, _ := nested[key].(string)
|
||||
if strings.TrimSpace(val) != "" || id == "" {
|
||||
return
|
||||
}
|
||||
for _, ep := range existingProviders {
|
||||
em, ok := ep.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if eid, _ := em["id"].(string); eid != id {
|
||||
continue
|
||||
}
|
||||
esec, _ := em[section].(map[string]any)
|
||||
if esec == nil {
|
||||
continue
|
||||
}
|
||||
if ek, ok := esec[key].(string); ok && ek != "" {
|
||||
nested[key] = ek
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func maskOrgPolicy(policy map[string]any) map[string]any {
|
||||
cloned := deepCloneMap(policy)
|
||||
maskStringField(cloned, "nextcloud", "admin_password")
|
||||
@ -331,9 +399,41 @@ func maskOrgPolicy(policy map[string]any) map[string]any {
|
||||
}
|
||||
}
|
||||
}
|
||||
maskIdentityProviderSecrets(cloned)
|
||||
return cloned
|
||||
}
|
||||
|
||||
func maskIdentityProviderSecrets(policy map[string]any) {
|
||||
idp, ok := policy["identity_providers"].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
providers, ok := idp["providers"].([]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for i, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, section := range []string{"oauth", "ldap", "saml"} {
|
||||
nested, ok := pm[section].(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, key := range []string{"client_secret", "bind_password", "signing_cert"} {
|
||||
if v, _ := nested[key].(string); strings.TrimSpace(v) != "" {
|
||||
nested[key] = ""
|
||||
}
|
||||
}
|
||||
pm[section] = nested
|
||||
}
|
||||
providers[i] = pm
|
||||
}
|
||||
idp["providers"] = providers
|
||||
}
|
||||
|
||||
func maskStringField(m map[string]any, section, key string) {
|
||||
sec, ok := m[section].(map[string]any)
|
||||
if !ok {
|
||||
@ -361,7 +461,7 @@ func secretConfigured(policy map[string]any, section, key string) bool {
|
||||
}
|
||||
|
||||
func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string]any {
|
||||
return map[string]any{
|
||||
secrets := map[string]any{
|
||||
"nextcloud_admin_password": map[string]any{
|
||||
"configured": secretConfigured(policy, "nextcloud", "admin_password") || strings.TrimSpace(cfg.NCAdminPass) != "",
|
||||
},
|
||||
@ -381,10 +481,57 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
|
||||
"configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "",
|
||||
},
|
||||
}
|
||||
if idpSecrets := buildIdentityProviderSecretsStatus(policy); len(idpSecrets) > 0 {
|
||||
secrets["identity_providers"] = idpSecrets
|
||||
}
|
||||
return secrets
|
||||
}
|
||||
|
||||
func buildIdentityProviderSecretsStatus(policy map[string]any) map[string]any {
|
||||
idp, ok := policy["identity_providers"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
providers, ok := idp["providers"].([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any)
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
entry := map[string]any{}
|
||||
if oauth, ok := pm["oauth"].(map[string]any); ok {
|
||||
if v, _ := oauth["client_secret"].(string); strings.TrimSpace(v) != "" {
|
||||
entry["oauth_client_secret"] = map[string]any{"configured": true}
|
||||
}
|
||||
}
|
||||
if ldap, ok := pm["ldap"].(map[string]any); ok {
|
||||
if v, _ := ldap["bind_password"].(string); strings.TrimSpace(v) != "" {
|
||||
entry["ldap_bind_password"] = map[string]any{"configured": true}
|
||||
}
|
||||
}
|
||||
if saml, ok := pm["saml"].(map[string]any); ok {
|
||||
if v, _ := saml["signing_cert"].(string); strings.TrimSpace(v) != "" {
|
||||
entry["saml_signing_cert"] = map[string]any{"configured": true}
|
||||
}
|
||||
}
|
||||
if len(entry) > 0 {
|
||||
out[id] = entry
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildOrgEffective(cfg *config.Config) map[string]any {
|
||||
authentikEnabled := strings.TrimSpace(cfg.AuthentikAPIToken) != "" || strings.TrimSpace(cfg.OIDCIssuer) != ""
|
||||
publicBase := authentik.AuthentikPublicBaseURL(cfg.AuthentikAPIURL, cfg.AuthentikPublicHTTPS)
|
||||
return map[string]any{
|
||||
"authentik": map[string]any{
|
||||
"enabled": authentikEnabled,
|
||||
@ -392,6 +539,10 @@ func buildOrgEffective(cfg *config.Config) map[string]any {
|
||||
"client_id": cfg.OIDCClientID,
|
||||
"issuer": cfg.OIDCIssuer,
|
||||
},
|
||||
"identity_providers": map[string]any{
|
||||
"authentik_public_url": publicBase,
|
||||
"oauth_redirect_template": publicBase + "/source/oauth/callback/{slug}/",
|
||||
},
|
||||
"nextcloud": map[string]any{
|
||||
"enabled": cfg.NextcloudEnabled,
|
||||
"base_url": firstNonEmpty(cfg.NextcloudPublicURL, cfg.NextcloudURL),
|
||||
@ -445,11 +596,11 @@ func (s *Service) GetOrgSettings(ctx context.Context) (map[string]any, error) {
|
||||
}
|
||||
|
||||
func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map[string]any) (map[string]any, error) {
|
||||
stored, _, _, err := s.loadOrgPolicyRaw(ctx)
|
||||
storedBefore, _, _, err := s.loadOrgPolicyRaw(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
merged := mergeOrgSecrets(stored, patch)
|
||||
merged := mergeOrgSecrets(storedBefore, patch)
|
||||
raw, err := json.Marshal(merged)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -468,6 +619,16 @@ func (s *Service) PutOrgSettings(ctx context.Context, actorSub string, patch map
|
||||
s.logAudit(ctx, actorSub, "update_org_settings", map[string]any{
|
||||
"sections": mapKeys(patch),
|
||||
})
|
||||
|
||||
if _, ok := patch["identity_providers"]; ok {
|
||||
removed := authentik.RemovedIdentityProviders(storedBefore, merged)
|
||||
syncer := authentik.NewSourceSyncer(s.db, s.cfg)
|
||||
idpSection, _ := merged["identity_providers"].(map[string]any)
|
||||
if err := syncer.SyncSectionWithRemoved(ctx, idpSection, removed); err != nil {
|
||||
s.logger.Warn("identity provider sync failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetOrgSettings(ctx)
|
||||
}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ const (
|
||||
CodeAuthInvalidToken = "auth.invalid_token"
|
||||
CodeAuthUnauthorized = "auth.unauthorized"
|
||||
CodeAuthForbidden = "auth.forbidden"
|
||||
CodeIdentityNotAllowed = "auth.identity_not_allowed"
|
||||
|
||||
CodeInvalidQueryParam = "invalid_query_param"
|
||||
CodeInvalidRequest = "invalid_request_body"
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/apitokens"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
"github.com/ultisuite/ulti-backend/internal/securityaudit"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
@ -23,7 +24,7 @@ const (
|
||||
apiTokenKey ctxKey = "api_token"
|
||||
)
|
||||
|
||||
func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger) func(http.Handler) http.Handler {
|
||||
func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger, orgPolicy *orgpolicy.Loader) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
header := r.Header.Get("Authorization")
|
||||
@ -148,6 +149,23 @@ func Auth(verifier *auth.Holder, db *pgxpool.Pool, audit *securityaudit.Logger)
|
||||
}
|
||||
return
|
||||
}
|
||||
if orgPolicy != nil {
|
||||
policy, err := orgPolicy.AuthAccessPolicy(r.Context())
|
||||
if err != nil {
|
||||
slog.Error("load auth access policy", "sub", claims.Sub, "error", err)
|
||||
} else if !policy.AllowsIdentity(claims.Email, claims) {
|
||||
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeIdentityNotAllowed, "identity not allowed by organization policy", nil)
|
||||
if audit != nil {
|
||||
audit.Log(r.Context(), claims.Sub, securityaudit.ActionTokenRejected, map[string]any{
|
||||
"reason": "identity_not_allowed",
|
||||
"email": claims.Email,
|
||||
"path": r.URL.Path,
|
||||
"method": r.Method,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
claims.Groups = permission.WithSuiteDefaults(claims.Groups)
|
||||
}
|
||||
|
||||
@ -56,7 +56,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) {
|
||||
if password == "" {
|
||||
password = publicSharePassword(r)
|
||||
}
|
||||
perms, err := h.svc.nc.GetPublicSharePathPermissions(r.Context(), token, req.Path, password)
|
||||
perms, err := h.svc.nc.EffectivePublicSharePermissions(r.Context(), token, req.Path, password)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
@ -81,6 +81,7 @@ func (h *Handler) PublicShareSession(w http.ResponseWriter, r *http.Request) {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"config": cfg,
|
||||
"serverUrl": h.svc.PublicURL(),
|
||||
"mode": mode,
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -16,6 +16,10 @@ type Claims struct {
|
||||
Email string
|
||||
Name string
|
||||
Groups []string
|
||||
Source string
|
||||
HD string
|
||||
TID string
|
||||
Org string
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
@ -96,6 +100,10 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error)
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Groups []string `json:"groups"`
|
||||
HD string `json:"hd"`
|
||||
TID string `json:"tid"`
|
||||
Org string `json:"org"`
|
||||
Source string `json:"ak-source"`
|
||||
}
|
||||
if err := token.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
@ -106,5 +114,9 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error)
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Groups: claims.Groups,
|
||||
HD: claims.HD,
|
||||
TID: claims.TID,
|
||||
Org: claims.Org,
|
||||
Source: claims.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
82
internal/authentik/presets.go
Normal file
82
internal/authentik/presets.go
Normal file
@ -0,0 +1,82 @@
|
||||
package authentik
|
||||
|
||||
import "strings"
|
||||
|
||||
type OAuthPreset struct {
|
||||
ProviderType string
|
||||
AuthorizationURL string
|
||||
AccessTokenURL string
|
||||
ProfileURL string
|
||||
DefaultScopes string
|
||||
OrganizationClaim string
|
||||
}
|
||||
|
||||
func OAuthPresetFor(provider string) OAuthPreset {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "google":
|
||||
return OAuthPreset{
|
||||
ProviderType: "google",
|
||||
AuthorizationURL: "https://accounts.google.com/o/oauth2/auth",
|
||||
AccessTokenURL: "https://oauth2.googleapis.com/token",
|
||||
ProfileURL: "https://www.googleapis.com/oauth2/v1/userinfo",
|
||||
DefaultScopes: "openid email profile",
|
||||
OrganizationClaim: "hd",
|
||||
}
|
||||
case "github":
|
||||
return OAuthPreset{
|
||||
ProviderType: "github",
|
||||
AuthorizationURL: "https://github.com/login/oauth/authorize",
|
||||
AccessTokenURL: "https://github.com/login/oauth/access_token",
|
||||
ProfileURL: "https://api.github.com/user",
|
||||
DefaultScopes: "read:user user:email",
|
||||
OrganizationClaim: "org",
|
||||
}
|
||||
case "linkedin":
|
||||
return OAuthPreset{
|
||||
ProviderType: "openidconnect",
|
||||
AuthorizationURL: "https://www.linkedin.com/oauth/v2/authorization",
|
||||
AccessTokenURL: "https://www.linkedin.com/oauth/v2/accessToken",
|
||||
ProfileURL: "https://api.linkedin.com/v2/userinfo",
|
||||
DefaultScopes: "openid profile email",
|
||||
OrganizationClaim: "",
|
||||
}
|
||||
case "microsoft":
|
||||
return OAuthPreset{
|
||||
ProviderType: "azuread",
|
||||
AuthorizationURL: "https://login.microsoftonline.com/common/oauth2/v2.0/authorize",
|
||||
AccessTokenURL: "https://login.microsoftonline.com/common/oauth2/v2.0/token",
|
||||
ProfileURL: "https://graph.microsoft.com/oidc/userinfo",
|
||||
DefaultScopes: "openid email profile",
|
||||
OrganizationClaim: "tid",
|
||||
}
|
||||
default:
|
||||
return OAuthPreset{
|
||||
ProviderType: "openidconnect",
|
||||
DefaultScopes: "openid email profile",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func OAuthRedirectURI(publicBaseURL, slug string) string {
|
||||
base := strings.TrimRight(strings.TrimSpace(publicBaseURL), "/")
|
||||
if base == "" {
|
||||
base = "http://localhost/auth"
|
||||
}
|
||||
return base + "/source/oauth/callback/" + strings.TrimSpace(slug) + "/"
|
||||
}
|
||||
|
||||
func AuthentikPublicBaseURL(apiURL string, publicHTTPS bool) string {
|
||||
apiURL = strings.TrimSpace(apiURL)
|
||||
if apiURL == "" {
|
||||
if publicHTTPS {
|
||||
return "https://localhost/auth"
|
||||
}
|
||||
return "http://localhost/auth"
|
||||
}
|
||||
u := strings.TrimSuffix(apiURL, "/")
|
||||
u = strings.TrimSuffix(u, "/api/v3")
|
||||
if publicHTTPS && strings.HasPrefix(u, "http://") {
|
||||
u = "https://" + strings.TrimPrefix(u, "http://")
|
||||
}
|
||||
return u
|
||||
}
|
||||
46
internal/authentik/presets_test.go
Normal file
46
internal/authentik/presets_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package authentik
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOAuthPresetForGoogle(t *testing.T) {
|
||||
preset := OAuthPresetFor("google")
|
||||
if preset.ProviderType != "google" {
|
||||
t.Fatalf("provider type = %q", preset.ProviderType)
|
||||
}
|
||||
if preset.AuthorizationURL == "" || preset.AccessTokenURL == "" {
|
||||
t.Fatal("expected google oauth endpoints")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOAuthRedirectURI(t *testing.T) {
|
||||
got := OAuthRedirectURI("http://localhost/auth", "google-workspace")
|
||||
want := "http://localhost/auth/source/oauth/callback/google-workspace/"
|
||||
if got != want {
|
||||
t.Fatalf("redirect uri = %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemovedIdentityProviders(t *testing.T) {
|
||||
before := map[string]any{
|
||||
"identity_providers": map[string]any{
|
||||
"providers": []any{
|
||||
map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"},
|
||||
map[string]any{"id": "b", "authentik_pk": 2, "type": "saml"},
|
||||
},
|
||||
},
|
||||
}
|
||||
after := map[string]any{
|
||||
"identity_providers": map[string]any{
|
||||
"providers": []any{
|
||||
map[string]any{"id": "a", "authentik_pk": 1, "type": "oauth"},
|
||||
},
|
||||
},
|
||||
}
|
||||
removed := RemovedIdentityProviders(before, after)
|
||||
if len(removed) != 1 {
|
||||
t.Fatalf("removed count = %d, want 1", len(removed))
|
||||
}
|
||||
if id, _ := removed[0]["id"].(string); id != "b" {
|
||||
t.Fatalf("removed id = %q", id)
|
||||
}
|
||||
}
|
||||
492
internal/authentik/source_sync.go
Normal file
492
internal/authentik/source_sync.go
Normal file
@ -0,0 +1,492 @@
|
||||
package authentik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
)
|
||||
|
||||
const orgSettingsSingletonID = 1
|
||||
|
||||
type SourceSyncer struct {
|
||||
db *pgxpool.Pool
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer {
|
||||
return &SourceSyncer{db: db, cfg: cfg}
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) client() *Client {
|
||||
apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL)
|
||||
token := strings.TrimSpace(s.cfg.AuthentikAPIToken)
|
||||
if apiURL == "" || token == "" {
|
||||
return nil
|
||||
}
|
||||
return NewClient(apiURL, token)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error {
|
||||
stored, err := s.loadStoredPolicy(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpSection, _ := stored["identity_providers"].(map[string]any)
|
||||
if idpSection == nil {
|
||||
return nil
|
||||
}
|
||||
return s.syncSection(ctx, idpSection, stored, nil)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error {
|
||||
stored, err := s.loadStoredPolicy(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.syncSection(ctx, idpSection, stored, nil)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error {
|
||||
stored, err := s.loadStoredPolicy(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.syncSection(ctx, idpSection, stored, removed)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error {
|
||||
client := s.client()
|
||||
if client == nil {
|
||||
return s.markAllPending(ctx, idpSection, "Authentik API not configured")
|
||||
}
|
||||
|
||||
if err := client.Ping(ctx); err != nil {
|
||||
return s.markAllPending(ctx, idpSection, err.Error())
|
||||
}
|
||||
|
||||
authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx)
|
||||
if err != nil {
|
||||
return s.markAllPending(ctx, idpSection, err.Error())
|
||||
}
|
||||
|
||||
allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool)
|
||||
if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
providers, _ := idpSection["providers"].([]any)
|
||||
now := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored)
|
||||
if syncErr != nil {
|
||||
pm["sync_status"] = "error"
|
||||
pm["sync_error"] = syncErr.Error()
|
||||
} else {
|
||||
pm["sync_status"] = "synced"
|
||||
pm["sync_error"] = ""
|
||||
pm["last_synced_at"] = now
|
||||
if pk != "" {
|
||||
if n, err := strconv.Atoi(pk); err == nil {
|
||||
pm["authentik_pk"] = n
|
||||
} else {
|
||||
pm["authentik_pk"] = pk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.deleteRemovedSources(ctx, client, removed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idpSection["providers"] = providers
|
||||
fullStored["identity_providers"] = idpSection
|
||||
return s.persistPolicy(ctx, fullStored)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error {
|
||||
brand, err := client.FindDefaultBrand(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if allow {
|
||||
flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !found {
|
||||
return nil
|
||||
}
|
||||
return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK)
|
||||
}
|
||||
return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) {
|
||||
providerType, _ := pm["type"].(string)
|
||||
enabled, _ := pm["enabled"].(bool)
|
||||
slug, _ := pm["slug"].(string)
|
||||
name, _ := pm["name"].(string)
|
||||
if slug == "" {
|
||||
return "", fmt.Errorf("missing slug")
|
||||
}
|
||||
if name == "" {
|
||||
name = slug
|
||||
}
|
||||
|
||||
pk := authentikPK(pm["authentik_pk"])
|
||||
switch providerType {
|
||||
case "oauth":
|
||||
req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pk == "" {
|
||||
createdPK, err := client.CreateOAuthSource(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return createdPK, nil
|
||||
}
|
||||
return pk, client.UpdateOAuthSource(ctx, pk, req)
|
||||
case "saml":
|
||||
req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name)
|
||||
if pk == "" {
|
||||
createdPK, err := client.CreateSAMLSource(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return createdPK, nil
|
||||
}
|
||||
return pk, client.UpdateSAMLSource(ctx, pk, req)
|
||||
case "ldap":
|
||||
req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if pk == "" {
|
||||
createdPK, err := client.CreateLDAPSource(ctx, req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return createdPK, nil
|
||||
}
|
||||
return pk, client.UpdateLDAPSource(ctx, pk, req)
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported provider type %q", providerType)
|
||||
}
|
||||
}
|
||||
|
||||
func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) {
|
||||
oauth, _ := pm["oauth"].(map[string]any)
|
||||
providerName, _ := oauth["provider"].(string)
|
||||
preset := OAuthPresetFor(providerName)
|
||||
|
||||
clientID, _ := oauth["client_id"].(string)
|
||||
clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret")
|
||||
|
||||
authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL)
|
||||
tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL)
|
||||
profileURL := stringOr(oauth["profile_url"], preset.ProfileURL)
|
||||
scopes := stringOr(oauth["scopes"], preset.DefaultScopes)
|
||||
providerType := preset.ProviderType
|
||||
if providerName == "custom" {
|
||||
providerType = "openidconnect"
|
||||
}
|
||||
|
||||
if strings.TrimSpace(clientID) == "" {
|
||||
return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required")
|
||||
}
|
||||
if strings.TrimSpace(clientSecret) == "" {
|
||||
return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required")
|
||||
}
|
||||
|
||||
return OAuthSourceRequest{
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Enabled: enabled,
|
||||
ProviderType: providerType,
|
||||
ClientID: clientID,
|
||||
ClientSecret: clientSecret,
|
||||
AuthorizationURL: authURL,
|
||||
AccessTokenURL: tokenURL,
|
||||
ProfileURL: profileURL,
|
||||
Scopes: scopes,
|
||||
AuthenticationFlow: authFlow,
|
||||
EnrollmentFlow: enrollmentFlow,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest {
|
||||
saml, _ := pm["saml"].(map[string]any)
|
||||
return SAMLSourceRequest{
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Enabled: enabled,
|
||||
MetadataURL: stringValue(saml["metadata_url"]),
|
||||
MetadataXML: stringValue(saml["metadata_xml"]),
|
||||
EntityID: stringValue(saml["entity_id"]),
|
||||
SSOURL: stringValue(saml["sso_url"]),
|
||||
SLOURL: stringValue(saml["slo_url"]),
|
||||
SigningCert: stringValue(saml["signing_cert"]),
|
||||
AuthenticationFlow: authFlow,
|
||||
EnrollmentFlow: enrollmentFlow,
|
||||
}
|
||||
}
|
||||
|
||||
func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) {
|
||||
ldap, _ := pm["ldap"].(map[string]any)
|
||||
bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password")
|
||||
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
|
||||
return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required")
|
||||
}
|
||||
if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" {
|
||||
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required")
|
||||
}
|
||||
if strings.TrimSpace(bindPassword) == "" {
|
||||
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required")
|
||||
}
|
||||
return LDAPSourceRequest{
|
||||
Name: name,
|
||||
Slug: slug,
|
||||
Enabled: enabled,
|
||||
ServerURI: stringValue(ldap["server_uri"]),
|
||||
BindDN: stringValue(ldap["bind_dn"]),
|
||||
BindPassword: bindPassword,
|
||||
BaseDN: stringValue(ldap["base_dn"]),
|
||||
UserFilter: stringValue(ldap["user_filter"]),
|
||||
StartTLS: boolValue(ldap["start_tls"]),
|
||||
SyncUsers: boolValue(ldap["sync_users"]),
|
||||
AuthenticationFlow: authFlow,
|
||||
EnrollmentFlow: enrollmentFlow,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error {
|
||||
for _, pm := range removed {
|
||||
pk := authentikPK(pm["authentik_pk"])
|
||||
if pk == "" {
|
||||
continue
|
||||
}
|
||||
providerType, _ := pm["type"].(string)
|
||||
switch providerType {
|
||||
case "oauth":
|
||||
_ = client.DeleteOAuthSource(ctx, pk)
|
||||
case "saml":
|
||||
_ = client.DeleteSAMLSource(ctx, pk)
|
||||
case "ldap":
|
||||
_ = client.DeleteLDAPSource(ctx, pk)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemovedIdentityProviders(before, after map[string]any) []map[string]any {
|
||||
beforeSection, _ := before["identity_providers"].(map[string]any)
|
||||
afterSection, _ := after["identity_providers"].(map[string]any)
|
||||
beforeProviders, _ := beforeSection["providers"].([]any)
|
||||
afterProviders, _ := afterSection["providers"].([]any)
|
||||
|
||||
afterIDs := make(map[string]struct{}, len(afterProviders))
|
||||
for _, item := range afterProviders {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if id, _ := pm["id"].(string); id != "" {
|
||||
afterIDs[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
removed := make([]map[string]any, 0)
|
||||
for _, item := range beforeProviders {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, stillPresent := afterIDs[id]; !stillPresent {
|
||||
removed = append(removed, pm)
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error {
|
||||
providers, _ := idpSection["providers"].([]any)
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
pm["sync_status"] = "error"
|
||||
pm["sync_error"] = reason
|
||||
}
|
||||
idpSection["providers"] = providers
|
||||
stored, err := s.loadStoredPolicy(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stored["identity_providers"] = idpSection
|
||||
return s.persistPolicy(ctx, stored)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) {
|
||||
var raw []byte
|
||||
err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw)
|
||||
if err != nil && err != pgx.ErrNoRows {
|
||||
return nil, err
|
||||
}
|
||||
stored := map[string]any{}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return stored, nil
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error {
|
||||
raw, err := json.Marshal(policy)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1
|
||||
`, orgSettingsSingletonID, raw)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) RedirectURI(slug string) string {
|
||||
return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug)
|
||||
}
|
||||
|
||||
func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error {
|
||||
stored, err := s.loadStoredPolicy(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pm, err := findProviderByID(stored, providerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
client := s.client()
|
||||
if client == nil {
|
||||
return fmt.Errorf("Authentik API not configured")
|
||||
}
|
||||
if err := client.Ping(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
providerType, _ := pm["type"].(string)
|
||||
switch providerType {
|
||||
case "ldap":
|
||||
ldap, _ := pm["ldap"].(map[string]any)
|
||||
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
|
||||
return fmt.Errorf("ldap server_uri required")
|
||||
}
|
||||
case "saml":
|
||||
saml, _ := pm["saml"].(map[string]any)
|
||||
if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" &&
|
||||
strings.TrimSpace(stringValue(saml["sso_url"])) == "" {
|
||||
return fmt.Errorf("saml metadata_url or sso_url required")
|
||||
}
|
||||
case "oauth":
|
||||
oauth, _ := pm["oauth"].(map[string]any)
|
||||
if strings.TrimSpace(stringValue(oauth["client_id"])) == "" {
|
||||
return fmt.Errorf("oauth client_id required")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func findProviderByID(stored map[string]any, id string) (map[string]any, error) {
|
||||
idpSection, _ := stored["identity_providers"].(map[string]any)
|
||||
providers, _ := idpSection["providers"].([]any)
|
||||
for _, item := range providers {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if pid, _ := pm["id"].(string); pid == id {
|
||||
return pm, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("provider not found")
|
||||
}
|
||||
|
||||
func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string {
|
||||
nested, _ := pm[section].(map[string]any)
|
||||
if val := strings.TrimSpace(stringValue(nested[key])); val != "" {
|
||||
return val
|
||||
}
|
||||
id, _ := pm["id"].(string)
|
||||
if id == "" {
|
||||
return ""
|
||||
}
|
||||
existing, _ := fullStored["identity_providers"].(map[string]any)
|
||||
existingProviders, _ := existing["providers"].([]any)
|
||||
for _, item := range existingProviders {
|
||||
em, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if eid, _ := em["id"].(string); eid != id {
|
||||
continue
|
||||
}
|
||||
esec, _ := em[section].(map[string]any)
|
||||
return strings.TrimSpace(stringValue(esec[key]))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func authentikPK(v any) string {
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(t)
|
||||
case float64:
|
||||
return strconv.Itoa(int(t))
|
||||
case int:
|
||||
return strconv.Itoa(t)
|
||||
case int64:
|
||||
return strconv.Itoa(int(t))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
|
||||
func stringOr(v any, fallback string) string {
|
||||
if s := strings.TrimSpace(stringValue(v)); s != "" {
|
||||
return s
|
||||
}
|
||||
return fallback
|
||||
}
|
||||
|
||||
func boolValue(v any) bool {
|
||||
b, ok := v.(bool)
|
||||
return ok && b
|
||||
}
|
||||
322
internal/authentik/sources.go
Normal file
322
internal/authentik/sources.go
Normal file
@ -0,0 +1,322 @@
|
||||
package authentik
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type sourceRef struct {
|
||||
PK string `json:"pk"`
|
||||
Slug string `json:"slug"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type brandRef struct {
|
||||
PK string `json:"pk"`
|
||||
BrandUUID string `json:"brand_uuid"`
|
||||
Domain string `json:"domain"`
|
||||
FlowEnrollment *string `json:"flow_enrollment"`
|
||||
}
|
||||
|
||||
type flowInstance struct {
|
||||
PK string `json:"pk"`
|
||||
Slug string `json:"slug"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (c *Client) FindOAuthSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
|
||||
q := url.Values{}
|
||||
q.Set("slug", slug)
|
||||
var out listResponse[sourceRef]
|
||||
if err := c.getJSON(ctx, "/api/v3/sources/oauth/?"+q.Encode(), &out); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(out.Results) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &out.Results[0], true, nil
|
||||
}
|
||||
|
||||
func (c *Client) FindSAMLSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
|
||||
q := url.Values{}
|
||||
q.Set("slug", slug)
|
||||
var out listResponse[sourceRef]
|
||||
if err := c.getJSON(ctx, "/api/v3/sources/saml/?"+q.Encode(), &out); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(out.Results) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &out.Results[0], true, nil
|
||||
}
|
||||
|
||||
func (c *Client) FindLDAPSourceBySlug(ctx context.Context, slug string) (*sourceRef, bool, error) {
|
||||
q := url.Values{}
|
||||
q.Set("slug", slug)
|
||||
var out listResponse[sourceRef]
|
||||
if err := c.getJSON(ctx, "/api/v3/sources/ldap/?"+q.Encode(), &out); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(out.Results) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &out.Results[0], true, nil
|
||||
}
|
||||
|
||||
type OAuthSourceRequest struct {
|
||||
Name string
|
||||
Slug string
|
||||
Enabled bool
|
||||
ProviderType string
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
AuthorizationURL string
|
||||
AccessTokenURL string
|
||||
ProfileURL string
|
||||
Scopes string
|
||||
AuthenticationFlow string
|
||||
EnrollmentFlow string
|
||||
}
|
||||
|
||||
func (c *Client) CreateOAuthSource(ctx context.Context, req OAuthSourceRequest) (string, error) {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"provider_type": req.ProviderType,
|
||||
"consumer_key": req.ClientID,
|
||||
"consumer_secret": req.ClientSecret,
|
||||
"authorization_url": req.AuthorizationURL,
|
||||
"access_token_url": req.AccessTokenURL,
|
||||
"profile_url": req.ProfileURL,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
"user_matching_mode": "email_link",
|
||||
"policy_engine_mode": "any",
|
||||
"request_token_url": "",
|
||||
"additional_scopes": req.Scopes,
|
||||
}
|
||||
var created sourceRef
|
||||
if err := c.postJSON(ctx, "/api/v3/sources/oauth/", body, &created); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return created.PK, nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateOAuthSource(ctx context.Context, pk string, req OAuthSourceRequest) error {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"provider_type": req.ProviderType,
|
||||
"consumer_key": req.ClientID,
|
||||
"authorization_url": req.AuthorizationURL,
|
||||
"access_token_url": req.AccessTokenURL,
|
||||
"profile_url": req.ProfileURL,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
"additional_scopes": req.Scopes,
|
||||
}
|
||||
if strings.TrimSpace(req.ClientSecret) != "" {
|
||||
body["consumer_secret"] = req.ClientSecret
|
||||
}
|
||||
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk), body)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteOAuthSource(ctx context.Context, pk string) error {
|
||||
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/oauth/%s/", pk))
|
||||
}
|
||||
|
||||
type SAMLSourceRequest struct {
|
||||
Name string
|
||||
Slug string
|
||||
Enabled bool
|
||||
MetadataURL string
|
||||
MetadataXML string
|
||||
EntityID string
|
||||
SSOURL string
|
||||
SLOURL string
|
||||
SigningCert string
|
||||
AuthenticationFlow string
|
||||
EnrollmentFlow string
|
||||
}
|
||||
|
||||
func (c *Client) CreateSAMLSource(ctx context.Context, req SAMLSourceRequest) (string, error) {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"metadata_url": req.MetadataURL,
|
||||
"sso_url": req.SSOURL,
|
||||
"slo_url": req.SLOURL,
|
||||
"issuer": req.EntityID,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
"user_matching_mode": "email_link",
|
||||
"policy_engine_mode": "any",
|
||||
}
|
||||
if strings.TrimSpace(req.MetadataXML) != "" {
|
||||
body["metadata"] = req.MetadataXML
|
||||
}
|
||||
if strings.TrimSpace(req.SigningCert) != "" {
|
||||
body["verification_kp"] = req.SigningCert
|
||||
}
|
||||
var created sourceRef
|
||||
if err := c.postJSON(ctx, "/api/v3/sources/saml/", body, &created); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return created.PK, nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateSAMLSource(ctx context.Context, pk string, req SAMLSourceRequest) error {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"metadata_url": req.MetadataURL,
|
||||
"sso_url": req.SSOURL,
|
||||
"slo_url": req.SLOURL,
|
||||
"issuer": req.EntityID,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
}
|
||||
if strings.TrimSpace(req.MetadataXML) != "" {
|
||||
body["metadata"] = req.MetadataXML
|
||||
}
|
||||
if strings.TrimSpace(req.SigningCert) != "" {
|
||||
body["verification_kp"] = req.SigningCert
|
||||
}
|
||||
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk), body)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteSAMLSource(ctx context.Context, pk string) error {
|
||||
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/saml/%s/", pk))
|
||||
}
|
||||
|
||||
type LDAPSourceRequest struct {
|
||||
Name string
|
||||
Slug string
|
||||
Enabled bool
|
||||
ServerURI string
|
||||
BindDN string
|
||||
BindPassword string
|
||||
BaseDN string
|
||||
UserFilter string
|
||||
StartTLS bool
|
||||
SyncUsers bool
|
||||
AuthenticationFlow string
|
||||
EnrollmentFlow string
|
||||
}
|
||||
|
||||
func (c *Client) CreateLDAPSource(ctx context.Context, req LDAPSourceRequest) (string, error) {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"server_uri": req.ServerURI,
|
||||
"bind_cn": req.BindDN,
|
||||
"bind_password": req.BindPassword,
|
||||
"base_dn": req.BaseDN,
|
||||
"search_mode": "direct",
|
||||
"start_tls": req.StartTLS,
|
||||
"sync_users": req.SyncUsers,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
"user_matching_mode": "email_link",
|
||||
"policy_engine_mode": "any",
|
||||
}
|
||||
if filter := strings.TrimSpace(req.UserFilter); filter != "" {
|
||||
body["search_filter"] = filter
|
||||
}
|
||||
var created sourceRef
|
||||
if err := c.postJSON(ctx, "/api/v3/sources/ldap/", body, &created); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return created.PK, nil
|
||||
}
|
||||
|
||||
func (c *Client) UpdateLDAPSource(ctx context.Context, pk string, req LDAPSourceRequest) error {
|
||||
body := map[string]any{
|
||||
"name": req.Name,
|
||||
"slug": req.Slug,
|
||||
"enabled": req.Enabled,
|
||||
"server_uri": req.ServerURI,
|
||||
"bind_cn": req.BindDN,
|
||||
"base_dn": req.BaseDN,
|
||||
"start_tls": req.StartTLS,
|
||||
"sync_users": req.SyncUsers,
|
||||
"authentication_flow": req.AuthenticationFlow,
|
||||
"enrollment_flow": req.EnrollmentFlow,
|
||||
}
|
||||
if filter := strings.TrimSpace(req.UserFilter); filter != "" {
|
||||
body["search_filter"] = filter
|
||||
}
|
||||
if strings.TrimSpace(req.BindPassword) != "" {
|
||||
body["bind_password"] = req.BindPassword
|
||||
}
|
||||
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk), body)
|
||||
}
|
||||
|
||||
func (c *Client) DeleteLDAPSource(ctx context.Context, pk string) error {
|
||||
return c.deleteJSON(ctx, fmt.Sprintf("/api/v3/sources/ldap/%s/", pk))
|
||||
}
|
||||
|
||||
func (c *Client) deleteJSON(ctx context.Context, path string) error {
|
||||
return c.doJSON(ctx, "DELETE", path, nil, nil)
|
||||
}
|
||||
|
||||
func (c *Client) FindDefaultBrand(ctx context.Context) (*brandRef, error) {
|
||||
var out listResponse[brandRef]
|
||||
if err := c.getJSON(ctx, "/api/v3/core/brands/", &out); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, b := range out.Results {
|
||||
if b.Domain == "authentik-default" || strings.Contains(strings.ToLower(b.Domain), "default") {
|
||||
return &b, nil
|
||||
}
|
||||
}
|
||||
if len(out.Results) > 0 {
|
||||
return &out.Results[0], nil
|
||||
}
|
||||
return nil, fmt.Errorf("no brand found")
|
||||
}
|
||||
|
||||
func (c *Client) SetBrandEnrollmentFlow(ctx context.Context, brandPK string, flowPK *string) error {
|
||||
body := map[string]any{}
|
||||
if flowPK == nil || strings.TrimSpace(*flowPK) == "" {
|
||||
body["flow_enrollment"] = nil
|
||||
} else {
|
||||
body["flow_enrollment"] = *flowPK
|
||||
}
|
||||
return c.patchJSON(ctx, fmt.Sprintf("/api/v3/core/brands/%s/", brandPK), body)
|
||||
}
|
||||
|
||||
func (c *Client) FindFlowInstanceBySlug(ctx context.Context, slug string) (*flowInstance, bool, error) {
|
||||
q := url.Values{}
|
||||
q.Set("slug", slug)
|
||||
var out listResponse[flowInstance]
|
||||
if err := c.getJSON(ctx, "/api/v3/flows/instances/?"+q.Encode(), &out); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if len(out.Results) == 0 {
|
||||
return nil, false, nil
|
||||
}
|
||||
return &out.Results[0], true, nil
|
||||
}
|
||||
|
||||
func (c *Client) ResolveSourceFlows(ctx context.Context) (authFlow, enrollmentFlow string, err error) {
|
||||
authFlow, err = c.FindFlowBySlug(ctx, "default-source-authentication")
|
||||
if err != nil {
|
||||
authFlow, err = c.FindFlowBySlug(ctx, "default-authentication-flow")
|
||||
}
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
enrollmentFlow, err = c.FindFlowBySlug(ctx, "default-source-enrollment")
|
||||
if err != nil {
|
||||
enrollmentFlow, _ = c.FindFlowBySlug(ctx, "default-enrollment-flow")
|
||||
}
|
||||
return authFlow, enrollmentFlow, nil
|
||||
}
|
||||
@ -58,6 +58,78 @@ func TestAdminOrgSettings(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminOrgSettingsIdentityProvidersSecrets(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
adminClient, _ := integrationtest.RequireAdminClient(t, h)
|
||||
|
||||
providerID := "test-oauth-provider"
|
||||
putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
|
||||
"policy": map[string]any{
|
||||
"identity_providers": map[string]any{
|
||||
"allow_self_enrollment": true,
|
||||
"providers": []any{
|
||||
map[string]any{
|
||||
"id": providerID,
|
||||
"name": "Google Workspace",
|
||||
"slug": "google-workspace",
|
||||
"type": "oauth",
|
||||
"enabled": true,
|
||||
"oauth": map[string]any{
|
||||
"provider": "google",
|
||||
"client_id": "client-id",
|
||||
"client_secret": "super-secret",
|
||||
"scopes": "openid email profile",
|
||||
},
|
||||
"allowed_email_domains": []any{"company.com"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
integrationtest.FailIf(err, t, "put identity providers")
|
||||
integrationtest.FailUnlessStatus(t, putResp, 200)
|
||||
|
||||
var afterPut map[string]any
|
||||
integrationtest.DecodeJSON(t, putResp, &afterPut)
|
||||
secrets, ok := afterPut["secrets"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing secrets: %#v", afterPut)
|
||||
}
|
||||
idpSecrets, ok := secrets["identity_providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing identity provider secrets: %#v", secrets)
|
||||
}
|
||||
providerSecrets, ok := idpSecrets[providerID].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing provider secrets entry: %#v", idpSecrets)
|
||||
}
|
||||
oauthSecret, ok := providerSecrets["oauth_client_secret"].(map[string]any)
|
||||
if !ok || oauthSecret["configured"] != true {
|
||||
t.Fatalf("oauth secret not configured: %#v", providerSecrets)
|
||||
}
|
||||
|
||||
policy, ok := afterPut["policy"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing policy")
|
||||
}
|
||||
idpPolicy, ok := policy["identity_providers"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing identity_providers policy")
|
||||
}
|
||||
providers, ok := idpPolicy["providers"].([]any)
|
||||
if !ok || len(providers) != 1 {
|
||||
t.Fatalf("expected one provider")
|
||||
}
|
||||
provider, ok := providers[0].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("invalid provider payload")
|
||||
}
|
||||
oauth, ok := provider["oauth"].(map[string]any)
|
||||
if !ok || oauth["client_secret"] != "" {
|
||||
t.Fatalf("client_secret should be masked on GET")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
adminClient, _ := integrationtest.RequireAdminClient(t, h)
|
||||
|
||||
@ -314,6 +314,24 @@ func (c *Client) GetPublicSharePermissions(ctx context.Context, token, password
|
||||
return c.GetPublicSharePathPermissions(ctx, token, "/", password)
|
||||
}
|
||||
|
||||
// EffectivePublicSharePermissions returns share permissions for a path.
|
||||
// Nextcloud often omits oc:permissions on nested WebDAV nodes; fall back to root share bits.
|
||||
func (c *Client) EffectivePublicSharePermissions(ctx context.Context, token, relPath, password string) (int, error) {
|
||||
root, err := c.GetPublicSharePermissions(ctx, token, password)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
relPath = NormalizeClientPath(relPath)
|
||||
if relPath == "/" {
|
||||
return root, nil
|
||||
}
|
||||
pathPerms, err := c.GetPublicSharePathPermissions(ctx, token, relPath, password)
|
||||
if err != nil || pathPerms == 0 {
|
||||
return root, nil
|
||||
}
|
||||
return root | pathPerms, nil
|
||||
}
|
||||
|
||||
const propfindPublicRevisionBody = `<?xml version="1.0" encoding="UTF-8"?>
|
||||
<d:propfind xmlns:d="DAV:" xmlns:oc="http://owncloud.org/ns">
|
||||
<d:prop>
|
||||
|
||||
144
internal/orgpolicy/auth.go
Normal file
144
internal/orgpolicy/auth.go
Normal file
@ -0,0 +1,144 @@
|
||||
package orgpolicy
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
)
|
||||
|
||||
type IdentityProviderPolicy struct {
|
||||
ID string
|
||||
Slug string
|
||||
Type string
|
||||
Enabled bool
|
||||
AllowedEmailDomains []string
|
||||
AllowedIdentities []string
|
||||
AllowedOrganizations []string
|
||||
}
|
||||
|
||||
type AuthAccessPolicy struct {
|
||||
AllowSelfEnrollment bool
|
||||
Providers []IdentityProviderPolicy
|
||||
}
|
||||
|
||||
func (p AuthAccessPolicy) AllowsIdentity(email string, claims *auth.Claims) bool {
|
||||
enabled := make([]IdentityProviderPolicy, 0)
|
||||
for _, provider := range p.Providers {
|
||||
if provider.Enabled {
|
||||
enabled = append(enabled, provider)
|
||||
}
|
||||
}
|
||||
if len(enabled) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if claims != nil && strings.TrimSpace(claims.Source) != "" {
|
||||
for _, provider := range enabled {
|
||||
if provider.Slug == claims.Source && providerAllows(provider, email, claims) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
hasRestrictions := false
|
||||
for _, provider := range enabled {
|
||||
if providerHasRestrictions(provider) {
|
||||
hasRestrictions = true
|
||||
if providerAllows(provider, email, claims) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
if !hasRestrictions {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func providerHasRestrictions(provider IdentityProviderPolicy) bool {
|
||||
return len(provider.AllowedEmailDomains) > 0 ||
|
||||
len(provider.AllowedIdentities) > 0 ||
|
||||
len(provider.AllowedOrganizations) > 0
|
||||
}
|
||||
|
||||
func providerAllows(provider IdentityProviderPolicy, email string, claims *auth.Claims) bool {
|
||||
if len(provider.AllowedIdentities) > 0 {
|
||||
if !containsFold(provider.AllowedIdentities, email) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(provider.AllowedEmailDomains) > 0 {
|
||||
if !emailDomainAllowed(email, provider.AllowedEmailDomains) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
if len(provider.AllowedOrganizations) > 0 {
|
||||
if claims == nil || !organizationAllowed(claims, provider.AllowedOrganizations) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func emailDomainAllowed(email string, domains []string) bool {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
at := strings.LastIndex(email, "@")
|
||||
if at < 0 {
|
||||
return false
|
||||
}
|
||||
domain := email[at+1:]
|
||||
for _, allowed := range domains {
|
||||
allowed = strings.ToLower(strings.TrimSpace(strings.TrimPrefix(allowed, "@")))
|
||||
if allowed == "" {
|
||||
continue
|
||||
}
|
||||
if domain == allowed || strings.HasSuffix(domain, "."+allowed) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func organizationAllowed(claims *auth.Claims, orgs []string) bool {
|
||||
candidates := []string{
|
||||
claims.HD,
|
||||
claims.TID,
|
||||
claims.Org,
|
||||
}
|
||||
for _, org := range orgs {
|
||||
org = strings.TrimSpace(org)
|
||||
if org == "" {
|
||||
continue
|
||||
}
|
||||
for _, candidate := range candidates {
|
||||
if strings.EqualFold(strings.TrimSpace(candidate), org) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsFold(list []string, value string) bool {
|
||||
value = strings.ToLower(strings.TrimSpace(value))
|
||||
for _, item := range list {
|
||||
if strings.ToLower(strings.TrimSpace(item)) == value {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func stringSlice(v any) []string {
|
||||
raw, ok := v.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
if s, ok := item.(string); ok && strings.TrimSpace(s) != "" {
|
||||
out = append(out, strings.TrimSpace(s))
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
72
internal/orgpolicy/auth_test.go
Normal file
72
internal/orgpolicy/auth_test.go
Normal file
@ -0,0 +1,72 @@
|
||||
package orgpolicy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
)
|
||||
|
||||
func TestAuthAccessPolicyAllowsOpenProviders(t *testing.T) {
|
||||
policy := AuthAccessPolicy{
|
||||
Providers: []IdentityProviderPolicy{
|
||||
{Enabled: true, Slug: "google"},
|
||||
},
|
||||
}
|
||||
claims := &auth.Claims{Email: "user@example.com"}
|
||||
if !policy.AllowsIdentity(claims.Email, claims) {
|
||||
t.Fatal("expected open provider to allow any identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAccessPolicyRejectsUnknownDomain(t *testing.T) {
|
||||
policy := AuthAccessPolicy{
|
||||
Providers: []IdentityProviderPolicy{
|
||||
{
|
||||
Enabled: true,
|
||||
Slug: "google",
|
||||
AllowedEmailDomains: []string{"company.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
claims := &auth.Claims{Email: "user@gmail.com"}
|
||||
if policy.AllowsIdentity(claims.Email, claims) {
|
||||
t.Fatal("expected domain restriction to reject identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAccessPolicyAllowsMatchingOrganization(t *testing.T) {
|
||||
policy := AuthAccessPolicy{
|
||||
Providers: []IdentityProviderPolicy{
|
||||
{
|
||||
Enabled: true,
|
||||
Slug: "google",
|
||||
AllowedOrganizations: []string{"company.com"},
|
||||
},
|
||||
},
|
||||
}
|
||||
claims := &auth.Claims{Email: "user@company.com", HD: "company.com"}
|
||||
if !policy.AllowsIdentity(claims.Email, claims) {
|
||||
t.Fatal("expected matching hosted domain to allow identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthAccessPolicyMatchesSourceSpecificProvider(t *testing.T) {
|
||||
policy := AuthAccessPolicy{
|
||||
Providers: []IdentityProviderPolicy{
|
||||
{
|
||||
Enabled: true,
|
||||
Slug: "corp-google",
|
||||
AllowedEmailDomains: []string{"company.com"},
|
||||
},
|
||||
{
|
||||
Enabled: true,
|
||||
Slug: "partner-google",
|
||||
AllowedEmailDomains: []string{"partner.org"},
|
||||
},
|
||||
},
|
||||
}
|
||||
claims := &auth.Claims{Email: "user@partner.org", Source: "partner-google"}
|
||||
if !policy.AllowsIdentity(claims.Email, claims) {
|
||||
t.Fatal("expected source-specific provider rules to allow identity")
|
||||
}
|
||||
}
|
||||
@ -28,10 +28,12 @@ type Loader struct {
|
||||
db *pgxpool.Pool
|
||||
cfg *config.Config
|
||||
|
||||
mu sync.Mutex
|
||||
cached FilePolicies
|
||||
cachedAt time.Time
|
||||
ttl time.Duration
|
||||
mu sync.Mutex
|
||||
cached FilePolicies
|
||||
cachedAt time.Time
|
||||
authCached AuthAccessPolicy
|
||||
authCachedAt time.Time
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader {
|
||||
@ -120,6 +122,77 @@ func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) {
|
||||
return true, fp.VirusTotalAPIKey, nil
|
||||
}
|
||||
|
||||
func (l *Loader) AuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) {
|
||||
l.mu.Lock()
|
||||
if !l.authCachedAt.IsZero() && time.Since(l.authCachedAt) < l.ttl {
|
||||
out := l.authCached
|
||||
l.mu.Unlock()
|
||||
return out, nil
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
policy, err := l.loadAuthAccessPolicy(ctx)
|
||||
if err != nil {
|
||||
return AuthAccessPolicy{}, err
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
l.authCached = policy
|
||||
l.authCachedAt = time.Now()
|
||||
l.mu.Unlock()
|
||||
return policy, nil
|
||||
}
|
||||
|
||||
func (l *Loader) loadAuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) {
|
||||
var raw []byte
|
||||
err := l.db.QueryRow(ctx, `
|
||||
SELECT settings FROM org_settings WHERE id = $1
|
||||
`, orgSettingsSingletonID).Scan(&raw)
|
||||
if err != nil && err != pgx.ErrNoRows {
|
||||
return AuthAccessPolicy{}, err
|
||||
}
|
||||
|
||||
stored := map[string]any{}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||
return AuthAccessPolicy{}, err
|
||||
}
|
||||
}
|
||||
|
||||
idp, _ := stored["identity_providers"].(map[string]any)
|
||||
if idp == nil {
|
||||
return AuthAccessPolicy{AllowSelfEnrollment: true}, nil
|
||||
}
|
||||
|
||||
allowSelfEnrollment := true
|
||||
if v, ok := idp["allow_self_enrollment"].(bool); ok {
|
||||
allowSelfEnrollment = v
|
||||
}
|
||||
|
||||
providersRaw, _ := idp["providers"].([]any)
|
||||
providers := make([]IdentityProviderPolicy, 0, len(providersRaw))
|
||||
for _, item := range providersRaw {
|
||||
pm, ok := item.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
providers = append(providers, IdentityProviderPolicy{
|
||||
ID: stringValue(pm["id"]),
|
||||
Slug: stringValue(pm["slug"]),
|
||||
Type: stringValue(pm["type"]),
|
||||
Enabled: boolValue(pm["enabled"]),
|
||||
AllowedEmailDomains: stringSlice(pm["allowed_email_domains"]),
|
||||
AllowedIdentities: stringSlice(pm["allowed_identities"]),
|
||||
AllowedOrganizations: stringSlice(pm["allowed_organizations"]),
|
||||
})
|
||||
}
|
||||
|
||||
return AuthAccessPolicy{
|
||||
AllowSelfEnrollment: allowSelfEnrollment,
|
||||
Providers: providers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func boolValue(v any) bool {
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
|
||||
@ -284,7 +284,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
JWTSecret: cfg.OnlyOfficeJWTSecret,
|
||||
})
|
||||
officeHandler := office.NewHandler(officeSvc, driveSvc)
|
||||
r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger)))
|
||||
r.Mount("/api/v1/office", officeHandler.Routes(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader)))
|
||||
driveHandler.SetPublicOffice(officeHandler)
|
||||
}
|
||||
if driveHandler != nil {
|
||||
@ -292,7 +292,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
}
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.Auth(verifierHolder, pool, auditLogger))
|
||||
r.Use(middleware.Auth(verifierHolder, pool, auditLogger, orgPolicyLoader))
|
||||
r.Use(middleware.EnforceApiTokenPolicy())
|
||||
|
||||
r.Mount("/api/v1/users", usersapi.NewHandler(pool).Routes())
|
||||
|
||||
Loading…
Reference in New Issue
Block a user