package orgpolicy import ( "context" "encoding/json" "strings" "sync" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/config" ) const orgSettingsSingletonID = 1 const defaultMaxUploadMiB = 512 // FilePolicies holds runtime file upload policy from org settings. type FilePolicies struct { VirusScanEnabled bool VirusTotalAPIKey string MaxUploadBytes int64 } type Loader struct { db *pgxpool.Pool cfg *config.Config mu sync.Mutex cached FilePolicies cachedAt time.Time authCached AuthAccessPolicy authCachedAt time.Time ttl time.Duration } func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader { return &Loader{ db: db, cfg: cfg, ttl: 60 * time.Second, } } func (l *Loader) FilePolicies(ctx context.Context) (FilePolicies, error) { l.mu.Lock() if !l.cachedAt.IsZero() && time.Since(l.cachedAt) < l.ttl { out := l.cached l.mu.Unlock() return out, nil } l.mu.Unlock() fp, err := l.loadFilePolicies(ctx) if err != nil { return FilePolicies{}, err } l.mu.Lock() l.cached = fp l.cachedAt = time.Now() l.mu.Unlock() return fp, nil } func (l *Loader) loadFilePolicies(ctx context.Context) (FilePolicies, error) { var raw []byte err := l.db.QueryRow(ctx, ` SELECT settings FROM org_settings WHERE id = $1 `, orgSettingsSingletonID).Scan(&raw) if err != nil && err != pgx.ErrNoRows { return FilePolicies{}, err } stored := map[string]any{} if len(raw) > 0 { if err := json.Unmarshal(raw, &stored); err != nil { return FilePolicies{}, err } } filePolicies, _ := stored["file_policies"].(map[string]any) enabled := boolValue(filePolicies["virus_scan_enabled"]) apiKey := stringValue(filePolicies["virustotal_api_key"]) if strings.TrimSpace(apiKey) == "" && l.cfg != nil { apiKey = strings.TrimSpace(l.cfg.VirusTotalAPIKey) } maxMiB := int64(defaultMaxUploadMiB) switch v := filePolicies["max_upload_mib"].(type) { case float64: if v > 0 { maxMiB = int64(v) } case int: if v > 0 { maxMiB = int64(v) } case int64: if v > 0 { maxMiB = v } } return FilePolicies{ VirusScanEnabled: enabled, VirusTotalAPIKey: apiKey, MaxUploadBytes: maxMiB * 1024 * 1024, }, nil } func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) { fp, err := l.FilePolicies(ctx) if err != nil { return false, "", err } if !fp.VirusScanEnabled || strings.TrimSpace(fp.VirusTotalAPIKey) == "" { return false, "", nil } return true, fp.VirusTotalAPIKey, nil } func (l *Loader) AuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) { l.mu.Lock() if !l.authCachedAt.IsZero() && time.Since(l.authCachedAt) < l.ttl { out := l.authCached l.mu.Unlock() return out, nil } l.mu.Unlock() policy, err := l.loadAuthAccessPolicy(ctx) if err != nil { return AuthAccessPolicy{}, err } l.mu.Lock() l.authCached = policy l.authCachedAt = time.Now() l.mu.Unlock() return policy, nil } func (l *Loader) loadAuthAccessPolicy(ctx context.Context) (AuthAccessPolicy, error) { var raw []byte err := l.db.QueryRow(ctx, ` SELECT settings FROM org_settings WHERE id = $1 `, orgSettingsSingletonID).Scan(&raw) if err != nil && err != pgx.ErrNoRows { return AuthAccessPolicy{}, err } stored := map[string]any{} if len(raw) > 0 { if err := json.Unmarshal(raw, &stored); err != nil { return AuthAccessPolicy{}, err } } idp, _ := stored["identity_providers"].(map[string]any) if idp == nil { return AuthAccessPolicy{AllowSelfEnrollment: true}, nil } allowSelfEnrollment := true if v, ok := idp["allow_self_enrollment"].(bool); ok { allowSelfEnrollment = v } providersRaw, _ := idp["providers"].([]any) providers := make([]IdentityProviderPolicy, 0, len(providersRaw)) for _, item := range providersRaw { pm, ok := item.(map[string]any) if !ok { continue } providers = append(providers, IdentityProviderPolicy{ ID: stringValue(pm["id"]), Slug: stringValue(pm["slug"]), Type: stringValue(pm["type"]), Enabled: boolValue(pm["enabled"]), AllowedEmailDomains: stringSlice(pm["allowed_email_domains"]), AllowedIdentities: stringSlice(pm["allowed_identities"]), AllowedOrganizations: stringSlice(pm["allowed_organizations"]), }) } return AuthAccessPolicy{ AllowSelfEnrollment: allowSelfEnrollment, Providers: providers, }, nil } func boolValue(v any) bool { switch t := v.(type) { case bool: return t default: return false } } func stringValue(v any) string { s, _ := v.(string) return s }