package authentik import ( "context" "encoding/json" "fmt" "strconv" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/config" ) const orgSettingsSingletonID = 1 type SourceSyncer struct { db *pgxpool.Pool cfg *config.Config } func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer { return &SourceSyncer{db: db, cfg: cfg} } func (s *SourceSyncer) client() *Client { apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL) token := strings.TrimSpace(s.cfg.AuthentikAPIToken) if apiURL == "" || token == "" { return nil } return NewClient(apiURL, token) } func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error { stored, err := s.loadStoredPolicy(ctx) if err != nil { return err } idpSection, _ := stored["identity_providers"].(map[string]any) if idpSection == nil { return nil } return s.syncSection(ctx, idpSection, stored, nil) } func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error { stored, err := s.loadStoredPolicy(ctx) if err != nil { return err } return s.syncSection(ctx, idpSection, stored, nil) } func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error { stored, err := s.loadStoredPolicy(ctx) if err != nil { return err } return s.syncSection(ctx, idpSection, stored, removed) } func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error { client := s.client() if client == nil { return s.markAllPending(ctx, idpSection, "Authentik API not configured") } if err := client.Ping(ctx); err != nil { return s.markAllPending(ctx, idpSection, err.Error()) } authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx) if err != nil { return s.markAllPending(ctx, idpSection, err.Error()) } allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool) if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil { return err } providers, _ := idpSection["providers"].([]any) now := time.Now().UTC().Format(time.RFC3339) for _, item := range providers { pm, ok := item.(map[string]any) if !ok { continue } id, _ := pm["id"].(string) if id == "" { continue } pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored) if syncErr != nil { pm["sync_status"] = "error" pm["sync_error"] = syncErr.Error() } else { pm["sync_status"] = "synced" pm["sync_error"] = "" pm["last_synced_at"] = now if pk != "" { if n, err := strconv.Atoi(pk); err == nil { pm["authentik_pk"] = n } else { pm["authentik_pk"] = pk } } } } if err := s.deleteRemovedSources(ctx, client, removed); err != nil { return err } idpSection["providers"] = providers fullStored["identity_providers"] = idpSection return s.persistPolicy(ctx, fullStored) } func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error { brand, err := client.FindDefaultBrand(ctx) if err != nil { return err } if allow { flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment") if err != nil { return err } if !found { return nil } return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK) } return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil) } func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) { providerType, _ := pm["type"].(string) enabled, _ := pm["enabled"].(bool) slug, _ := pm["slug"].(string) name, _ := pm["name"].(string) if slug == "" { return "", fmt.Errorf("missing slug") } if name == "" { name = slug } pk := authentikPK(pm["authentik_pk"]) switch providerType { case "oauth": req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name) if err != nil { return "", err } if pk == "" { createdPK, err := client.CreateOAuthSource(ctx, req) if err != nil { return "", err } return createdPK, nil } return pk, client.UpdateOAuthSource(ctx, pk, req) case "saml": req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name) if pk == "" { createdPK, err := client.CreateSAMLSource(ctx, req) if err != nil { return "", err } return createdPK, nil } return pk, client.UpdateSAMLSource(ctx, pk, req) case "ldap": req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name) if err != nil { return "", err } if pk == "" { createdPK, err := client.CreateLDAPSource(ctx, req) if err != nil { return "", err } return createdPK, nil } return pk, client.UpdateLDAPSource(ctx, pk, req) default: return "", fmt.Errorf("unsupported provider type %q", providerType) } } func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) { oauth, _ := pm["oauth"].(map[string]any) providerName, _ := oauth["provider"].(string) preset := OAuthPresetFor(providerName) clientID, _ := oauth["client_id"].(string) clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret") authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL) tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL) profileURL := stringOr(oauth["profile_url"], preset.ProfileURL) scopes := stringOr(oauth["scopes"], preset.DefaultScopes) providerType := preset.ProviderType if providerName == "custom" { providerType = "openidconnect" } if strings.TrimSpace(clientID) == "" { return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required") } if strings.TrimSpace(clientSecret) == "" { return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required") } return OAuthSourceRequest{ Name: name, Slug: slug, Enabled: enabled, ProviderType: providerType, ClientID: clientID, ClientSecret: clientSecret, AuthorizationURL: authURL, AccessTokenURL: tokenURL, ProfileURL: profileURL, Scopes: scopes, AuthenticationFlow: authFlow, EnrollmentFlow: enrollmentFlow, }, nil } func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest { saml, _ := pm["saml"].(map[string]any) return SAMLSourceRequest{ Name: name, Slug: slug, Enabled: enabled, MetadataURL: stringValue(saml["metadata_url"]), MetadataXML: stringValue(saml["metadata_xml"]), EntityID: stringValue(saml["entity_id"]), SSOURL: stringValue(saml["sso_url"]), SLOURL: stringValue(saml["slo_url"]), SigningCert: stringValue(saml["signing_cert"]), AuthenticationFlow: authFlow, EnrollmentFlow: enrollmentFlow, } } func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) { ldap, _ := pm["ldap"].(map[string]any) bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password") if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" { return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required") } if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" { return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required") } if strings.TrimSpace(bindPassword) == "" { return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required") } return LDAPSourceRequest{ Name: name, Slug: slug, Enabled: enabled, ServerURI: stringValue(ldap["server_uri"]), BindDN: stringValue(ldap["bind_dn"]), BindPassword: bindPassword, BaseDN: stringValue(ldap["base_dn"]), UserFilter: stringValue(ldap["user_filter"]), StartTLS: boolValue(ldap["start_tls"]), SyncUsers: boolValue(ldap["sync_users"]), AuthenticationFlow: authFlow, EnrollmentFlow: enrollmentFlow, }, nil } func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error { for _, pm := range removed { pk := authentikPK(pm["authentik_pk"]) if pk == "" { continue } providerType, _ := pm["type"].(string) switch providerType { case "oauth": _ = client.DeleteOAuthSource(ctx, pk) case "saml": _ = client.DeleteSAMLSource(ctx, pk) case "ldap": _ = client.DeleteLDAPSource(ctx, pk) } } return nil } func RemovedIdentityProviders(before, after map[string]any) []map[string]any { beforeSection, _ := before["identity_providers"].(map[string]any) afterSection, _ := after["identity_providers"].(map[string]any) beforeProviders, _ := beforeSection["providers"].([]any) afterProviders, _ := afterSection["providers"].([]any) afterIDs := make(map[string]struct{}, len(afterProviders)) for _, item := range afterProviders { pm, ok := item.(map[string]any) if !ok { continue } if id, _ := pm["id"].(string); id != "" { afterIDs[id] = struct{}{} } } removed := make([]map[string]any, 0) for _, item := range beforeProviders { pm, ok := item.(map[string]any) if !ok { continue } id, _ := pm["id"].(string) if id == "" { continue } if _, stillPresent := afterIDs[id]; !stillPresent { removed = append(removed, pm) } } return removed } func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error { providers, _ := idpSection["providers"].([]any) for _, item := range providers { pm, ok := item.(map[string]any) if !ok { continue } pm["sync_status"] = "error" pm["sync_error"] = reason } idpSection["providers"] = providers stored, err := s.loadStoredPolicy(ctx) if err != nil { return err } stored["identity_providers"] = idpSection return s.persistPolicy(ctx, stored) } func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) { var raw []byte err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw) if err != nil && err != pgx.ErrNoRows { return nil, err } stored := map[string]any{} if len(raw) > 0 { if err := json.Unmarshal(raw, &stored); err != nil { return nil, err } } return stored, nil } func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error { raw, err := json.Marshal(policy) if err != nil { return err } _, err = s.db.Exec(ctx, ` UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1 `, orgSettingsSingletonID, raw) return err } func (s *SourceSyncer) RedirectURI(slug string) string { return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug) } func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error { stored, err := s.loadStoredPolicy(ctx) if err != nil { return err } pm, err := findProviderByID(stored, providerID) if err != nil { return err } client := s.client() if client == nil { return fmt.Errorf("Authentik API not configured") } if err := client.Ping(ctx); err != nil { return err } providerType, _ := pm["type"].(string) switch providerType { case "ldap": ldap, _ := pm["ldap"].(map[string]any) if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" { return fmt.Errorf("ldap server_uri required") } case "saml": saml, _ := pm["saml"].(map[string]any) if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" && strings.TrimSpace(stringValue(saml["sso_url"])) == "" { return fmt.Errorf("saml metadata_url or sso_url required") } case "oauth": oauth, _ := pm["oauth"].(map[string]any) if strings.TrimSpace(stringValue(oauth["client_id"])) == "" { return fmt.Errorf("oauth client_id required") } } return nil } func findProviderByID(stored map[string]any, id string) (map[string]any, error) { idpSection, _ := stored["identity_providers"].(map[string]any) providers, _ := idpSection["providers"].([]any) for _, item := range providers { pm, ok := item.(map[string]any) if !ok { continue } if pid, _ := pm["id"].(string); pid == id { return pm, nil } } return nil, fmt.Errorf("provider not found") } func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string { nested, _ := pm[section].(map[string]any) if val := strings.TrimSpace(stringValue(nested[key])); val != "" { return val } id, _ := pm["id"].(string) if id == "" { return "" } existing, _ := fullStored["identity_providers"].(map[string]any) existingProviders, _ := existing["providers"].([]any) for _, item := range existingProviders { em, ok := item.(map[string]any) if !ok { continue } if eid, _ := em["id"].(string); eid != id { continue } esec, _ := em[section].(map[string]any) return strings.TrimSpace(stringValue(esec[key])) } return "" } func authentikPK(v any) string { switch t := v.(type) { case string: return strings.TrimSpace(t) case float64: return strconv.Itoa(int(t)) case int: return strconv.Itoa(t) case int64: return strconv.Itoa(int(t)) default: return "" } } func stringValue(v any) string { s, _ := v.(string) return s } func stringOr(v any, fallback string) string { if s := strings.TrimSpace(stringValue(v)); s != "" { return s } return fallback } func boolValue(v any) bool { b, ok := v.(bool) return ok && b }