ultisuite-backend/internal/authentik/source_sync.go
R3D347HR4Y d3c930cac6
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
feat(identity-providers): add management for identity providers in admin API
- Introduced new endpoints for managing identity providers, including retrieval of redirect URIs and testing/syncing providers.
- Enhanced organization settings to include identity provider configurations, allowing for self-enrollment and domain restrictions.
- Implemented caching for access policies and added validation for identity provider secrets.
- Added integration tests to ensure proper functionality of identity provider management and policy enforcement.
2026-06-09 09:36:38 +02:00

493 lines
14 KiB
Go

package authentik
import (
"context"
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/config"
)
const orgSettingsSingletonID = 1
type SourceSyncer struct {
db *pgxpool.Pool
cfg *config.Config
}
func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer {
return &SourceSyncer{db: db, cfg: cfg}
}
func (s *SourceSyncer) client() *Client {
apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL)
token := strings.TrimSpace(s.cfg.AuthentikAPIToken)
if apiURL == "" || token == "" {
return nil
}
return NewClient(apiURL, token)
}
func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
idpSection, _ := stored["identity_providers"].(map[string]any)
if idpSection == nil {
return nil
}
return s.syncSection(ctx, idpSection, stored, nil)
}
func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
return s.syncSection(ctx, idpSection, stored, nil)
}
func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
return s.syncSection(ctx, idpSection, stored, removed)
}
func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error {
client := s.client()
if client == nil {
return s.markAllPending(ctx, idpSection, "Authentik API not configured")
}
if err := client.Ping(ctx); err != nil {
return s.markAllPending(ctx, idpSection, err.Error())
}
authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx)
if err != nil {
return s.markAllPending(ctx, idpSection, err.Error())
}
allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool)
if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil {
return err
}
providers, _ := idpSection["providers"].([]any)
now := time.Now().UTC().Format(time.RFC3339)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
id, _ := pm["id"].(string)
if id == "" {
continue
}
pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored)
if syncErr != nil {
pm["sync_status"] = "error"
pm["sync_error"] = syncErr.Error()
} else {
pm["sync_status"] = "synced"
pm["sync_error"] = ""
pm["last_synced_at"] = now
if pk != "" {
if n, err := strconv.Atoi(pk); err == nil {
pm["authentik_pk"] = n
} else {
pm["authentik_pk"] = pk
}
}
}
}
if err := s.deleteRemovedSources(ctx, client, removed); err != nil {
return err
}
idpSection["providers"] = providers
fullStored["identity_providers"] = idpSection
return s.persistPolicy(ctx, fullStored)
}
func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error {
brand, err := client.FindDefaultBrand(ctx)
if err != nil {
return err
}
if allow {
flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment")
if err != nil {
return err
}
if !found {
return nil
}
return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK)
}
return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil)
}
func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) {
providerType, _ := pm["type"].(string)
enabled, _ := pm["enabled"].(bool)
slug, _ := pm["slug"].(string)
name, _ := pm["name"].(string)
if slug == "" {
return "", fmt.Errorf("missing slug")
}
if name == "" {
name = slug
}
pk := authentikPK(pm["authentik_pk"])
switch providerType {
case "oauth":
req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
if err != nil {
return "", err
}
if pk == "" {
createdPK, err := client.CreateOAuthSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateOAuthSource(ctx, pk, req)
case "saml":
req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name)
if pk == "" {
createdPK, err := client.CreateSAMLSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateSAMLSource(ctx, pk, req)
case "ldap":
req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
if err != nil {
return "", err
}
if pk == "" {
createdPK, err := client.CreateLDAPSource(ctx, req)
if err != nil {
return "", err
}
return createdPK, nil
}
return pk, client.UpdateLDAPSource(ctx, pk, req)
default:
return "", fmt.Errorf("unsupported provider type %q", providerType)
}
}
func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) {
oauth, _ := pm["oauth"].(map[string]any)
providerName, _ := oauth["provider"].(string)
preset := OAuthPresetFor(providerName)
clientID, _ := oauth["client_id"].(string)
clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret")
authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL)
tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL)
profileURL := stringOr(oauth["profile_url"], preset.ProfileURL)
scopes := stringOr(oauth["scopes"], preset.DefaultScopes)
providerType := preset.ProviderType
if providerName == "custom" {
providerType = "openidconnect"
}
if strings.TrimSpace(clientID) == "" {
return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required")
}
if strings.TrimSpace(clientSecret) == "" {
return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required")
}
return OAuthSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
ProviderType: providerType,
ClientID: clientID,
ClientSecret: clientSecret,
AuthorizationURL: authURL,
AccessTokenURL: tokenURL,
ProfileURL: profileURL,
Scopes: scopes,
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}, nil
}
func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest {
saml, _ := pm["saml"].(map[string]any)
return SAMLSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
MetadataURL: stringValue(saml["metadata_url"]),
MetadataXML: stringValue(saml["metadata_xml"]),
EntityID: stringValue(saml["entity_id"]),
SSOURL: stringValue(saml["sso_url"]),
SLOURL: stringValue(saml["slo_url"]),
SigningCert: stringValue(saml["signing_cert"]),
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}
}
func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) {
ldap, _ := pm["ldap"].(map[string]any)
bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password")
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required")
}
if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required")
}
if strings.TrimSpace(bindPassword) == "" {
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required")
}
return LDAPSourceRequest{
Name: name,
Slug: slug,
Enabled: enabled,
ServerURI: stringValue(ldap["server_uri"]),
BindDN: stringValue(ldap["bind_dn"]),
BindPassword: bindPassword,
BaseDN: stringValue(ldap["base_dn"]),
UserFilter: stringValue(ldap["user_filter"]),
StartTLS: boolValue(ldap["start_tls"]),
SyncUsers: boolValue(ldap["sync_users"]),
AuthenticationFlow: authFlow,
EnrollmentFlow: enrollmentFlow,
}, nil
}
func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error {
for _, pm := range removed {
pk := authentikPK(pm["authentik_pk"])
if pk == "" {
continue
}
providerType, _ := pm["type"].(string)
switch providerType {
case "oauth":
_ = client.DeleteOAuthSource(ctx, pk)
case "saml":
_ = client.DeleteSAMLSource(ctx, pk)
case "ldap":
_ = client.DeleteLDAPSource(ctx, pk)
}
}
return nil
}
func RemovedIdentityProviders(before, after map[string]any) []map[string]any {
beforeSection, _ := before["identity_providers"].(map[string]any)
afterSection, _ := after["identity_providers"].(map[string]any)
beforeProviders, _ := beforeSection["providers"].([]any)
afterProviders, _ := afterSection["providers"].([]any)
afterIDs := make(map[string]struct{}, len(afterProviders))
for _, item := range afterProviders {
pm, ok := item.(map[string]any)
if !ok {
continue
}
if id, _ := pm["id"].(string); id != "" {
afterIDs[id] = struct{}{}
}
}
removed := make([]map[string]any, 0)
for _, item := range beforeProviders {
pm, ok := item.(map[string]any)
if !ok {
continue
}
id, _ := pm["id"].(string)
if id == "" {
continue
}
if _, stillPresent := afterIDs[id]; !stillPresent {
removed = append(removed, pm)
}
}
return removed
}
func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error {
providers, _ := idpSection["providers"].([]any)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
pm["sync_status"] = "error"
pm["sync_error"] = reason
}
idpSection["providers"] = providers
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
stored["identity_providers"] = idpSection
return s.persistPolicy(ctx, stored)
}
func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) {
var raw []byte
err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw)
if err != nil && err != pgx.ErrNoRows {
return nil, err
}
stored := map[string]any{}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &stored); err != nil {
return nil, err
}
}
return stored, nil
}
func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error {
raw, err := json.Marshal(policy)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `
UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1
`, orgSettingsSingletonID, raw)
return err
}
func (s *SourceSyncer) RedirectURI(slug string) string {
return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug)
}
func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error {
stored, err := s.loadStoredPolicy(ctx)
if err != nil {
return err
}
pm, err := findProviderByID(stored, providerID)
if err != nil {
return err
}
client := s.client()
if client == nil {
return fmt.Errorf("Authentik API not configured")
}
if err := client.Ping(ctx); err != nil {
return err
}
providerType, _ := pm["type"].(string)
switch providerType {
case "ldap":
ldap, _ := pm["ldap"].(map[string]any)
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
return fmt.Errorf("ldap server_uri required")
}
case "saml":
saml, _ := pm["saml"].(map[string]any)
if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" &&
strings.TrimSpace(stringValue(saml["sso_url"])) == "" {
return fmt.Errorf("saml metadata_url or sso_url required")
}
case "oauth":
oauth, _ := pm["oauth"].(map[string]any)
if strings.TrimSpace(stringValue(oauth["client_id"])) == "" {
return fmt.Errorf("oauth client_id required")
}
}
return nil
}
func findProviderByID(stored map[string]any, id string) (map[string]any, error) {
idpSection, _ := stored["identity_providers"].(map[string]any)
providers, _ := idpSection["providers"].([]any)
for _, item := range providers {
pm, ok := item.(map[string]any)
if !ok {
continue
}
if pid, _ := pm["id"].(string); pid == id {
return pm, nil
}
}
return nil, fmt.Errorf("provider not found")
}
func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string {
nested, _ := pm[section].(map[string]any)
if val := strings.TrimSpace(stringValue(nested[key])); val != "" {
return val
}
id, _ := pm["id"].(string)
if id == "" {
return ""
}
existing, _ := fullStored["identity_providers"].(map[string]any)
existingProviders, _ := existing["providers"].([]any)
for _, item := range existingProviders {
em, ok := item.(map[string]any)
if !ok {
continue
}
if eid, _ := em["id"].(string); eid != id {
continue
}
esec, _ := em[section].(map[string]any)
return strings.TrimSpace(stringValue(esec[key]))
}
return ""
}
func authentikPK(v any) string {
switch t := v.(type) {
case string:
return strings.TrimSpace(t)
case float64:
return strconv.Itoa(int(t))
case int:
return strconv.Itoa(t)
case int64:
return strconv.Itoa(int(t))
default:
return ""
}
}
func stringValue(v any) string {
s, _ := v.(string)
return s
}
func stringOr(v any, fallback string) string {
if s := strings.TrimSpace(stringValue(v)); s != "" {
return s
}
return fallback
}
func boolValue(v any) bool {
b, ok := v.(bool)
return ok && b
}