ultisuite-backend/internal/migration/service.go
R3D347HR4Y 7143a36c19
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run
feat(mail): integrate Stalwart hosted mail and migration features
- Added configuration options for Stalwart hosted mail in .env.example.
- Updated Docker Compose to include Stalwart service with health checks.
- Introduced new API endpoints for managing mail domains and migration projects.
- Enhanced Authentik blueprints for user enrollment and post-migration security.
- Updated OAuth handling for Google and Microsoft migration processes.
- Improved error handling and response structures in the mail API.
- Added integration tests for email claiming and migration workflows.
2026-06-13 12:47:08 +02:00

500 lines
15 KiB
Go

package migration
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/redis/go-redis/v9"
"golang.org/x/oauth2"
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
)
var (
ErrInviteNotFound = errors.New("migration invite not found")
ErrInviteClaimed = errors.New("migration invite already claimed")
ErrEmailMismatch = errors.New("email does not match invite")
ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active")
ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain")
)
type Service struct {
db *pgxpool.Pool
rdb *redis.Client
creds *credentials.Manager
hosted *hosted.Service
oauth *OAuthService
cutover CutoverConfig
}
func NewService(db *pgxpool.Pool, rdb *redis.Client, creds *credentials.Manager, hostedSvc *hosted.Service, oauth *OAuthService) *Service {
return &Service{db: db, rdb: rdb, creds: creds, hosted: hostedSvc, oauth: oauth}
}
func (s *Service) SetCutoverConfig(cfg CutoverConfig) {
s.cutover = cfg
}
type Project struct {
ID string `json:"id"`
DomainID string `json:"domain_id,omitempty"`
Name string `json:"name"`
SourceProvider string `json:"source_provider"`
AuthMode string `json:"auth_mode"`
Status string `json:"status"`
CutoverAt *string `json:"cutover_at,omitempty"`
DeltaMode bool `json:"delta_mode"`
CreatedAt string `json:"created_at"`
MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"`
MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"`
MicrosoftAdminConsentError string `json:"microsoft_admin_consent_error,omitempty"`
CutoverDNS *hosted.DNSCheckReport `json:"cutover_dns,omitempty"`
}
type Invite struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
Email string `json:"email"`
AlternateEmails []string `json:"alternate_emails,omitempty"`
Token string `json:"token,omitempty"`
Status string `json:"status"`
ClaimedAt *string `json:"claimed_at,omitempty"`
UserID string `json:"user_id,omitempty"`
}
type Job struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
UserID string `json:"user_id"`
Service string `json:"service"`
Status string `json:"status"`
CursorJSON map[string]any `json:"cursor_json"`
StatsJSON map[string]any `json:"stats_json"`
Error string `json:"error,omitempty"`
StartedAt *string `json:"started_at,omitempty"`
CompletedAt *string `json:"completed_at,omitempty"`
}
type UserStatus struct {
Project Project `json:"project"`
Invite Invite `json:"invite,omitempty"`
Jobs []Job `json:"jobs"`
Onboarding OnboardingHints `json:"onboarding"`
}
func (s *Service) CreateProject(ctx context.Context, name, sourceProvider, domainID, authMode string) (Project, error) {
name = strings.TrimSpace(name)
if name == "" {
return Project{}, fmt.Errorf("project name required")
}
sourceProvider = strings.ToLower(strings.TrimSpace(sourceProvider))
if sourceProvider == "" {
sourceProvider = "google"
}
authMode = NormalizeAuthMode(sourceProvider, authMode)
sc := newProjectScanner()
err := s.db.QueryRow(ctx, `
INSERT INTO migration_projects (name, source_provider, domain_id, auth_mode)
VALUES ($1, $2, NULLIF($3, '')::uuid, $4)
RETURNING `+projectSelectSQL("")+`
`, name, sourceProvider, domainID, authMode).Scan(sc.targets()...)
return sc.result(), err
}
func (s *Service) ListProjects(ctx context.Context) ([]Project, error) {
rows, err := s.db.Query(ctx, `
SELECT `+projectSelectSQL("")+`
FROM migration_projects ORDER BY created_at DESC
`)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Project
for rows.Next() {
sc := newProjectScanner()
if err := rows.Scan(sc.targets()...); err != nil {
return nil, err
}
out = append(out, sc.result())
}
return out, rows.Err()
}
func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alternateEmails []string) (Invite, error) {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return Invite{}, fmt.Errorf("email required")
}
alternates := normalizeAlternateEmails(email, alternateEmails)
token, err := hosted.NewInviteToken()
if err != nil {
return Invite{}, err
}
var row Invite
err = s.db.QueryRow(ctx, `
INSERT INTO migration_invites (project_id, email, token, alternate_emails)
VALUES ($1::uuid, $2, $3, $4)
RETURNING id::text, project_id::text, email, token, status, claimed_at::text, COALESCE(user_id::text, ''), alternate_emails
`, projectID, email, token, alternates).Scan(
&row.ID, &row.ProjectID, &row.Email, &row.Token, &row.Status, &row.ClaimedAt, &row.UserID, &row.AlternateEmails,
)
return row, err
}
func normalizeAlternateEmails(inviteEmail string, alternateEmails []string) []string {
inviteEmail = normalizeInviteEmail(inviteEmail)
seen := map[string]struct{}{inviteEmail: {}}
var out []string
for _, raw := range alternateEmails {
email := normalizeInviteEmail(raw)
if email == "" || !isEmailAddress(email) {
continue
}
if _, ok := seen[email]; ok {
continue
}
seen[email] = struct{}{}
out = append(out, email)
}
return out
}
func (s *Service) ImportInvites(ctx context.Context, projectID string, emails []string) (int, error) {
count := 0
for _, email := range emails {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
continue
}
if _, err := s.CreateInvite(ctx, projectID, email, nil); err != nil {
return count, err
}
count++
}
return count, nil
}
func (s *Service) GetInviteByToken(ctx context.Context, token string) (Invite, Project, error) {
var inv Invite
sc := newProjectScanner()
scanArgs := append([]any{
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID, &inv.AlternateEmails,
}, sc.targets()...)
err := s.db.QueryRow(ctx, `
SELECT i.id::text, i.project_id::text, i.email, i.status, i.claimed_at::text, COALESCE(i.user_id::text, ''), i.alternate_emails,
`+projectSelectSQL("p")+`
FROM migration_invites i
JOIN migration_projects p ON p.id = i.project_id
WHERE i.token = $1
`, token).Scan(scanArgs...)
if errors.Is(err, pgx.ErrNoRows) {
return Invite{}, Project{}, ErrInviteNotFound
}
return inv, sc.result(), err
}
func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identity ClaimIdentity, displayName, password string) (UserStatus, error) {
inv, proj, err := s.GetInviteByToken(ctx, token)
if err != nil {
return UserStatus{}, err
}
if inv.Status == "claimed" {
return UserStatus{}, ErrInviteClaimed
}
projectDomain := ""
var hostedDomain *hosted.DomainRow
if strings.TrimSpace(proj.DomainID) != "" && s.hosted != nil {
domain, err := s.hosted.GetDomain(ctx, proj.DomainID)
if err != nil {
return UserStatus{}, fmt.Errorf("migration domain: %w", err)
}
hostedDomain = &domain
projectDomain = domain.Name
}
if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) {
return UserStatus{}, ErrEmailMismatch
}
mailboxEmail := normalizeInviteEmail(inv.Email)
tx, err := s.db.Begin(ctx)
if err != nil {
return UserStatus{}, err
}
defer tx.Rollback(ctx)
_, err = tx.Exec(ctx, `
UPDATE migration_invites
SET status = 'claimed', claimed_at = NOW(), user_id = $1::uuid
WHERE id = $2::uuid AND status = 'invited'
`, userID, inv.ID)
if err != nil {
return UserStatus{}, err
}
if s.hosted != nil {
provision := hosted.ProvisionMailboxInput{
UserID: userID,
Email: mailboxEmail,
DisplayName: displayName,
Password: password,
QuotaBytes: 0,
}
if hostedDomain != nil {
at := strings.LastIndex(mailboxEmail, "@")
if at <= 0 || !strings.EqualFold(mailboxEmail[at+1:], hostedDomain.Name) {
return UserStatus{}, ErrMigrationDomainMismatch
}
if hostedDomain.Status != "active" && !hostedDomain.IsPlatformDomain {
return UserStatus{}, ErrMigrationDomainNotActive
}
provision.DomainID = proj.DomainID
}
_, err = s.hosted.ProvisionMailbox(ctx, provision)
if err != nil {
if errors.Is(err, hosted.ErrDomainNotActive) {
return UserStatus{}, ErrMigrationDomainNotActive
}
if !errors.Is(err, hosted.ErrAddressTaken) {
return UserStatus{}, err
}
}
}
services := []string{"mail", "contacts", "calendar", "drive"}
for _, svc := range services {
_, err = tx.Exec(ctx, `
INSERT INTO migration_jobs (project_id, user_id, service, status)
VALUES ($1::uuid, $2::uuid, $3, 'pending')
ON CONFLICT (project_id, user_id, service) DO NOTHING
`, proj.ID, userID, svc)
if err != nil {
return UserStatus{}, err
}
}
if err := tx.Commit(ctx); err != nil {
return UserStatus{}, err
}
return s.GetUserStatus(ctx, userID, proj.ID)
}
func (s *Service) StoreMigrationToken(ctx context.Context, userID, projectID, provider string, token *oauth2.Token, scopes []string) error {
if s.creds == nil {
return fmt.Errorf("credential manager not configured")
}
payload, err := json.Marshal(map[string]any{
"access_token": token.AccessToken,
"refresh_token": token.RefreshToken,
"expiry": token.Expiry.UTC().Format(time.RFC3339),
"token_type": token.TokenType,
})
if err != nil {
return err
}
enc, err := s.creds.EncryptCredential(credentials.Credential{
AuthType: credentials.AuthOAuth2,
AccessToken: token.AccessToken,
RefreshToken: token.RefreshToken,
Expiry: token.Expiry,
OAuthProvider: provider,
})
if err != nil {
_ = payload
return err
}
var expiresAt *time.Time
if !token.Expiry.IsZero() {
expiresAt = &token.Expiry
}
_, err = s.db.Exec(ctx, `
INSERT INTO migration_credentials (user_id, project_id, provider, encrypted_token, scopes, expires_at)
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6)
ON CONFLICT (user_id, project_id, provider) DO UPDATE SET
encrypted_token = EXCLUDED.encrypted_token,
scopes = EXCLUDED.scopes,
expires_at = EXCLUDED.expires_at,
revoked_at = NULL
`, userID, projectID, provider, enc, scopes, expiresAt)
return err
}
func (s *Service) GetUserStatus(ctx context.Context, userID, projectID string) (UserStatus, error) {
sc := newProjectScanner()
err := s.db.QueryRow(ctx, `
SELECT `+projectSelectSQL("")+`
FROM migration_projects WHERE id = $1::uuid
`, projectID).Scan(sc.targets()...)
proj := sc.result()
if err != nil {
return UserStatus{}, err
}
var inv Invite
_ = s.db.QueryRow(ctx, `
SELECT id::text, project_id::text, email, status, claimed_at::text, COALESCE(user_id::text, '')
FROM migration_invites WHERE project_id = $1::uuid AND user_id = $2::uuid
`, projectID, userID).Scan(
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID,
)
jobs, err := s.listJobs(ctx, projectID, userID)
if err != nil {
return UserStatus{}, err
}
return UserStatus{
Project: proj,
Invite: inv,
Jobs: jobs,
Onboarding: s.BuildOnboardingHints(ctx, userID, proj, inv),
}, nil
}
func (s *Service) GetActiveUserStatus(ctx context.Context, userID string) (UserStatus, error) {
var projectID string
err := s.db.QueryRow(ctx, `
SELECT project_id::text FROM migration_invites
WHERE user_id = $1::uuid AND status = 'claimed'
ORDER BY claimed_at DESC NULLS LAST LIMIT 1
`, userID).Scan(&projectID)
if errors.Is(err, pgx.ErrNoRows) {
return UserStatus{}, nil
}
if err != nil {
return UserStatus{}, err
}
return s.GetUserStatus(ctx, userID, projectID)
}
func (s *Service) listJobs(ctx context.Context, projectID, userID string) ([]Job, error) {
rows, err := s.db.Query(ctx, `
SELECT id::text, project_id::text, user_id::text, service, status,
cursor_json, stats_json, error, started_at::text, completed_at::text
FROM migration_jobs
WHERE project_id = $1::uuid AND user_id = $2::uuid
ORDER BY service ASC
`, projectID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []Job
for rows.Next() {
var row Job
var cursorRaw, statsRaw []byte
if err := rows.Scan(
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
); err != nil {
return nil, err
}
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
if row.CursorJSON == nil {
row.CursorJSON = map[string]any{}
}
if row.StatsJSON == nil {
row.StatsJSON = map[string]any{}
}
out = append(out, row)
}
return out, rows.Err()
}
func (s *Service) PendingJobs(ctx context.Context, limit int) ([]Job, error) {
if limit <= 0 {
limit = 10
}
rows, err := s.db.Query(ctx, `
SELECT j.id::text, j.project_id::text, j.user_id::text, j.service, j.status,
j.cursor_json, j.stats_json, j.error, j.started_at::text, j.completed_at::text
FROM migration_jobs j
JOIN migration_projects p ON p.id = j.project_id
WHERE j.status IN ('pending', 'running')
AND p.status IN ('active', 'cutover')
ORDER BY j.updated_at ASC
LIMIT $1
`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
return scanJobs(rows)
}
func scanJobs(rows pgx.Rows) ([]Job, error) {
var out []Job
for rows.Next() {
var row Job
var cursorRaw, statsRaw []byte
if err := rows.Scan(
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
); err != nil {
return nil, err
}
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
if row.CursorJSON == nil {
row.CursorJSON = map[string]any{}
}
if row.StatsJSON == nil {
row.StatsJSON = map[string]any{}
}
out = append(out, row)
}
return out, rows.Err()
}
func (s *Service) UpdateJobProgress(ctx context.Context, jobID, status string, cursor, stats map[string]any, jobErr string) error {
cursorRaw, _ := json.Marshal(cursor)
statsRaw, _ := json.Marshal(stats)
_, err := s.db.Exec(ctx, `
UPDATE migration_jobs SET
status = $2,
cursor_json = $3,
stats_json = $4,
error = $5,
started_at = COALESCE(started_at, CASE WHEN $2 = 'running' THEN NOW() ELSE NULL END),
completed_at = CASE WHEN $2 IN ('completed', 'failed') THEN NOW() ELSE completed_at END,
updated_at = NOW()
WHERE id = $1::uuid
`, jobID, status, cursorRaw, statsRaw, jobErr)
return err
}
func (s *Service) ActivateProject(ctx context.Context, projectID string) (Project, error) {
sc := newProjectScanner()
err := s.db.QueryRow(ctx, `
UPDATE migration_projects SET status = 'active', updated_at = NOW()
WHERE id = $1::uuid
RETURNING `+projectSelectSQL("")+`
`, projectID).Scan(sc.targets()...)
return sc.result(), err
}
func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) {
var userID string
err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
return userID, err
}
func randomState() (string, error) {
b := make([]byte, 24)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}