- Introduced new endpoints for managing identity providers, including retrieval of redirect URIs and testing/syncing providers. - Enhanced organization settings to include identity provider configurations, allowing for self-enrollment and domain restrictions. - Implemented caching for access policies and added validation for identity provider secrets. - Added integration tests to ensure proper functionality of identity provider management and policy enforcement.
493 lines
14 KiB
Go
493 lines
14 KiB
Go
package authentik
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/config"
|
|
)
|
|
|
|
const orgSettingsSingletonID = 1
|
|
|
|
type SourceSyncer struct {
|
|
db *pgxpool.Pool
|
|
cfg *config.Config
|
|
}
|
|
|
|
func NewSourceSyncer(db *pgxpool.Pool, cfg *config.Config) *SourceSyncer {
|
|
return &SourceSyncer{db: db, cfg: cfg}
|
|
}
|
|
|
|
func (s *SourceSyncer) client() *Client {
|
|
apiURL := strings.TrimSpace(s.cfg.AuthentikAPIURL)
|
|
token := strings.TrimSpace(s.cfg.AuthentikAPIToken)
|
|
if apiURL == "" || token == "" {
|
|
return nil
|
|
}
|
|
return NewClient(apiURL, token)
|
|
}
|
|
|
|
func (s *SourceSyncer) SyncFromStoredPolicy(ctx context.Context) error {
|
|
stored, err := s.loadStoredPolicy(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
idpSection, _ := stored["identity_providers"].(map[string]any)
|
|
if idpSection == nil {
|
|
return nil
|
|
}
|
|
return s.syncSection(ctx, idpSection, stored, nil)
|
|
}
|
|
|
|
func (s *SourceSyncer) SyncSection(ctx context.Context, idpSection map[string]any) error {
|
|
stored, err := s.loadStoredPolicy(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.syncSection(ctx, idpSection, stored, nil)
|
|
}
|
|
|
|
func (s *SourceSyncer) SyncSectionWithRemoved(ctx context.Context, idpSection map[string]any, removed []map[string]any) error {
|
|
stored, err := s.loadStoredPolicy(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return s.syncSection(ctx, idpSection, stored, removed)
|
|
}
|
|
|
|
func (s *SourceSyncer) syncSection(ctx context.Context, idpSection, fullStored map[string]any, removed []map[string]any) error {
|
|
client := s.client()
|
|
if client == nil {
|
|
return s.markAllPending(ctx, idpSection, "Authentik API not configured")
|
|
}
|
|
|
|
if err := client.Ping(ctx); err != nil {
|
|
return s.markAllPending(ctx, idpSection, err.Error())
|
|
}
|
|
|
|
authFlow, enrollmentFlow, err := client.ResolveSourceFlows(ctx)
|
|
if err != nil {
|
|
return s.markAllPending(ctx, idpSection, err.Error())
|
|
}
|
|
|
|
allowSelfEnrollment, _ := idpSection["allow_self_enrollment"].(bool)
|
|
if err := s.syncEnrollment(ctx, client, allowSelfEnrollment); err != nil {
|
|
return err
|
|
}
|
|
|
|
providers, _ := idpSection["providers"].([]any)
|
|
now := time.Now().UTC().Format(time.RFC3339)
|
|
|
|
for _, item := range providers {
|
|
pm, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
id, _ := pm["id"].(string)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
pk, syncErr := s.syncProvider(ctx, client, pm, authFlow, enrollmentFlow, fullStored)
|
|
if syncErr != nil {
|
|
pm["sync_status"] = "error"
|
|
pm["sync_error"] = syncErr.Error()
|
|
} else {
|
|
pm["sync_status"] = "synced"
|
|
pm["sync_error"] = ""
|
|
pm["last_synced_at"] = now
|
|
if pk != "" {
|
|
if n, err := strconv.Atoi(pk); err == nil {
|
|
pm["authentik_pk"] = n
|
|
} else {
|
|
pm["authentik_pk"] = pk
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := s.deleteRemovedSources(ctx, client, removed); err != nil {
|
|
return err
|
|
}
|
|
|
|
idpSection["providers"] = providers
|
|
fullStored["identity_providers"] = idpSection
|
|
return s.persistPolicy(ctx, fullStored)
|
|
}
|
|
|
|
func (s *SourceSyncer) syncEnrollment(ctx context.Context, client *Client, allow bool) error {
|
|
brand, err := client.FindDefaultBrand(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if allow {
|
|
flow, found, err := client.FindFlowInstanceBySlug(ctx, "ulti-enrollment")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if !found {
|
|
return nil
|
|
}
|
|
return client.SetBrandEnrollmentFlow(ctx, brand.PK, &flow.PK)
|
|
}
|
|
return client.SetBrandEnrollmentFlow(ctx, brand.PK, nil)
|
|
}
|
|
|
|
func (s *SourceSyncer) syncProvider(ctx context.Context, client *Client, pm map[string]any, authFlow, enrollmentFlow string, fullStored map[string]any) (string, error) {
|
|
providerType, _ := pm["type"].(string)
|
|
enabled, _ := pm["enabled"].(bool)
|
|
slug, _ := pm["slug"].(string)
|
|
name, _ := pm["name"].(string)
|
|
if slug == "" {
|
|
return "", fmt.Errorf("missing slug")
|
|
}
|
|
if name == "" {
|
|
name = slug
|
|
}
|
|
|
|
pk := authentikPK(pm["authentik_pk"])
|
|
switch providerType {
|
|
case "oauth":
|
|
req, err := buildOAuthSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if pk == "" {
|
|
createdPK, err := client.CreateOAuthSource(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return createdPK, nil
|
|
}
|
|
return pk, client.UpdateOAuthSource(ctx, pk, req)
|
|
case "saml":
|
|
req := buildSAMLSourceRequest(pm, authFlow, enrollmentFlow, enabled, slug, name)
|
|
if pk == "" {
|
|
createdPK, err := client.CreateSAMLSource(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return createdPK, nil
|
|
}
|
|
return pk, client.UpdateSAMLSource(ctx, pk, req)
|
|
case "ldap":
|
|
req, err := buildLDAPSourceRequest(pm, fullStored, authFlow, enrollmentFlow, enabled, slug, name)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if pk == "" {
|
|
createdPK, err := client.CreateLDAPSource(ctx, req)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return createdPK, nil
|
|
}
|
|
return pk, client.UpdateLDAPSource(ctx, pk, req)
|
|
default:
|
|
return "", fmt.Errorf("unsupported provider type %q", providerType)
|
|
}
|
|
}
|
|
|
|
func buildOAuthSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (OAuthSourceRequest, error) {
|
|
oauth, _ := pm["oauth"].(map[string]any)
|
|
providerName, _ := oauth["provider"].(string)
|
|
preset := OAuthPresetFor(providerName)
|
|
|
|
clientID, _ := oauth["client_id"].(string)
|
|
clientSecret := resolveProviderSecret(pm, fullStored, "oauth", "client_secret")
|
|
|
|
authURL := stringOr(oauth["authorization_url"], preset.AuthorizationURL)
|
|
tokenURL := stringOr(oauth["token_url"], preset.AccessTokenURL)
|
|
profileURL := stringOr(oauth["profile_url"], preset.ProfileURL)
|
|
scopes := stringOr(oauth["scopes"], preset.DefaultScopes)
|
|
providerType := preset.ProviderType
|
|
if providerName == "custom" {
|
|
providerType = "openidconnect"
|
|
}
|
|
|
|
if strings.TrimSpace(clientID) == "" {
|
|
return OAuthSourceRequest{}, fmt.Errorf("oauth client_id required")
|
|
}
|
|
if strings.TrimSpace(clientSecret) == "" {
|
|
return OAuthSourceRequest{}, fmt.Errorf("oauth client_secret required")
|
|
}
|
|
|
|
return OAuthSourceRequest{
|
|
Name: name,
|
|
Slug: slug,
|
|
Enabled: enabled,
|
|
ProviderType: providerType,
|
|
ClientID: clientID,
|
|
ClientSecret: clientSecret,
|
|
AuthorizationURL: authURL,
|
|
AccessTokenURL: tokenURL,
|
|
ProfileURL: profileURL,
|
|
Scopes: scopes,
|
|
AuthenticationFlow: authFlow,
|
|
EnrollmentFlow: enrollmentFlow,
|
|
}, nil
|
|
}
|
|
|
|
func buildSAMLSourceRequest(pm map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) SAMLSourceRequest {
|
|
saml, _ := pm["saml"].(map[string]any)
|
|
return SAMLSourceRequest{
|
|
Name: name,
|
|
Slug: slug,
|
|
Enabled: enabled,
|
|
MetadataURL: stringValue(saml["metadata_url"]),
|
|
MetadataXML: stringValue(saml["metadata_xml"]),
|
|
EntityID: stringValue(saml["entity_id"]),
|
|
SSOURL: stringValue(saml["sso_url"]),
|
|
SLOURL: stringValue(saml["slo_url"]),
|
|
SigningCert: stringValue(saml["signing_cert"]),
|
|
AuthenticationFlow: authFlow,
|
|
EnrollmentFlow: enrollmentFlow,
|
|
}
|
|
}
|
|
|
|
func buildLDAPSourceRequest(pm, fullStored map[string]any, authFlow, enrollmentFlow string, enabled bool, slug, name string) (LDAPSourceRequest, error) {
|
|
ldap, _ := pm["ldap"].(map[string]any)
|
|
bindPassword := resolveProviderSecret(pm, fullStored, "ldap", "bind_password")
|
|
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
|
|
return LDAPSourceRequest{}, fmt.Errorf("ldap server_uri required")
|
|
}
|
|
if strings.TrimSpace(stringValue(ldap["bind_dn"])) == "" {
|
|
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_dn required")
|
|
}
|
|
if strings.TrimSpace(bindPassword) == "" {
|
|
return LDAPSourceRequest{}, fmt.Errorf("ldap bind_password required")
|
|
}
|
|
return LDAPSourceRequest{
|
|
Name: name,
|
|
Slug: slug,
|
|
Enabled: enabled,
|
|
ServerURI: stringValue(ldap["server_uri"]),
|
|
BindDN: stringValue(ldap["bind_dn"]),
|
|
BindPassword: bindPassword,
|
|
BaseDN: stringValue(ldap["base_dn"]),
|
|
UserFilter: stringValue(ldap["user_filter"]),
|
|
StartTLS: boolValue(ldap["start_tls"]),
|
|
SyncUsers: boolValue(ldap["sync_users"]),
|
|
AuthenticationFlow: authFlow,
|
|
EnrollmentFlow: enrollmentFlow,
|
|
}, nil
|
|
}
|
|
|
|
func (s *SourceSyncer) deleteRemovedSources(ctx context.Context, client *Client, removed []map[string]any) error {
|
|
for _, pm := range removed {
|
|
pk := authentikPK(pm["authentik_pk"])
|
|
if pk == "" {
|
|
continue
|
|
}
|
|
providerType, _ := pm["type"].(string)
|
|
switch providerType {
|
|
case "oauth":
|
|
_ = client.DeleteOAuthSource(ctx, pk)
|
|
case "saml":
|
|
_ = client.DeleteSAMLSource(ctx, pk)
|
|
case "ldap":
|
|
_ = client.DeleteLDAPSource(ctx, pk)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func RemovedIdentityProviders(before, after map[string]any) []map[string]any {
|
|
beforeSection, _ := before["identity_providers"].(map[string]any)
|
|
afterSection, _ := after["identity_providers"].(map[string]any)
|
|
beforeProviders, _ := beforeSection["providers"].([]any)
|
|
afterProviders, _ := afterSection["providers"].([]any)
|
|
|
|
afterIDs := make(map[string]struct{}, len(afterProviders))
|
|
for _, item := range afterProviders {
|
|
pm, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if id, _ := pm["id"].(string); id != "" {
|
|
afterIDs[id] = struct{}{}
|
|
}
|
|
}
|
|
|
|
removed := make([]map[string]any, 0)
|
|
for _, item := range beforeProviders {
|
|
pm, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
id, _ := pm["id"].(string)
|
|
if id == "" {
|
|
continue
|
|
}
|
|
if _, stillPresent := afterIDs[id]; !stillPresent {
|
|
removed = append(removed, pm)
|
|
}
|
|
}
|
|
return removed
|
|
}
|
|
|
|
func (s *SourceSyncer) markAllPending(ctx context.Context, idpSection map[string]any, reason string) error {
|
|
providers, _ := idpSection["providers"].([]any)
|
|
for _, item := range providers {
|
|
pm, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
pm["sync_status"] = "error"
|
|
pm["sync_error"] = reason
|
|
}
|
|
idpSection["providers"] = providers
|
|
stored, err := s.loadStoredPolicy(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
stored["identity_providers"] = idpSection
|
|
return s.persistPolicy(ctx, stored)
|
|
}
|
|
|
|
func (s *SourceSyncer) loadStoredPolicy(ctx context.Context) (map[string]any, error) {
|
|
var raw []byte
|
|
err := s.db.QueryRow(ctx, `SELECT settings FROM org_settings WHERE id = $1`, orgSettingsSingletonID).Scan(&raw)
|
|
if err != nil && err != pgx.ErrNoRows {
|
|
return nil, err
|
|
}
|
|
stored := map[string]any{}
|
|
if len(raw) > 0 {
|
|
if err := json.Unmarshal(raw, &stored); err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
return stored, nil
|
|
}
|
|
|
|
func (s *SourceSyncer) persistPolicy(ctx context.Context, policy map[string]any) error {
|
|
raw, err := json.Marshal(policy)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = s.db.Exec(ctx, `
|
|
UPDATE org_settings SET settings = $2, updated_at = NOW() WHERE id = $1
|
|
`, orgSettingsSingletonID, raw)
|
|
return err
|
|
}
|
|
|
|
func (s *SourceSyncer) RedirectURI(slug string) string {
|
|
return OAuthRedirectURI(AuthentikPublicBaseURL(s.cfg.AuthentikAPIURL, s.cfg.AuthentikPublicHTTPS), slug)
|
|
}
|
|
|
|
func (s *SourceSyncer) TestProvider(ctx context.Context, providerID string) error {
|
|
stored, err := s.loadStoredPolicy(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
pm, err := findProviderByID(stored, providerID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
client := s.client()
|
|
if client == nil {
|
|
return fmt.Errorf("Authentik API not configured")
|
|
}
|
|
if err := client.Ping(ctx); err != nil {
|
|
return err
|
|
}
|
|
providerType, _ := pm["type"].(string)
|
|
switch providerType {
|
|
case "ldap":
|
|
ldap, _ := pm["ldap"].(map[string]any)
|
|
if strings.TrimSpace(stringValue(ldap["server_uri"])) == "" {
|
|
return fmt.Errorf("ldap server_uri required")
|
|
}
|
|
case "saml":
|
|
saml, _ := pm["saml"].(map[string]any)
|
|
if strings.TrimSpace(stringValue(saml["metadata_url"])) == "" &&
|
|
strings.TrimSpace(stringValue(saml["sso_url"])) == "" {
|
|
return fmt.Errorf("saml metadata_url or sso_url required")
|
|
}
|
|
case "oauth":
|
|
oauth, _ := pm["oauth"].(map[string]any)
|
|
if strings.TrimSpace(stringValue(oauth["client_id"])) == "" {
|
|
return fmt.Errorf("oauth client_id required")
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func findProviderByID(stored map[string]any, id string) (map[string]any, error) {
|
|
idpSection, _ := stored["identity_providers"].(map[string]any)
|
|
providers, _ := idpSection["providers"].([]any)
|
|
for _, item := range providers {
|
|
pm, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if pid, _ := pm["id"].(string); pid == id {
|
|
return pm, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("provider not found")
|
|
}
|
|
|
|
func resolveProviderSecret(pm, fullStored map[string]any, section, key string) string {
|
|
nested, _ := pm[section].(map[string]any)
|
|
if val := strings.TrimSpace(stringValue(nested[key])); val != "" {
|
|
return val
|
|
}
|
|
id, _ := pm["id"].(string)
|
|
if id == "" {
|
|
return ""
|
|
}
|
|
existing, _ := fullStored["identity_providers"].(map[string]any)
|
|
existingProviders, _ := existing["providers"].([]any)
|
|
for _, item := range existingProviders {
|
|
em, ok := item.(map[string]any)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if eid, _ := em["id"].(string); eid != id {
|
|
continue
|
|
}
|
|
esec, _ := em[section].(map[string]any)
|
|
return strings.TrimSpace(stringValue(esec[key]))
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func authentikPK(v any) string {
|
|
switch t := v.(type) {
|
|
case string:
|
|
return strings.TrimSpace(t)
|
|
case float64:
|
|
return strconv.Itoa(int(t))
|
|
case int:
|
|
return strconv.Itoa(t)
|
|
case int64:
|
|
return strconv.Itoa(int(t))
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func stringValue(v any) string {
|
|
s, _ := v.(string)
|
|
return s
|
|
}
|
|
|
|
func stringOr(v any, fallback string) string {
|
|
if s := strings.TrimSpace(stringValue(v)); s != "" {
|
|
return s
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
func boolValue(v any) bool {
|
|
b, ok := v.(bool)
|
|
return ok && b
|
|
}
|