package migration import ( "context" "crypto/rand" "encoding/base64" "encoding/json" "errors" "fmt" "strings" "time" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/redis/go-redis/v9" "golang.org/x/oauth2" "github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/hosted" ) var ( ErrInviteNotFound = errors.New("migration invite not found") ErrInviteClaimed = errors.New("migration invite already claimed") ErrEmailMismatch = errors.New("email does not match invite") ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active") ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain") ) type Service struct { db *pgxpool.Pool rdb *redis.Client creds *credentials.Manager hosted *hosted.Service oauth *OAuthService cutover CutoverConfig } func NewService(db *pgxpool.Pool, rdb *redis.Client, creds *credentials.Manager, hostedSvc *hosted.Service, oauth *OAuthService) *Service { return &Service{db: db, rdb: rdb, creds: creds, hosted: hostedSvc, oauth: oauth} } func (s *Service) SetCutoverConfig(cfg CutoverConfig) { s.cutover = cfg } type Project struct { ID string `json:"id"` DomainID string `json:"domain_id,omitempty"` Name string `json:"name"` SourceProvider string `json:"source_provider"` AuthMode string `json:"auth_mode"` Status string `json:"status"` CutoverAt *string `json:"cutover_at,omitempty"` DeltaMode bool `json:"delta_mode"` CreatedAt string `json:"created_at"` MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"` MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"` MicrosoftAdminConsentError string `json:"microsoft_admin_consent_error,omitempty"` CutoverDNS *hosted.DNSCheckReport `json:"cutover_dns,omitempty"` } type Invite struct { ID string `json:"id"` ProjectID string `json:"project_id"` Email string `json:"email"` AlternateEmails []string `json:"alternate_emails,omitempty"` Token string `json:"token,omitempty"` Status string `json:"status"` ClaimedAt *string `json:"claimed_at,omitempty"` UserID string `json:"user_id,omitempty"` } type Job struct { ID string `json:"id"` ProjectID string `json:"project_id"` UserID string `json:"user_id"` Service string `json:"service"` Status string `json:"status"` CursorJSON map[string]any `json:"cursor_json"` StatsJSON map[string]any `json:"stats_json"` Error string `json:"error,omitempty"` StartedAt *string `json:"started_at,omitempty"` CompletedAt *string `json:"completed_at,omitempty"` } type UserStatus struct { Project Project `json:"project"` Invite Invite `json:"invite,omitempty"` Jobs []Job `json:"jobs"` Onboarding OnboardingHints `json:"onboarding"` } func (s *Service) CreateProject(ctx context.Context, name, sourceProvider, domainID, authMode string) (Project, error) { name = strings.TrimSpace(name) if name == "" { return Project{}, fmt.Errorf("project name required") } sourceProvider = strings.ToLower(strings.TrimSpace(sourceProvider)) if sourceProvider == "" { sourceProvider = "google" } authMode = NormalizeAuthMode(sourceProvider, authMode) sc := newProjectScanner() err := s.db.QueryRow(ctx, ` INSERT INTO migration_projects (name, source_provider, domain_id, auth_mode) VALUES ($1, $2, NULLIF($3, '')::uuid, $4) RETURNING `+projectSelectSQL("")+` `, name, sourceProvider, domainID, authMode).Scan(sc.targets()...) return sc.result(), err } func (s *Service) ListProjects(ctx context.Context) ([]Project, error) { rows, err := s.db.Query(ctx, ` SELECT `+projectSelectSQL("")+` FROM migration_projects ORDER BY created_at DESC `) if err != nil { return nil, err } defer rows.Close() var out []Project for rows.Next() { sc := newProjectScanner() if err := rows.Scan(sc.targets()...); err != nil { return nil, err } out = append(out, sc.result()) } return out, rows.Err() } func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alternateEmails []string) (Invite, error) { email = strings.ToLower(strings.TrimSpace(email)) if email == "" { return Invite{}, fmt.Errorf("email required") } alternates := normalizeAlternateEmails(email, alternateEmails) token, err := hosted.NewInviteToken() if err != nil { return Invite{}, err } var row Invite err = s.db.QueryRow(ctx, ` INSERT INTO migration_invites (project_id, email, token, alternate_emails) VALUES ($1::uuid, $2, $3, $4) RETURNING id::text, project_id::text, email, token, status, claimed_at::text, COALESCE(user_id::text, ''), alternate_emails `, projectID, email, token, alternates).Scan( &row.ID, &row.ProjectID, &row.Email, &row.Token, &row.Status, &row.ClaimedAt, &row.UserID, &row.AlternateEmails, ) return row, err } func normalizeAlternateEmails(inviteEmail string, alternateEmails []string) []string { inviteEmail = normalizeInviteEmail(inviteEmail) seen := map[string]struct{}{inviteEmail: {}} var out []string for _, raw := range alternateEmails { email := normalizeInviteEmail(raw) if email == "" || !isEmailAddress(email) { continue } if _, ok := seen[email]; ok { continue } seen[email] = struct{}{} out = append(out, email) } return out } func (s *Service) ImportInvites(ctx context.Context, projectID string, emails []string) (int, error) { count := 0 for _, email := range emails { email = strings.ToLower(strings.TrimSpace(email)) if email == "" { continue } if _, err := s.CreateInvite(ctx, projectID, email, nil); err != nil { return count, err } count++ } return count, nil } func (s *Service) GetInviteByToken(ctx context.Context, token string) (Invite, Project, error) { var inv Invite sc := newProjectScanner() scanArgs := append([]any{ &inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID, &inv.AlternateEmails, }, sc.targets()...) err := s.db.QueryRow(ctx, ` SELECT i.id::text, i.project_id::text, i.email, i.status, i.claimed_at::text, COALESCE(i.user_id::text, ''), i.alternate_emails, `+projectSelectSQL("p")+` FROM migration_invites i JOIN migration_projects p ON p.id = i.project_id WHERE i.token = $1 `, token).Scan(scanArgs...) if errors.Is(err, pgx.ErrNoRows) { return Invite{}, Project{}, ErrInviteNotFound } return inv, sc.result(), err } func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identity ClaimIdentity, displayName, password string) (UserStatus, error) { inv, proj, err := s.GetInviteByToken(ctx, token) if err != nil { return UserStatus{}, err } if inv.Status == "claimed" { return UserStatus{}, ErrInviteClaimed } projectDomain := "" var hostedDomain *hosted.DomainRow if strings.TrimSpace(proj.DomainID) != "" && s.hosted != nil { domain, err := s.hosted.GetDomain(ctx, proj.DomainID) if err != nil { return UserStatus{}, fmt.Errorf("migration domain: %w", err) } hostedDomain = &domain projectDomain = domain.Name } if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) { return UserStatus{}, ErrEmailMismatch } mailboxEmail := normalizeInviteEmail(inv.Email) tx, err := s.db.Begin(ctx) if err != nil { return UserStatus{}, err } defer tx.Rollback(ctx) _, err = tx.Exec(ctx, ` UPDATE migration_invites SET status = 'claimed', claimed_at = NOW(), user_id = $1::uuid WHERE id = $2::uuid AND status = 'invited' `, userID, inv.ID) if err != nil { return UserStatus{}, err } if s.hosted != nil { provision := hosted.ProvisionMailboxInput{ UserID: userID, Email: mailboxEmail, DisplayName: displayName, Password: password, QuotaBytes: 0, } if hostedDomain != nil { at := strings.LastIndex(mailboxEmail, "@") if at <= 0 || !strings.EqualFold(mailboxEmail[at+1:], hostedDomain.Name) { return UserStatus{}, ErrMigrationDomainMismatch } if hostedDomain.Status != "active" && !hostedDomain.IsPlatformDomain { return UserStatus{}, ErrMigrationDomainNotActive } provision.DomainID = proj.DomainID } _, err = s.hosted.ProvisionMailbox(ctx, provision) if err != nil { if errors.Is(err, hosted.ErrDomainNotActive) { return UserStatus{}, ErrMigrationDomainNotActive } if !errors.Is(err, hosted.ErrAddressTaken) { return UserStatus{}, err } } } services := []string{"mail", "contacts", "calendar", "drive"} for _, svc := range services { _, err = tx.Exec(ctx, ` INSERT INTO migration_jobs (project_id, user_id, service, status) VALUES ($1::uuid, $2::uuid, $3, 'pending') ON CONFLICT (project_id, user_id, service) DO NOTHING `, proj.ID, userID, svc) if err != nil { return UserStatus{}, err } } if err := tx.Commit(ctx); err != nil { return UserStatus{}, err } return s.GetUserStatus(ctx, userID, proj.ID) } func (s *Service) StoreMigrationToken(ctx context.Context, userID, projectID, provider string, token *oauth2.Token, scopes []string) error { if s.creds == nil { return fmt.Errorf("credential manager not configured") } payload, err := json.Marshal(map[string]any{ "access_token": token.AccessToken, "refresh_token": token.RefreshToken, "expiry": token.Expiry.UTC().Format(time.RFC3339), "token_type": token.TokenType, }) if err != nil { return err } enc, err := s.creds.EncryptCredential(credentials.Credential{ AuthType: credentials.AuthOAuth2, AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, Expiry: token.Expiry, OAuthProvider: provider, }) if err != nil { _ = payload return err } var expiresAt *time.Time if !token.Expiry.IsZero() { expiresAt = &token.Expiry } _, err = s.db.Exec(ctx, ` INSERT INTO migration_credentials (user_id, project_id, provider, encrypted_token, scopes, expires_at) VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6) ON CONFLICT (user_id, project_id, provider) DO UPDATE SET encrypted_token = EXCLUDED.encrypted_token, scopes = EXCLUDED.scopes, expires_at = EXCLUDED.expires_at, revoked_at = NULL `, userID, projectID, provider, enc, scopes, expiresAt) return err } func (s *Service) GetUserStatus(ctx context.Context, userID, projectID string) (UserStatus, error) { sc := newProjectScanner() err := s.db.QueryRow(ctx, ` SELECT `+projectSelectSQL("")+` FROM migration_projects WHERE id = $1::uuid `, projectID).Scan(sc.targets()...) proj := sc.result() if err != nil { return UserStatus{}, err } var inv Invite _ = s.db.QueryRow(ctx, ` SELECT id::text, project_id::text, email, status, claimed_at::text, COALESCE(user_id::text, '') FROM migration_invites WHERE project_id = $1::uuid AND user_id = $2::uuid `, projectID, userID).Scan( &inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID, ) jobs, err := s.listJobs(ctx, projectID, userID) if err != nil { return UserStatus{}, err } return UserStatus{ Project: proj, Invite: inv, Jobs: jobs, Onboarding: s.BuildOnboardingHints(ctx, userID, proj, inv), }, nil } func (s *Service) GetActiveUserStatus(ctx context.Context, userID string) (UserStatus, error) { var projectID string err := s.db.QueryRow(ctx, ` SELECT project_id::text FROM migration_invites WHERE user_id = $1::uuid AND status = 'claimed' ORDER BY claimed_at DESC NULLS LAST LIMIT 1 `, userID).Scan(&projectID) if errors.Is(err, pgx.ErrNoRows) { return UserStatus{}, nil } if err != nil { return UserStatus{}, err } return s.GetUserStatus(ctx, userID, projectID) } func (s *Service) listJobs(ctx context.Context, projectID, userID string) ([]Job, error) { rows, err := s.db.Query(ctx, ` SELECT id::text, project_id::text, user_id::text, service, status, cursor_json, stats_json, error, started_at::text, completed_at::text FROM migration_jobs WHERE project_id = $1::uuid AND user_id = $2::uuid ORDER BY service ASC `, projectID, userID) if err != nil { return nil, err } defer rows.Close() var out []Job for rows.Next() { var row Job var cursorRaw, statsRaw []byte if err := rows.Scan( &row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status, &cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt, ); err != nil { return nil, err } _ = json.Unmarshal(cursorRaw, &row.CursorJSON) _ = json.Unmarshal(statsRaw, &row.StatsJSON) if row.CursorJSON == nil { row.CursorJSON = map[string]any{} } if row.StatsJSON == nil { row.StatsJSON = map[string]any{} } out = append(out, row) } return out, rows.Err() } func (s *Service) PendingJobs(ctx context.Context, limit int) ([]Job, error) { if limit <= 0 { limit = 10 } rows, err := s.db.Query(ctx, ` SELECT j.id::text, j.project_id::text, j.user_id::text, j.service, j.status, j.cursor_json, j.stats_json, j.error, j.started_at::text, j.completed_at::text FROM migration_jobs j JOIN migration_projects p ON p.id = j.project_id WHERE j.status IN ('pending', 'running') AND p.status IN ('active', 'cutover') ORDER BY j.updated_at ASC LIMIT $1 `, limit) if err != nil { return nil, err } defer rows.Close() return scanJobs(rows) } func scanJobs(rows pgx.Rows) ([]Job, error) { var out []Job for rows.Next() { var row Job var cursorRaw, statsRaw []byte if err := rows.Scan( &row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status, &cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt, ); err != nil { return nil, err } _ = json.Unmarshal(cursorRaw, &row.CursorJSON) _ = json.Unmarshal(statsRaw, &row.StatsJSON) if row.CursorJSON == nil { row.CursorJSON = map[string]any{} } if row.StatsJSON == nil { row.StatsJSON = map[string]any{} } out = append(out, row) } return out, rows.Err() } func (s *Service) UpdateJobProgress(ctx context.Context, jobID, status string, cursor, stats map[string]any, jobErr string) error { cursorRaw, _ := json.Marshal(cursor) statsRaw, _ := json.Marshal(stats) _, err := s.db.Exec(ctx, ` UPDATE migration_jobs SET status = $2, cursor_json = $3, stats_json = $4, error = $5, started_at = COALESCE(started_at, CASE WHEN $2 = 'running' THEN NOW() ELSE NULL END), completed_at = CASE WHEN $2 IN ('completed', 'failed') THEN NOW() ELSE completed_at END, updated_at = NOW() WHERE id = $1::uuid `, jobID, status, cursorRaw, statsRaw, jobErr) return err } func (s *Service) ActivateProject(ctx context.Context, projectID string) (Project, error) { sc := newProjectScanner() err := s.db.QueryRow(ctx, ` UPDATE migration_projects SET status = 'active', updated_at = NOW() WHERE id = $1::uuid RETURNING `+projectSelectSQL("")+` `, projectID).Scan(sc.targets()...) return sc.result(), err } func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) { var userID string err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID) return userID, err } func randomState() (string, error) { b := make([]byte, 24) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil }