- Added configuration options for Stalwart hosted mail in .env.example. - Updated Docker Compose to include Stalwart service with health checks. - Introduced new API endpoints for managing mail domains and migration projects. - Enhanced Authentik blueprints for user enrollment and post-migration security. - Updated OAuth handling for Google and Microsoft migration processes. - Improved error handling and response structures in the mail API. - Added integration tests for email claiming and migration workflows.
500 lines
15 KiB
Go
500 lines
15 KiB
Go
package migration
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
"github.com/redis/go-redis/v9"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
|
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
|
)
|
|
|
|
var (
|
|
ErrInviteNotFound = errors.New("migration invite not found")
|
|
ErrInviteClaimed = errors.New("migration invite already claimed")
|
|
ErrEmailMismatch = errors.New("email does not match invite")
|
|
ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active")
|
|
ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain")
|
|
)
|
|
|
|
type Service struct {
|
|
db *pgxpool.Pool
|
|
rdb *redis.Client
|
|
creds *credentials.Manager
|
|
hosted *hosted.Service
|
|
oauth *OAuthService
|
|
cutover CutoverConfig
|
|
}
|
|
|
|
func NewService(db *pgxpool.Pool, rdb *redis.Client, creds *credentials.Manager, hostedSvc *hosted.Service, oauth *OAuthService) *Service {
|
|
return &Service{db: db, rdb: rdb, creds: creds, hosted: hostedSvc, oauth: oauth}
|
|
}
|
|
|
|
func (s *Service) SetCutoverConfig(cfg CutoverConfig) {
|
|
s.cutover = cfg
|
|
}
|
|
|
|
type Project struct {
|
|
ID string `json:"id"`
|
|
DomainID string `json:"domain_id,omitempty"`
|
|
Name string `json:"name"`
|
|
SourceProvider string `json:"source_provider"`
|
|
AuthMode string `json:"auth_mode"`
|
|
Status string `json:"status"`
|
|
CutoverAt *string `json:"cutover_at,omitempty"`
|
|
DeltaMode bool `json:"delta_mode"`
|
|
CreatedAt string `json:"created_at"`
|
|
MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"`
|
|
MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"`
|
|
MicrosoftAdminConsentError string `json:"microsoft_admin_consent_error,omitempty"`
|
|
CutoverDNS *hosted.DNSCheckReport `json:"cutover_dns,omitempty"`
|
|
}
|
|
|
|
type Invite struct {
|
|
ID string `json:"id"`
|
|
ProjectID string `json:"project_id"`
|
|
Email string `json:"email"`
|
|
AlternateEmails []string `json:"alternate_emails,omitempty"`
|
|
Token string `json:"token,omitempty"`
|
|
Status string `json:"status"`
|
|
ClaimedAt *string `json:"claimed_at,omitempty"`
|
|
UserID string `json:"user_id,omitempty"`
|
|
}
|
|
|
|
type Job struct {
|
|
ID string `json:"id"`
|
|
ProjectID string `json:"project_id"`
|
|
UserID string `json:"user_id"`
|
|
Service string `json:"service"`
|
|
Status string `json:"status"`
|
|
CursorJSON map[string]any `json:"cursor_json"`
|
|
StatsJSON map[string]any `json:"stats_json"`
|
|
Error string `json:"error,omitempty"`
|
|
StartedAt *string `json:"started_at,omitempty"`
|
|
CompletedAt *string `json:"completed_at,omitempty"`
|
|
}
|
|
|
|
type UserStatus struct {
|
|
Project Project `json:"project"`
|
|
Invite Invite `json:"invite,omitempty"`
|
|
Jobs []Job `json:"jobs"`
|
|
Onboarding OnboardingHints `json:"onboarding"`
|
|
}
|
|
|
|
func (s *Service) CreateProject(ctx context.Context, name, sourceProvider, domainID, authMode string) (Project, error) {
|
|
name = strings.TrimSpace(name)
|
|
if name == "" {
|
|
return Project{}, fmt.Errorf("project name required")
|
|
}
|
|
sourceProvider = strings.ToLower(strings.TrimSpace(sourceProvider))
|
|
if sourceProvider == "" {
|
|
sourceProvider = "google"
|
|
}
|
|
authMode = NormalizeAuthMode(sourceProvider, authMode)
|
|
sc := newProjectScanner()
|
|
err := s.db.QueryRow(ctx, `
|
|
INSERT INTO migration_projects (name, source_provider, domain_id, auth_mode)
|
|
VALUES ($1, $2, NULLIF($3, '')::uuid, $4)
|
|
RETURNING `+projectSelectSQL("")+`
|
|
`, name, sourceProvider, domainID, authMode).Scan(sc.targets()...)
|
|
return sc.result(), err
|
|
}
|
|
|
|
func (s *Service) ListProjects(ctx context.Context) ([]Project, error) {
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT `+projectSelectSQL("")+`
|
|
FROM migration_projects ORDER BY created_at DESC
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []Project
|
|
for rows.Next() {
|
|
sc := newProjectScanner()
|
|
if err := rows.Scan(sc.targets()...); err != nil {
|
|
return nil, err
|
|
}
|
|
out = append(out, sc.result())
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alternateEmails []string) (Invite, error) {
|
|
email = strings.ToLower(strings.TrimSpace(email))
|
|
if email == "" {
|
|
return Invite{}, fmt.Errorf("email required")
|
|
}
|
|
alternates := normalizeAlternateEmails(email, alternateEmails)
|
|
token, err := hosted.NewInviteToken()
|
|
if err != nil {
|
|
return Invite{}, err
|
|
}
|
|
var row Invite
|
|
err = s.db.QueryRow(ctx, `
|
|
INSERT INTO migration_invites (project_id, email, token, alternate_emails)
|
|
VALUES ($1::uuid, $2, $3, $4)
|
|
RETURNING id::text, project_id::text, email, token, status, claimed_at::text, COALESCE(user_id::text, ''), alternate_emails
|
|
`, projectID, email, token, alternates).Scan(
|
|
&row.ID, &row.ProjectID, &row.Email, &row.Token, &row.Status, &row.ClaimedAt, &row.UserID, &row.AlternateEmails,
|
|
)
|
|
return row, err
|
|
}
|
|
|
|
func normalizeAlternateEmails(inviteEmail string, alternateEmails []string) []string {
|
|
inviteEmail = normalizeInviteEmail(inviteEmail)
|
|
seen := map[string]struct{}{inviteEmail: {}}
|
|
var out []string
|
|
for _, raw := range alternateEmails {
|
|
email := normalizeInviteEmail(raw)
|
|
if email == "" || !isEmailAddress(email) {
|
|
continue
|
|
}
|
|
if _, ok := seen[email]; ok {
|
|
continue
|
|
}
|
|
seen[email] = struct{}{}
|
|
out = append(out, email)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (s *Service) ImportInvites(ctx context.Context, projectID string, emails []string) (int, error) {
|
|
count := 0
|
|
for _, email := range emails {
|
|
email = strings.ToLower(strings.TrimSpace(email))
|
|
if email == "" {
|
|
continue
|
|
}
|
|
if _, err := s.CreateInvite(ctx, projectID, email, nil); err != nil {
|
|
return count, err
|
|
}
|
|
count++
|
|
}
|
|
return count, nil
|
|
}
|
|
|
|
func (s *Service) GetInviteByToken(ctx context.Context, token string) (Invite, Project, error) {
|
|
var inv Invite
|
|
sc := newProjectScanner()
|
|
scanArgs := append([]any{
|
|
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID, &inv.AlternateEmails,
|
|
}, sc.targets()...)
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT i.id::text, i.project_id::text, i.email, i.status, i.claimed_at::text, COALESCE(i.user_id::text, ''), i.alternate_emails,
|
|
`+projectSelectSQL("p")+`
|
|
FROM migration_invites i
|
|
JOIN migration_projects p ON p.id = i.project_id
|
|
WHERE i.token = $1
|
|
`, token).Scan(scanArgs...)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return Invite{}, Project{}, ErrInviteNotFound
|
|
}
|
|
return inv, sc.result(), err
|
|
}
|
|
|
|
func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identity ClaimIdentity, displayName, password string) (UserStatus, error) {
|
|
inv, proj, err := s.GetInviteByToken(ctx, token)
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
if inv.Status == "claimed" {
|
|
return UserStatus{}, ErrInviteClaimed
|
|
}
|
|
|
|
projectDomain := ""
|
|
var hostedDomain *hosted.DomainRow
|
|
if strings.TrimSpace(proj.DomainID) != "" && s.hosted != nil {
|
|
domain, err := s.hosted.GetDomain(ctx, proj.DomainID)
|
|
if err != nil {
|
|
return UserStatus{}, fmt.Errorf("migration domain: %w", err)
|
|
}
|
|
hostedDomain = &domain
|
|
projectDomain = domain.Name
|
|
}
|
|
if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) {
|
|
return UserStatus{}, ErrEmailMismatch
|
|
}
|
|
mailboxEmail := normalizeInviteEmail(inv.Email)
|
|
|
|
tx, err := s.db.Begin(ctx)
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
_, err = tx.Exec(ctx, `
|
|
UPDATE migration_invites
|
|
SET status = 'claimed', claimed_at = NOW(), user_id = $1::uuid
|
|
WHERE id = $2::uuid AND status = 'invited'
|
|
`, userID, inv.ID)
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
|
|
if s.hosted != nil {
|
|
provision := hosted.ProvisionMailboxInput{
|
|
UserID: userID,
|
|
Email: mailboxEmail,
|
|
DisplayName: displayName,
|
|
Password: password,
|
|
QuotaBytes: 0,
|
|
}
|
|
if hostedDomain != nil {
|
|
at := strings.LastIndex(mailboxEmail, "@")
|
|
if at <= 0 || !strings.EqualFold(mailboxEmail[at+1:], hostedDomain.Name) {
|
|
return UserStatus{}, ErrMigrationDomainMismatch
|
|
}
|
|
if hostedDomain.Status != "active" && !hostedDomain.IsPlatformDomain {
|
|
return UserStatus{}, ErrMigrationDomainNotActive
|
|
}
|
|
provision.DomainID = proj.DomainID
|
|
}
|
|
_, err = s.hosted.ProvisionMailbox(ctx, provision)
|
|
if err != nil {
|
|
if errors.Is(err, hosted.ErrDomainNotActive) {
|
|
return UserStatus{}, ErrMigrationDomainNotActive
|
|
}
|
|
if !errors.Is(err, hosted.ErrAddressTaken) {
|
|
return UserStatus{}, err
|
|
}
|
|
}
|
|
}
|
|
|
|
services := []string{"mail", "contacts", "calendar", "drive"}
|
|
for _, svc := range services {
|
|
_, err = tx.Exec(ctx, `
|
|
INSERT INTO migration_jobs (project_id, user_id, service, status)
|
|
VALUES ($1::uuid, $2::uuid, $3, 'pending')
|
|
ON CONFLICT (project_id, user_id, service) DO NOTHING
|
|
`, proj.ID, userID, svc)
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
|
|
return s.GetUserStatus(ctx, userID, proj.ID)
|
|
}
|
|
|
|
func (s *Service) StoreMigrationToken(ctx context.Context, userID, projectID, provider string, token *oauth2.Token, scopes []string) error {
|
|
if s.creds == nil {
|
|
return fmt.Errorf("credential manager not configured")
|
|
}
|
|
payload, err := json.Marshal(map[string]any{
|
|
"access_token": token.AccessToken,
|
|
"refresh_token": token.RefreshToken,
|
|
"expiry": token.Expiry.UTC().Format(time.RFC3339),
|
|
"token_type": token.TokenType,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
enc, err := s.creds.EncryptCredential(credentials.Credential{
|
|
AuthType: credentials.AuthOAuth2,
|
|
AccessToken: token.AccessToken,
|
|
RefreshToken: token.RefreshToken,
|
|
Expiry: token.Expiry,
|
|
OAuthProvider: provider,
|
|
})
|
|
if err != nil {
|
|
_ = payload
|
|
return err
|
|
}
|
|
var expiresAt *time.Time
|
|
if !token.Expiry.IsZero() {
|
|
expiresAt = &token.Expiry
|
|
}
|
|
_, err = s.db.Exec(ctx, `
|
|
INSERT INTO migration_credentials (user_id, project_id, provider, encrypted_token, scopes, expires_at)
|
|
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6)
|
|
ON CONFLICT (user_id, project_id, provider) DO UPDATE SET
|
|
encrypted_token = EXCLUDED.encrypted_token,
|
|
scopes = EXCLUDED.scopes,
|
|
expires_at = EXCLUDED.expires_at,
|
|
revoked_at = NULL
|
|
`, userID, projectID, provider, enc, scopes, expiresAt)
|
|
return err
|
|
}
|
|
|
|
func (s *Service) GetUserStatus(ctx context.Context, userID, projectID string) (UserStatus, error) {
|
|
sc := newProjectScanner()
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT `+projectSelectSQL("")+`
|
|
FROM migration_projects WHERE id = $1::uuid
|
|
`, projectID).Scan(sc.targets()...)
|
|
proj := sc.result()
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
|
|
var inv Invite
|
|
_ = s.db.QueryRow(ctx, `
|
|
SELECT id::text, project_id::text, email, status, claimed_at::text, COALESCE(user_id::text, '')
|
|
FROM migration_invites WHERE project_id = $1::uuid AND user_id = $2::uuid
|
|
`, projectID, userID).Scan(
|
|
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID,
|
|
)
|
|
|
|
jobs, err := s.listJobs(ctx, projectID, userID)
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
return UserStatus{
|
|
Project: proj,
|
|
Invite: inv,
|
|
Jobs: jobs,
|
|
Onboarding: s.BuildOnboardingHints(ctx, userID, proj, inv),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) GetActiveUserStatus(ctx context.Context, userID string) (UserStatus, error) {
|
|
var projectID string
|
|
err := s.db.QueryRow(ctx, `
|
|
SELECT project_id::text FROM migration_invites
|
|
WHERE user_id = $1::uuid AND status = 'claimed'
|
|
ORDER BY claimed_at DESC NULLS LAST LIMIT 1
|
|
`, userID).Scan(&projectID)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return UserStatus{}, nil
|
|
}
|
|
if err != nil {
|
|
return UserStatus{}, err
|
|
}
|
|
return s.GetUserStatus(ctx, userID, projectID)
|
|
}
|
|
|
|
func (s *Service) listJobs(ctx context.Context, projectID, userID string) ([]Job, error) {
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT id::text, project_id::text, user_id::text, service, status,
|
|
cursor_json, stats_json, error, started_at::text, completed_at::text
|
|
FROM migration_jobs
|
|
WHERE project_id = $1::uuid AND user_id = $2::uuid
|
|
ORDER BY service ASC
|
|
`, projectID, userID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var out []Job
|
|
for rows.Next() {
|
|
var row Job
|
|
var cursorRaw, statsRaw []byte
|
|
if err := rows.Scan(
|
|
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
|
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
|
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
|
if row.CursorJSON == nil {
|
|
row.CursorJSON = map[string]any{}
|
|
}
|
|
if row.StatsJSON == nil {
|
|
row.StatsJSON = map[string]any{}
|
|
}
|
|
out = append(out, row)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *Service) PendingJobs(ctx context.Context, limit int) ([]Job, error) {
|
|
if limit <= 0 {
|
|
limit = 10
|
|
}
|
|
rows, err := s.db.Query(ctx, `
|
|
SELECT j.id::text, j.project_id::text, j.user_id::text, j.service, j.status,
|
|
j.cursor_json, j.stats_json, j.error, j.started_at::text, j.completed_at::text
|
|
FROM migration_jobs j
|
|
JOIN migration_projects p ON p.id = j.project_id
|
|
WHERE j.status IN ('pending', 'running')
|
|
AND p.status IN ('active', 'cutover')
|
|
ORDER BY j.updated_at ASC
|
|
LIMIT $1
|
|
`, limit)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
return scanJobs(rows)
|
|
}
|
|
|
|
func scanJobs(rows pgx.Rows) ([]Job, error) {
|
|
var out []Job
|
|
for rows.Next() {
|
|
var row Job
|
|
var cursorRaw, statsRaw []byte
|
|
if err := rows.Scan(
|
|
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
|
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
|
); err != nil {
|
|
return nil, err
|
|
}
|
|
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
|
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
|
if row.CursorJSON == nil {
|
|
row.CursorJSON = map[string]any{}
|
|
}
|
|
if row.StatsJSON == nil {
|
|
row.StatsJSON = map[string]any{}
|
|
}
|
|
out = append(out, row)
|
|
}
|
|
return out, rows.Err()
|
|
}
|
|
|
|
func (s *Service) UpdateJobProgress(ctx context.Context, jobID, status string, cursor, stats map[string]any, jobErr string) error {
|
|
cursorRaw, _ := json.Marshal(cursor)
|
|
statsRaw, _ := json.Marshal(stats)
|
|
_, err := s.db.Exec(ctx, `
|
|
UPDATE migration_jobs SET
|
|
status = $2,
|
|
cursor_json = $3,
|
|
stats_json = $4,
|
|
error = $5,
|
|
started_at = COALESCE(started_at, CASE WHEN $2 = 'running' THEN NOW() ELSE NULL END),
|
|
completed_at = CASE WHEN $2 IN ('completed', 'failed') THEN NOW() ELSE completed_at END,
|
|
updated_at = NOW()
|
|
WHERE id = $1::uuid
|
|
`, jobID, status, cursorRaw, statsRaw, jobErr)
|
|
return err
|
|
}
|
|
|
|
func (s *Service) ActivateProject(ctx context.Context, projectID string) (Project, error) {
|
|
sc := newProjectScanner()
|
|
err := s.db.QueryRow(ctx, `
|
|
UPDATE migration_projects SET status = 'active', updated_at = NOW()
|
|
WHERE id = $1::uuid
|
|
RETURNING `+projectSelectSQL("")+`
|
|
`, projectID).Scan(sc.targets()...)
|
|
return sc.result(), err
|
|
}
|
|
|
|
func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) {
|
|
var userID string
|
|
err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
|
|
return userID, err
|
|
}
|
|
|
|
func randomState() (string, error) {
|
|
b := make([]byte, 24)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|