From 1ffd0817d825911411aaba0aa37a89c6cfc2d3ae Mon Sep 17 00:00:00 2001 From: R3D347HR4Y Date: Sat, 13 Jun 2026 13:11:30 +0200 Subject: [PATCH] feat(migration): enhance migration API with roster and audit export features - Added endpoints for listing and importing migration rosters. - Introduced audit export functionality for migration jobs in CSV and NDJSON formats. - Implemented tenant mismatch validation for Microsoft migration claims. - Enhanced error handling for email claiming and migration processes. - Added integration tests for roster import and claim workflows. --- internal/api/admin/handlers_mail_domains.go | 218 +++++++++++ internal/api/migration/handlers.go | 2 + internal/integrationtest/harness.go | 2 + .../migration/claim_email_test.go | 130 ++++++- .../integrationtest/migration/delta_test.go | 92 +++++ .../migration/migration_test.go | 68 +++- .../migration/provision_unify_test.go | 287 +++++++++++++++ .../integrationtest/migration/roster_test.go | 133 +++++++ internal/integrationtest/oidc.go | 3 + internal/mail/hosted/service.go | 112 ++++++ internal/migration/claim_email_match.go | 35 +- internal/migration/claim_email_match_test.go | 44 ++- internal/migration/drive_delta.go | 36 +- internal/migration/drive_helpers.go | 16 +- internal/migration/drive_import.go | 179 ++++++++- internal/migration/drive_shared.go | 237 ++++++++++++ internal/migration/drive_shared_test.go | 116 ++++++ internal/migration/gmail_import.go | 24 -- internal/migration/graph_import.go | 220 +++++++++-- internal/migration/graph_import_test.go | 96 ++++- internal/migration/import_helpers.go | 44 +++ internal/migration/invite_provision.go | 97 +++++ internal/migration/invite_provision_test.go | 35 ++ internal/migration/job_audit_export.go | 319 ++++++++++++++++ internal/migration/job_audit_export_test.go | 115 ++++++ internal/migration/project_columns.go | 37 +- internal/migration/roster.go | 347 ++++++++++++++++++ internal/migration/roster_test.go | 69 ++++ internal/migration/service.go | 23 +- internal/migration/worker.go | 24 +- internal/nextcloud/upload_streaming.go | 65 ++++ internal/provision/handler.go | 80 ++-- internal/provision/handler_test.go | 52 +++ internal/users/provision.go | 65 +++- internal/users/provision_resolve_test.go | 26 ++ migrations/000048_migration_roster.down.sql | 1 + migrations/000048_migration_roster.up.sql | 17 + .../000049_migration_shared_drives.down.sql | 5 + .../000049_migration_shared_drives.up.sql | 39 ++ 39 files changed, 3335 insertions(+), 175 deletions(-) create mode 100644 internal/integrationtest/migration/provision_unify_test.go create mode 100644 internal/integrationtest/migration/roster_test.go create mode 100644 internal/migration/drive_shared.go create mode 100644 internal/migration/drive_shared_test.go create mode 100644 internal/migration/invite_provision.go create mode 100644 internal/migration/invite_provision_test.go create mode 100644 internal/migration/job_audit_export.go create mode 100644 internal/migration/job_audit_export_test.go create mode 100644 internal/migration/roster.go create mode 100644 internal/migration/roster_test.go create mode 100644 internal/nextcloud/upload_streaming.go create mode 100644 internal/provision/handler_test.go create mode 100644 internal/users/provision_resolve_test.go create mode 100644 migrations/000048_migration_roster.down.sql create mode 100644 migrations/000048_migration_roster.up.sql create mode 100644 migrations/000049_migration_shared_drives.down.sql create mode 100644 migrations/000049_migration_shared_drives.up.sql diff --git a/internal/api/admin/handlers_mail_domains.go b/internal/api/admin/handlers_mail_domains.go index fe65b54..889d384 100644 --- a/internal/api/admin/handlers_mail_domains.go +++ b/internal/api/admin/handlers_mail_domains.go @@ -3,6 +3,7 @@ package admin import ( "encoding/csv" "errors" + "fmt" "io" "net/http" "strings" @@ -41,12 +42,20 @@ func (h *Handler) registerMailAdminRoutes(r chi.Router, read, write func(http.Ha r.With(write).Post("/projects/{projectID}/cutover", h.StartMigrationCutover) r.With(write).Post("/projects/{projectID}/invites", h.CreateMigrationInvite) r.With(write).Post("/projects/{projectID}/invites/import", h.ImportMigrationInvites) + r.With(read).Get("/projects/{projectID}/roster", h.ListMigrationRoster) + r.With(write).Post("/projects/{projectID}/roster", h.ImportMigrationRoster) r.With(read).Get("/projects/{projectID}/jobs", h.ListMigrationProjectJobs) r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit", h.ListMigrationJobAudit) r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/summary", h.MigrationJobAuditSummary) + r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/export", h.ExportMigrationJobAudit) + r.With(read).Get("/projects/{projectID}/audit/export", h.ExportMigrationProjectAudit) r.With(write).Post("/projects/{projectID}/jobs/retry-failed", h.RetryMigrationFailedJobs) r.With(write).Post("/projects/{projectID}/jobs/{jobID}/retry", h.RetryMigrationJob) r.With(write).Post("/projects/{projectID}/jobs/{jobID}/reset-cursor", h.ResetMigrationJobCursor) + r.With(write).Patch("/projects/{projectID}/shared-drive-mode", h.UpdateMigrationSharedDriveMode) + r.With(read).Get("/projects/{projectID}/shared-drives", h.ListMigrationSharedDrives) + r.With(write).Post("/projects/{projectID}/shared-drives/{driveID}/approve", h.ApproveMigrationSharedDrive) + r.With(write).Post("/projects/{projectID}/shared-drives/{driveID}/reject", h.RejectMigrationSharedDrive) r.With(read).Get("/microsoft/admin-consent-url", h.MicrosoftMigrationAdminConsentURL) r.With(read).Get("/microsoft/admin-consents", h.ListMicrosoftAdminConsents) }) @@ -249,6 +258,81 @@ func (h *Handler) ImportMigrationInvites(w http.ResponseWriter, r *http.Request) apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"imported": count}) } +func (h *Handler) ListMigrationRoster(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + rows, err := h.svc.migration.ListRoster(r.Context(), chi.URLParam(r, "projectID")) + if err != nil { + apivalidate.WriteInternal(w, r) + return + } + if rows == nil { + rows = []migr.RosterEntry{} + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"roster": rows}) +} + +func (h *Handler) ImportMigrationRoster(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + projectID := chi.URLParam(r, "projectID") + + var inputs []migr.RosterRowInput + contentType := r.Header.Get("Content-Type") + if strings.Contains(contentType, "multipart/form-data") { + file, _, err := r.FormFile("file") + if err == nil { + defer file.Close() + parsed, err := migr.ParseRosterCSV(file) + if err != nil { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "file", Message: err.Error(), + })) + return + } + inputs = parsed + } + } + if len(inputs) == 0 { + var body struct { + CSV string `json:"csv"` + Rows []migr.RosterRowInput `json:"rows"` + } + if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &body); err != nil { + return + } + if len(body.Rows) > 0 { + inputs = body.Rows + } else if strings.TrimSpace(body.CSV) != "" { + parsed, err := migr.ParseRosterCSV(strings.NewReader(body.CSV)) + if err != nil { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "csv", Message: err.Error(), + })) + return + } + inputs = parsed + } + } + if len(inputs) == 0 { + apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "csv", Message: "roster csv or rows required", + })) + return + } + + result, err := h.svc.migration.ImportRoster(r.Context(), projectID, inputs) + if err != nil { + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + func (h *Handler) MicrosoftMigrationAdminConsentURL(w http.ResponseWriter, r *http.Request) { if h.svc.migration == nil { apivalidate.WriteInternal(w, r) @@ -388,3 +472,137 @@ func (h *Handler) MigrationJobAuditSummary(w http.ResponseWriter, r *http.Reques } apiresponse.WriteJSON(w, http.StatusOK, summary) } + +func (h *Handler) ExportMigrationJobAudit(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + format, verr := validateExportFormat(r.URL.Query().Get("format")) + if verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + projectID := chi.URLParam(r, "projectID") + jobID := chi.URLParam(r, "jobID") + meta, err := h.svc.migration.PrepareJobAuditExport(r.Context(), projectID, jobID, format) + if err != nil { + if strings.Contains(err.Error(), "not found") { + apiresponse.WriteError(w, r, http.StatusNotFound, "migration_job_not_found", err.Error(), nil) + return + } + apivalidate.WriteInternal(w, r) + return + } + + w.Header().Set("Content-Type", meta.ContentType) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, meta.FileName)) + w.WriteHeader(http.StatusOK) + if err := h.svc.migration.WriteJobAuditExport( + r.Context(), + projectID, + jobID, + r.URL.Query().Get("status"), + format, + w, + ); err != nil { + h.logger.Error("export migration job audit", "error", err) + } +} + +func (h *Handler) ExportMigrationProjectAudit(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + format, verr := validateExportFormat(r.URL.Query().Get("format")) + if verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + projectID := chi.URLParam(r, "projectID") + meta, err := h.svc.migration.PrepareProjectAuditExport(r.Context(), projectID, format) + if err != nil { + if strings.Contains(err.Error(), "not found") { + apiresponse.WriteError(w, r, http.StatusNotFound, "migration_project_not_found", err.Error(), nil) + return + } + apivalidate.WriteInternal(w, r) + return + } + + w.Header().Set("Content-Type", meta.ContentType) + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, meta.FileName)) + w.WriteHeader(http.StatusOK) + if err := h.svc.migration.WriteProjectAuditExport( + r.Context(), + projectID, + r.URL.Query().Get("status"), + format, + w, + ); err != nil { + h.logger.Error("export migration project audit", "error", err) + } +} + +type updateSharedDriveModeRequest struct { + Mode string `json:"shared_drive_mode"` +} + +func (h *Handler) UpdateMigrationSharedDriveMode(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + var req updateSharedDriveModeRequest + if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &req); err != nil { + return + } + row, err := h.svc.migration.UpdateSharedDriveMode(r.Context(), chi.URLParam(r, "projectID"), req.Mode) + if err != nil { + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, row) +} + +func (h *Handler) ListMigrationSharedDrives(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + rows, err := h.svc.migration.ListSharedDrives(r.Context(), chi.URLParam(r, "projectID"), r.URL.Query().Get("status")) + if err != nil { + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"shared_drives": rows}) +} + +func (h *Handler) ApproveMigrationSharedDrive(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + row, err := h.svc.migration.ApproveSharedDrive(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "driveID")) + if err != nil { + apiresponse.WriteError(w, r, http.StatusNotFound, "shared_drive_not_found", err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, row) +} + +func (h *Handler) RejectMigrationSharedDrive(w http.ResponseWriter, r *http.Request) { + if h.svc.migration == nil { + apivalidate.WriteInternal(w, r) + return + } + row, err := h.svc.migration.RejectSharedDrive(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "driveID")) + if err != nil { + apiresponse.WriteError(w, r, http.StatusNotFound, "shared_drive_not_found", err.Error(), nil) + return + } + apiresponse.WriteJSON(w, http.StatusOK, row) +} diff --git a/internal/api/migration/handlers.go b/internal/api/migration/handlers.go index 7e25f2b..5dd8e22 100644 --- a/internal/api/migration/handlers.go +++ b/internal/api/migration/handlers.go @@ -162,6 +162,8 @@ func (h *Handler) ClaimInvite(w http.ResponseWriter, r *http.Request) { errCode = "invite_already_claimed" case err == migr.ErrEmailMismatch: errCode = "email_mismatch" + case err == migr.ErrTenantMismatch: + errCode = "tenant_mismatch" case err == migr.ErrMigrationDomainNotActive: errCode = "migration_domain_not_active" case err == migr.ErrMigrationDomainMismatch: diff --git a/internal/integrationtest/harness.go b/internal/integrationtest/harness.go index 3686be8..1ceeb27 100644 --- a/internal/integrationtest/harness.go +++ b/internal/integrationtest/harness.go @@ -215,6 +215,8 @@ func buildTestConfig(env Env, infra *infra, oidc *OIDCServer) *config.Config { MailActiveCredentialKeyID: "v1", MailWebhookSharedSecret: "test-webhook-secret", MailAppURL: "http://localhost:3004", + ProvisionWebhookSecret: "test-provision-secret", + PlatformMailDomain: "ultisuite.local", SearchEngine: "postgres", MeilisearchURL: env.MeilisearchURL, MeilisearchKey: env.MeilisearchKey, diff --git a/internal/integrationtest/migration/claim_email_test.go b/internal/integrationtest/migration/claim_email_test.go index c381514..a673601 100644 --- a/internal/integrationtest/migration/claim_email_test.go +++ b/internal/integrationtest/migration/claim_email_test.go @@ -101,7 +101,7 @@ func TestClaimInviteRejectsEmailMismatch(t *testing.T) { integrationtest.FailUnlessStatus(t, actResp, 200) inviteEmail := "victim-" + uuid.NewString() + "@example.com" - inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{ + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{ "email": inviteEmail, }) integrationtest.FailIf(err, t, "create invite") @@ -127,3 +127,131 @@ func TestClaimInviteRejectsEmailMismatch(t *testing.T) { integrationtest.FailIf(err, t, "claim invite") integrationtest.AssertErrorCode(t, claimResp, 400, "email_mismatch") } + +func TestClaimInviteRejectsMicrosoftTenantMismatch(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Tenant mismatch", + "source_provider": "microsoft", + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + expectedTenant := "11111111-2222-3333-4444-555555555555" + if _, err := h.Pool.Exec(ctx, ` + UPDATE migration_projects SET microsoft_tenant_id = $1 WHERE id = $2::uuid + `, expectedTenant, created.ID); err != nil { + t.Fatalf("set tenant: %v", err) + } + + inviteEmail := "tenant-user-" + uuid.NewString()[:8] + "@example.com" + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{ + "email": inviteEmail, + }) + integrationtest.FailIf(err, t, "create invite") + integrationtest.FailUnlessStatus(t, inviteResp, 201) + + var invite struct { + Token string `json:"token"` + } + integrationtest.DecodeJSON(t, inviteResp, &invite) + + wrongTenantClaims := integrationtest.RegularUser(integrationtest.NewExternalID("tenant-mismatch")) + wrongTenantClaims.Email = inviteEmail + wrongTenantClaims.TID = "99999999-aaaa-bbbb-cccc-dddddddddddd" + wrongTenantClient, err := h.Client(wrongTenantClaims) + integrationtest.FailIf(err, t, "wrong tenant client") + + if _, err := users.EnsureUser(ctx, h.Pool, wrongTenantClaims); err != nil { + t.Fatalf("ensure user: %v", err) + } + + claimResp, err := wrongTenantClient.Post("/api/v1/migration/claim", map[string]string{ + "token": invite.Token, + }) + integrationtest.FailIf(err, t, "claim invite") + integrationtest.AssertErrorCode(t, claimResp, 400, "tenant_mismatch") +} + +func TestClaimInviteGoogleProjectIgnoresTenant(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Google ignores tenant", + "source_provider": "google", + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + if _, err := h.Pool.Exec(ctx, ` + UPDATE migration_projects SET microsoft_tenant_id = $1 WHERE id = $2::uuid + `, "11111111-2222-3333-4444-555555555555", created.ID); err != nil { + t.Fatalf("set tenant: %v", err) + } + + inviteEmail := "google-user-" + uuid.NewString()[:8] + "@example.com" + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{ + "email": inviteEmail, + }) + integrationtest.FailIf(err, t, "create invite") + integrationtest.FailUnlessStatus(t, inviteResp, 201) + + var invite struct { + Token string `json:"token"` + } + integrationtest.DecodeJSON(t, inviteResp, &invite) + + userClaims := integrationtest.RegularUser(integrationtest.NewExternalID("google-tenant-ignore")) + userClaims.Email = inviteEmail + userClaims.TID = "wrong-tenant-id" + userClient, err := h.Client(userClaims) + integrationtest.FailIf(err, t, "user client") + + if _, err := users.EnsureUser(ctx, h.Pool, userClaims); err != nil { + t.Fatalf("ensure user: %v", err) + } + + claimResp, err := userClient.Post("/api/v1/migration/claim", map[string]string{ + "token": invite.Token, + "password": "test-password-123", + }) + integrationtest.FailIf(err, t, "claim invite") + integrationtest.FailUnlessStatus(t, claimResp, 200) +} diff --git a/internal/integrationtest/migration/delta_test.go b/internal/integrationtest/migration/delta_test.go index 68fb304..ad41387 100644 --- a/internal/integrationtest/migration/delta_test.go +++ b/internal/integrationtest/migration/delta_test.go @@ -495,6 +495,98 @@ func TestGraphMailDeltaDeletesRemoved(t *testing.T) { } } +func TestGraphFolderDeltaDeletesRemoved(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("graph-folder-delta"))) + integrationtest.FailIf(err, t, "ensure user") + + var accountID string + err = h.Pool.QueryRow(ctx, ` + INSERT INTO mail_accounts (user_id, email, provider, is_active) + VALUES ($1::uuid, 'graph-folder-delta@test.local', 'hosted', true) + RETURNING id::text + `, userID).Scan(&accountID) + integrationtest.FailIf(err, t, "insert mail account") + + uid := migr.RemoteMessageUIDForTest("msg-folder-removed") + _, err = h.Pool.Exec(ctx, ` + INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, date, snippet, body_text, body_html, flags, labels) + SELECT $1::uuid, f.id, $2, '', 'To delete', '[]', '[]', NOW(), '', '', '', '{}', '{}' + FROM mail_folders f WHERE f.account_id = $1::uuid AND f.remote_name = 'INBOX' LIMIT 1 + `, accountID, uid) + integrationtest.FailIf(err, t, "seed message") + + inboxID := "inbox-folder-id" + sentID := "sent-folder-id" + client := graphRewriteClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/mailFolders") { + _, _ = w.Write([]byte(`{"value":[ + {"id":"` + inboxID + `","displayName":"Inbox","wellKnownName":"inbox"}, + {"id":"` + sentID + `","displayName":"Sent","wellKnownName":"sentitems"} + ]}`)) + return + } + if strings.Contains(r.URL.Path, "/mailFolders/"+inboxID+"/messages/delta") { + _, _ = w.Write([]byte(`{ + "value":[{"id":"msg-folder-removed","@removed":{"reason":"deleted"}}], + "@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/` + inboxID + `/messages/delta?token=inbox-done" + }`)) + return + } + if strings.Contains(r.URL.Path, "/mailFolders/"+sentID+"/messages/delta") { + _, _ = w.Write([]byte(`{ + "value":[], + "@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/` + sentID + `/messages/delta?token=sent-done" + }`)) + return + } + http.NotFound(w, r) + }) + + importer := migr.NewGraphImporter(h.Pool).WithHTTPClient(client).WithBaseURL("https://graph.microsoft.com") + job := &migr.Job{ + UserID: userID, + CursorJSON: map[string]any{ + "graphFolderQueue": []any{inboxID, sentID}, + "folderDeltaLinks": map[string]any{ + inboxID: "https://graph.microsoft.com/v1.0/me/mailFolders/" + inboxID + "/messages/delta?token=inbox-old", + sentID: "https://graph.microsoft.com/v1.0/me/mailFolders/" + sentID + "/messages/delta?token=sent-old", + }, + }, + StatsJSON: map[string]any{}, + } + + for { + var finalStatus string + err = importer.ImportBatch(ctx, job, "token", true, func(status string, cursor, stats map[string]any, jobErr string) error { + if jobErr != "" { + t.Fatalf("import error: %s", jobErr) + } + finalStatus = status + return nil + }) + integrationtest.FailIf(err, t, "import batch") + if finalStatus == "completed" { + break + } + } + + deleted, _ := job.StatsJSON["delta_deleted"].(float64) + if deleted != 1 { + t.Fatalf("delta_deleted = %v, want 1", job.StatsJSON["delta_deleted"]) + } + + var count int + if err := h.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid).Scan(&count); err != nil { + t.Fatalf("count messages: %v", err) + } + if count != 0 { + t.Fatalf("message count = %d, want 0", count) + } +} + func TestGmailHistoryDeltaDeletesMessage(t *testing.T) { h := integrationtest.RequireHarness(t) ctx := context.Background() diff --git a/internal/integrationtest/migration/migration_test.go b/internal/integrationtest/migration/migration_test.go index 2ff47cf..4fb82b9 100644 --- a/internal/integrationtest/migration/migration_test.go +++ b/internal/integrationtest/migration/migration_test.go @@ -4,6 +4,7 @@ package migration_test import ( "context" + "encoding/json" "net/http" "net/http/httptest" "strings" @@ -200,6 +201,38 @@ func TestMigrationInviteClaimFlow(t *testing.T) { t.Fatalf("failed audit items = %+v", failedList.Items) } + csvExportResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit/export?format=csv") + integrationtest.FailIf(err, t, "audit export csv") + integrationtest.FailUnlessStatus(t, csvExportResp, 200) + csvText := string(csvExportResp.Body) + if !strings.Contains(csvText, "item_id,rel_path,status,error,service,timestamp") { + t.Fatalf("csv headers missing: %q", csvText) + } + if !strings.Contains(csvText, "msg-fail") || !strings.Contains(csvText, "upload timeout") { + t.Fatalf("csv body = %q", csvText) + } + + ndExportResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit/export?format=ndjson") + integrationtest.FailIf(err, t, "audit export ndjson") + integrationtest.FailUnlessStatus(t, ndExportResp, 200) + ndBody := string(ndExportResp.Body) + for _, line := range strings.Split(strings.TrimSpace(ndBody), "\n") { + if line == "" { + continue + } + var row struct { + ItemID string `json:"item_id"` + Status string `json:"status"` + Service string `json:"service"` + } + if err := json.Unmarshal([]byte(line), &row); err != nil { + t.Fatalf("ndjson line invalid: %q err=%v", line, err) + } + if row.ItemID == "" || row.Status == "" || row.Service == "" { + t.Fatalf("ndjson row incomplete: %+v", row) + } + } + resetResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/jobs/"+mailJobID+"/reset-cursor", nil) integrationtest.FailIf(err, t, "reset cursor") integrationtest.FailUnlessStatus(t, resetResp, 200) @@ -247,12 +280,16 @@ func TestGraphImportWritesMessages(t *testing.T) { integrationtest.FailIf(err, t, "insert mail account") folderID := "inbox-folder-id" + sentFolderID := "sent-folder-id" messagesListed := false srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { - case strings.Contains(r.URL.Path, "/mailFolders"): - _, _ = w.Write([]byte(`{"value":[{"id":"` + folderID + `","displayName":"Inbox","wellKnownName":"inbox"}]}`)) - case strings.Contains(r.URL.Path, "/messages"): + case strings.HasSuffix(r.URL.Path, "/mailFolders"): + _, _ = w.Write([]byte(`{"value":[ + {"id":"` + folderID + `","displayName":"Inbox","wellKnownName":"inbox"}, + {"id":"` + sentFolderID + `","displayName":"Sent Items","wellKnownName":"sentitems"} + ]}`)) + case strings.Contains(r.URL.Path, "/mailFolders/"+folderID+"/messages"): messagesListed = true _, _ = w.Write([]byte(`{"value":[{ "id":"msg-1", @@ -266,6 +303,8 @@ func TestGraphImportWritesMessages(t *testing.T) { "isRead":true, "internetMessageId":"" }]}`)) + case strings.Contains(r.URL.Path, "/mailFolders/"+sentFolderID+"/messages"): + _, _ = w.Write([]byte(`{"value":[]}`)) default: http.NotFound(w, r) } @@ -278,16 +317,23 @@ func TestGraphImportWritesMessages(t *testing.T) { CursorJSON: map[string]any{}, StatsJSON: map[string]any{}, } - err = importer.ImportBatch(ctx, job, "test-token", false, func(status string, cursor, stats map[string]any, jobErr string) error { - if jobErr != "" { - t.Fatalf("import error: %s", jobErr) + for { + var finalStatus string + err = importer.ImportBatch(ctx, job, "test-token", false, func(status string, cursor, stats map[string]any, jobErr string) error { + if jobErr != "" { + t.Fatalf("import error: %s", jobErr) + } + finalStatus = status + return nil + }) + integrationtest.FailIf(err, t, "import batch") + if finalStatus == "completed" { + break } - if status != "completed" { - t.Fatalf("status = %q, want completed", status) + if finalStatus != "pending" { + t.Fatalf("status = %q, want pending or completed", finalStatus) } - return nil - }) - integrationtest.FailIf(err, t, "import batch") + } if !messagesListed { t.Fatal("graph messages endpoint not called") } diff --git a/internal/integrationtest/migration/provision_unify_test.go b/internal/integrationtest/migration/provision_unify_test.go new file mode 100644 index 0000000..75c8501 --- /dev/null +++ b/internal/integrationtest/migration/provision_unify_test.go @@ -0,0 +1,287 @@ +//go:build integration + +package migration_test + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "testing" + + "github.com/google/uuid" + + "github.com/ultisuite/ulti-backend/internal/integrationtest" + "github.com/ultisuite/ulti-backend/internal/migration" + "github.com/ultisuite/ulti-backend/internal/users" +) + +const testProvisionSecret = "test-provision-secret" + +func postProvision(t *testing.T, h *integrationtest.Harness, body map[string]any) *integrationtest.Response { + t.Helper() + data, err := json.Marshal(body) + integrationtest.FailIf(err, t, "marshal provision body") + req, err := http.NewRequest(http.MethodPost, h.Server.URL+"/internal/provision/user", bytes.NewReader(data)) + integrationtest.FailIf(err, t, "provision request") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Provision-Secret", testProvisionSecret) + resp, err := http.DefaultClient.Do(req) + integrationtest.FailIf(err, t, "provision call") + defer resp.Body.Close() + bodyBytes, err := io.ReadAll(resp.Body) + integrationtest.FailIf(err, t, "read provision response") + return &integrationtest.Response{Status: resp.StatusCode, Body: bodyBytes, Header: resp.Header} +} + +func TestProvisionEnrollThenClaim(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + domainName := "enroll-claim-" + uuid.NewString()[:8] + ".test" + var domainID string + err := h.Pool.QueryRow(ctx, ` + INSERT INTO mail_domains (name, status, is_platform_domain) + VALUES ($1, 'active', false) + RETURNING id::text + `, domainName).Scan(&domainID) + integrationtest.FailIf(err, t, "insert domain") + + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Enroll then claim", + "source_provider": "google", + "domain_id": domainID, + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + migrateeEmail := "user@" + domainName + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{ + "email": migrateeEmail, + }) + integrationtest.FailIf(err, t, "create invite") + integrationtest.FailUnlessStatus(t, inviteResp, 201) + + var invite struct { + Token string `json:"token"` + } + integrationtest.DecodeJSON(t, inviteResp, &invite) + + externalID := integrationtest.NewExternalID("enroll-claim") + provisionResp := postProvision(t, h, map[string]any{ + "email": migrateeEmail, + "name": "Migratee", + "password": "enroll-password-123", + "external_id": externalID, + }) + if provisionResp.Status != 200 { + t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body)) + } + + migrateeClaims := integrationtest.RegularUser(externalID) + migrateeClaims.Email = migrateeEmail + migrateeClient, err := h.Client(migrateeClaims) + integrationtest.FailIf(err, t, "migratee client") + + if _, err := users.EnsureUser(ctx, h.Pool, migrateeClaims); err != nil { + t.Fatalf("ensure migratee: %v", err) + } + + claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{ + "token": invite.Token, + "password": "claim-password-123", + }) + integrationtest.FailIf(err, t, "claim invite") + integrationtest.FailUnlessStatus(t, claimResp, 200) + + audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, migrateeEmail) + integrationtest.FailIf(err, t, "audit provision") + if audit.Users != 1 { + t.Fatalf("users = %d, want 1", audit.Users) + } + if audit.Mailboxes != 1 { + t.Fatalf("mailboxes = %d, want 1", audit.Mailboxes) + } + if audit.MailAccounts != 1 { + t.Fatalf("mail_accounts = %d, want 1", audit.MailAccounts) + } +} + +func TestProvisionClaimThenEnroll(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + domainName := "claim-enroll-" + uuid.NewString()[:8] + ".test" + var domainID string + err := h.Pool.QueryRow(ctx, ` + INSERT INTO mail_domains (name, status, is_platform_domain) + VALUES ($1, 'active', false) + RETURNING id::text + `, domainName).Scan(&domainID) + integrationtest.FailIf(err, t, "insert domain") + + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Claim then enroll", + "source_provider": "google", + "domain_id": domainID, + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + migrateeEmail := "user@" + domainName + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{ + "email": migrateeEmail, + }) + integrationtest.FailIf(err, t, "create invite") + integrationtest.FailUnlessStatus(t, inviteResp, 201) + + var invite struct { + Token string `json:"token"` + } + integrationtest.DecodeJSON(t, inviteResp, &invite) + + externalID := integrationtest.NewExternalID("claim-enroll") + migrateeClaims := integrationtest.RegularUser(externalID) + migrateeClaims.Email = migrateeEmail + migrateeClient, err := h.Client(migrateeClaims) + integrationtest.FailIf(err, t, "migratee client") + + userID, err := users.EnsureUser(ctx, h.Pool, migrateeClaims) + integrationtest.FailIf(err, t, "ensure migratee") + + claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{ + "token": invite.Token, + "password": "claim-password-123", + }) + integrationtest.FailIf(err, t, "claim invite") + integrationtest.FailUnlessStatus(t, claimResp, 200) + + provisionResp := postProvision(t, h, map[string]any{ + "email": migrateeEmail, + "name": "Migratee", + "password": "enroll-password-123", + "external_id": externalID, + }) + if provisionResp.Status != 200 { + t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body)) + } + + var provisionBody struct { + UserID string `json:"user_id"` + } + integrationtest.DecodeJSON(t, provisionResp, &provisionBody) + if provisionBody.UserID != userID { + t.Fatalf("provision user_id = %q, want %q", provisionBody.UserID, userID) + } + + audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, migrateeEmail) + integrationtest.FailIf(err, t, "audit provision") + if audit.Users != 1 { + t.Fatalf("users = %d, want 1", audit.Users) + } + if audit.Mailboxes != 1 { + t.Fatalf("mailboxes = %d, want 1", audit.Mailboxes) + } + if audit.MailAccounts != 1 { + t.Fatalf("mail_accounts = %d, want 1", audit.MailAccounts) + } +} + +func TestProvisionEnrollmentDefersMailboxForPendingInvite(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + platformEmail := "pending-" + uuid.NewString()[:8] + "@ultisuite.local" + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Pending invite defer", + "source_provider": "google", + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{ + "email": platformEmail, + }) + integrationtest.FailIf(err, t, "create invite") + integrationtest.FailUnlessStatus(t, inviteResp, 201) + + externalID := integrationtest.NewExternalID("pending-invite") + provisionResp := postProvision(t, h, map[string]any{ + "email": platformEmail, + "name": "Pending User", + "password": "enroll-password-123", + "external_id": externalID, + }) + if provisionResp.Status != 200 { + t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body)) + } + + var provisionBody struct { + MailboxDeferred bool `json:"mailbox_deferred"` + } + integrationtest.DecodeJSON(t, provisionResp, &provisionBody) + if !provisionBody.MailboxDeferred { + t.Fatal("expected mailbox_deferred=true for pending invite enrollment") + } + + audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, platformEmail) + integrationtest.FailIf(err, t, "audit after deferred enroll") + if audit.Users != 1 { + t.Fatalf("users = %d, want 1", audit.Users) + } + if audit.Mailboxes != 0 { + t.Fatalf("mailboxes = %d, want 0 before claim", audit.Mailboxes) + } +} diff --git a/internal/integrationtest/migration/roster_test.go b/internal/integrationtest/migration/roster_test.go new file mode 100644 index 0000000..4c0ceeb --- /dev/null +++ b/internal/integrationtest/migration/roster_test.go @@ -0,0 +1,133 @@ +//go:build integration + +package migration_test + +import ( + "context" + "testing" + + "github.com/google/uuid" + + "github.com/ultisuite/ulti-backend/internal/integrationtest" + "github.com/ultisuite/ulti-backend/internal/users" +) + +func TestMigrationRosterImportAndClaim(t *testing.T) { + h := integrationtest.RequireHarness(t) + ctx := context.Background() + + adminClient, adminClaims := integrationtest.RequireAdminClient(t, h) + if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil { + t.Fatalf("ensure admin: %v", err) + } + if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil { + t.Fatalf("grant admin: %v", err) + } + + createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{ + "name": "Roster migration", + "source_provider": "google", + }) + integrationtest.FailIf(err, t, "create project") + integrationtest.FailUnlessStatus(t, createResp, 201) + + var created struct { + ID string `json:"id"` + } + integrationtest.DecodeJSON(t, createResp, &created) + + actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil) + integrationtest.FailIf(err, t, "activate project") + integrationtest.FailUnlessStatus(t, actResp, 200) + + csv := "email,display_name,alternate_emails\n" + + "migratee-" + uuid.NewString() + "@example.com,Test User,alt-" + uuid.NewString() + "@example.com\n" + + importResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/roster", map[string]string{ + "csv": csv, + }) + integrationtest.FailIf(err, t, "import roster") + integrationtest.FailUnlessStatus(t, importResp, 200) + + var importResult struct { + Created int `json:"created"` + SkippedDuplicates int `json:"skipped_duplicates"` + } + integrationtest.DecodeJSON(t, importResp, &importResult) + if importResult.Created != 1 { + t.Fatalf("expected 1 created, got %#v", importResult) + } + + listResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/roster") + integrationtest.FailIf(err, t, "list roster") + integrationtest.FailUnlessStatus(t, listResp, 200) + + var rosterList struct { + Roster []struct { + Email string `json:"email"` + DisplayName string `json:"display_name"` + Status string `json:"status"` + } `json:"roster"` + } + integrationtest.DecodeJSON(t, listResp, &rosterList) + if len(rosterList.Roster) != 1 { + t.Fatalf("expected 1 roster entry, got %d", len(rosterList.Roster)) + } + if rosterList.Roster[0].Status != "invited" { + t.Fatalf("expected invited status, got %q", rosterList.Roster[0].Status) + } + migrateeEmail := rosterList.Roster[0].Email + + dupResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/roster", map[string]string{ + "csv": csv, + }) + integrationtest.FailIf(err, t, "duplicate import") + integrationtest.FailUnlessStatus(t, dupResp, 200) + + var dupResult struct { + SkippedDuplicates int `json:"skipped_duplicates"` + } + integrationtest.DecodeJSON(t, dupResp, &dupResult) + if dupResult.SkippedDuplicates != 1 { + t.Fatalf("expected 1 skipped duplicate, got %d", dupResult.SkippedDuplicates) + } + + var inviteToken string + err = h.Pool.QueryRow(ctx, ` + SELECT token FROM migration_invites WHERE project_id = $1::uuid AND email = $2 + `, created.ID, migrateeEmail).Scan(&inviteToken) + if err != nil { + t.Fatalf("lookup invite token: %v", err) + } + if inviteToken == "" { + t.Fatal("missing invite token for roster entry") + } + + migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("roster-migratee")) + migrateeClaims.Email = migrateeEmail + migrateeClient, err := h.Client(migrateeClaims) + integrationtest.FailIf(err, t, "migratee client") + + if _, err := users.EnsureUser(ctx, h.Pool, migrateeClaims); err != nil { + t.Fatalf("ensure migratee: %v", err) + } + + claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{ + "token": inviteToken, + "password": "test-password-123", + "display_name": "Test User", + }) + integrationtest.FailIf(err, t, "claim invite") + integrationtest.FailUnlessStatus(t, claimResp, 200) + + var rosterStatus string + err = h.Pool.QueryRow(ctx, ` + SELECT status FROM migration_roster WHERE project_id = $1::uuid AND email = $2 + `, created.ID, migrateeEmail).Scan(&rosterStatus) + if err != nil { + t.Fatalf("roster status: %v", err) + } + if rosterStatus != "claimed" { + t.Fatalf("expected claimed roster, got %q", rosterStatus) + } +} diff --git a/internal/integrationtest/oidc.go b/internal/integrationtest/oidc.go index f86f309..7aa3a63 100644 --- a/internal/integrationtest/oidc.go +++ b/internal/integrationtest/oidc.go @@ -99,6 +99,9 @@ func (s *OIDCServer) IssueToken(claims *auth.Claims) (string, error) { "name": claims.Name, "groups": claims.Groups, }) + if tid := strings.TrimSpace(claims.TID); tid != "" { + builder = builder.Claims(map[string]any{"tid": tid}) + } return builder.Serialize() } diff --git a/internal/mail/hosted/service.go b/internal/mail/hosted/service.go index cd77623..cb3f97e 100644 --- a/internal/mail/hosted/service.go +++ b/internal/mail/hosted/service.go @@ -266,6 +266,118 @@ type ProvisionMailboxResult struct { MailAccountID string } +// EnsureMailboxProvisioned creates a mailbox or links an existing one to the requested user. +func (s *Service) EnsureMailboxProvisioned(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) { + email := strings.ToLower(strings.TrimSpace(in.Email)) + existing, err := s.lookupMailboxByEmail(ctx, email) + if errors.Is(err, pgx.ErrNoRows) { + return s.ProvisionMailbox(ctx, in) + } + if err != nil { + return ProvisionMailboxResult{}, err + } + return s.reconcileExistingMailbox(ctx, existing, in) +} + +func (s *Service) lookupMailboxByEmail(ctx context.Context, email string) (MailboxRow, error) { + email = strings.ToLower(strings.TrimSpace(email)) + at := strings.LastIndex(email, "@") + if at <= 0 { + return MailboxRow{}, fmt.Errorf("invalid email") + } + localPart := email[:at] + domainName := email[at+1:] + localPart, err := normalizeLocalPart(localPart) + if err != nil { + return MailboxRow{}, err + } + var row MailboxRow + err = s.db.QueryRow(ctx, ` + SELECT mb.id::text, mb.domain_id::text, mb.local_part, + lower(mb.local_part || '@' || md.name), + COALESCE(mb.user_id::text, ''), + COALESCE(mb.mail_account_id::text, ''), + mb.stalwart_account_id, mb.quota_bytes, mb.status + FROM mailboxes mb + JOIN mail_domains md ON md.id = mb.domain_id + WHERE md.name = $1 AND mb.local_part = $2 + `, domainName, localPart).Scan( + &row.ID, &row.DomainID, &row.LocalPart, &row.Email, + &row.UserID, &row.MailAccountID, &row.StalwartAccountID, &row.QuotaBytes, &row.Status, + ) + return row, err +} + +func (s *Service) reconcileExistingMailbox(ctx context.Context, existing MailboxRow, in ProvisionMailboxInput) (ProvisionMailboxResult, error) { + userID := strings.TrimSpace(in.UserID) + if userID != "" && existing.UserID != "" && existing.UserID != userID { + return ProvisionMailboxResult{}, ErrAddressTaken + } + if userID == "" { + userID = existing.UserID + } + mailAccountID := existing.MailAccountID + + if userID != "" && existing.UserID != userID { + if err := s.LinkMailboxToUser(ctx, existing.ID, userID); err != nil { + return ProvisionMailboxResult{}, err + } + existing.UserID = userID + } + + if userID != "" && mailAccountID == "" && strings.TrimSpace(in.Password) != "" { + email := strings.ToLower(strings.TrimSpace(in.Email)) + err := s.db.QueryRow(ctx, ` + SELECT id::text FROM mail_accounts + WHERE user_id = $1::uuid AND lower(email) = $2 + LIMIT 1 + `, userID, email).Scan(&mailAccountID) + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return ProvisionMailboxResult{}, err + } + if errors.Is(err, pgx.ErrNoRows) { + enc, err := s.encryptHostedCredential(email, in.Password) + if err != nil { + return ProvisionMailboxResult{}, err + } + displayName := strings.TrimSpace(in.DisplayName) + if displayName == "" { + displayName = email + } + err = s.db.QueryRow(ctx, ` + INSERT INTO mail_accounts ( + user_id, name, email, provider, + imap_host, imap_port, imap_tls, + smtp_host, smtp_port, smtp_tls, + credentials, is_active + ) + VALUES ($1, $2, $3, 'hosted', $4, $5, $6, $7, $8, $9, $10, true) + RETURNING id::text + `, userID, displayName, email, + s.imapHost, s.imapPort, s.imapTLS, + s.smtpHost, s.smtpPort, s.smtpTLS, + enc, + ).Scan(&mailAccountID) + if err != nil { + return ProvisionMailboxResult{}, err + } + } + _, err = s.db.Exec(ctx, ` + UPDATE mailboxes SET mail_account_id = $1::uuid, updated_at = NOW() + WHERE id = $2::uuid AND (mail_account_id IS NULL OR mail_account_id = $1::uuid) + `, mailAccountID, existing.ID) + if err != nil { + return ProvisionMailboxResult{}, err + } + existing.MailAccountID = mailAccountID + } + + return ProvisionMailboxResult{ + Mailbox: existing, + MailAccountID: mailAccountID, + }, nil +} + func (s *Service) ProvisionMailbox(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) { email := strings.ToLower(strings.TrimSpace(in.Email)) at := strings.LastIndex(email, "@") diff --git a/internal/migration/claim_email_match.go b/internal/migration/claim_email_match.go index 1cbc901..3774d73 100644 --- a/internal/migration/claim_email_match.go +++ b/internal/migration/claim_email_match.go @@ -11,6 +11,7 @@ type ClaimIdentity struct { Email string PreferredUsername string UPN string + TenantID string } func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity { @@ -21,6 +22,7 @@ func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity { Email: c.Email, PreferredUsername: c.PreferredUsername, UPN: c.UPN, + TenantID: c.TID, } } @@ -90,7 +92,17 @@ func inviteMatchTargets(inviteEmail string, alternateEmails []string) []string { return out } -func localPartAliasMatch(a, b string) bool { +func isGmailAliasDomain(domain string) bool { + switch strings.ToLower(strings.TrimSpace(domain)) { + case "gmail.com", "googlemail.com": + return true + default: + return false + } +} + +// gmailLocalPartAliasMatch applies Gmail dot/plus normalization only on gmail.com / googlemail.com. +func gmailLocalPartAliasMatch(a, b string) bool { aLocal, aDomain, okA := emailLocalAndDomain(a) bLocal, bDomain, okB := emailLocalAndDomain(b) if !okA || !okB { @@ -99,6 +111,9 @@ func localPartAliasMatch(a, b string) bool { if !strings.EqualFold(aDomain, bDomain) { return false } + if !isGmailAliasDomain(aDomain) { + return false + } return normalizeEmailLocalPart(aLocal) == normalizeEmailLocalPart(bLocal) } @@ -140,7 +155,7 @@ func InviteEmailMatchesIdentity(inviteEmail string, alternateEmails []string, pr if candidate == target { return true } - if localPartAliasMatch(target, candidate) { + if gmailLocalPartAliasMatch(target, candidate) { return true } } @@ -148,3 +163,19 @@ func InviteEmailMatchesIdentity(inviteEmail string, alternateEmails []string, pr return projectDomainUPNMatch(inviteEmail, projectDomain, identity) } + +// validateMicrosoftTenantClaim rejects claims when the OIDC tid does not match the project's pinned tenant. +func validateMicrosoftTenantClaim(proj Project, tokenTenantID string) error { + if strings.ToLower(strings.TrimSpace(proj.SourceProvider)) != "microsoft" { + return nil + } + expected := strings.TrimSpace(proj.MicrosoftTenantID) + if expected == "" { + return nil + } + got := strings.TrimSpace(tokenTenantID) + if got == "" || !strings.EqualFold(got, expected) { + return ErrTenantMismatch + } + return nil +} diff --git a/internal/migration/claim_email_match_test.go b/internal/migration/claim_email_match_test.go index 623ae8d..b3a92bc 100644 --- a/internal/migration/claim_email_match_test.go +++ b/internal/migration/claim_email_match_test.go @@ -41,16 +41,24 @@ func TestInviteEmailMatchesIdentityGmailDotAlias(t *testing.T) { if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) { t.Fatal("expected exact match baseline") } + id = ClaimIdentity{Email: "a.l.i.c.e.smith@gmail.com"} + if !InviteEmailMatchesIdentity("alice.smith@gmail.com", nil, "", id) { + t.Fatal("expected dot-insensitive local-part match on gmail.com") + } id = ClaimIdentity{Email: "a.l.i.c.e.smith@acme.com"} - if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) { - t.Fatal("expected dot-insensitive local-part match") + if InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) { + t.Fatal("expected reject dot-alias on non-gmail domain") } } func TestInviteEmailMatchesIdentityPlusTag(t *testing.T) { - id := ClaimIdentity{Email: "alice+tag@acme.com"} - if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) { - t.Fatal("expected plus-tag stripped match") + id := ClaimIdentity{Email: "alice+tag@gmail.com"} + if !InviteEmailMatchesIdentity("alice@gmail.com", nil, "", id) { + t.Fatal("expected plus-tag stripped match on gmail.com") + } + id = ClaimIdentity{Email: "alice+tag@acme.com"} + if InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) { + t.Fatal("expected reject plus-tag alias on non-gmail domain") } } @@ -90,3 +98,29 @@ func TestInviteEmailMatchesIdentityIgnoresNonEmailPreferredUsername(t *testing.T t.Fatal("expected reject when preferred_username is not an email") } } + +func TestValidateMicrosoftTenantClaim(t *testing.T) { + msProj := Project{SourceProvider: "microsoft", MicrosoftTenantID: "tenant-abc"} + if err := validateMicrosoftTenantClaim(msProj, "tenant-abc"); err != nil { + t.Fatalf("expected match: %v", err) + } + if err := validateMicrosoftTenantClaim(msProj, "TENANT-ABC"); err != nil { + t.Fatalf("expected case-insensitive match: %v", err) + } + if err := validateMicrosoftTenantClaim(msProj, "other-tenant"); err != ErrTenantMismatch { + t.Fatalf("expected tenant mismatch, got %v", err) + } + if err := validateMicrosoftTenantClaim(msProj, ""); err != ErrTenantMismatch { + t.Fatalf("expected reject empty tid when tenant pinned: %v", err) + } + + googleProj := Project{SourceProvider: "google", MicrosoftTenantID: "tenant-abc"} + if err := validateMicrosoftTenantClaim(googleProj, "wrong"); err != nil { + t.Fatalf("google project should ignore tenant: %v", err) + } + + noTenant := Project{SourceProvider: "microsoft"} + if err := validateMicrosoftTenantClaim(noTenant, "any"); err != nil { + t.Fatalf("expected skip when project tenant unset: %v", err) + } +} diff --git a/internal/migration/drive_delta.go b/internal/migration/drive_delta.go index ae9dab7..4122256 100644 --- a/internal/migration/drive_delta.go +++ b/internal/migration/drive_delta.go @@ -83,8 +83,9 @@ func (d *DriveImporter) importGoogleDriveDelta(ctx context.Context, job *Job, ac return fmt.Errorf("google drive delta token missing") } - listURL := "https://www.googleapis.com/drive/v3/changes?pageSize=100&spaces=drive&includeRemoved=true&fields=" + - url.QueryEscape("nextPageToken,newStartPageToken,changes(fileId,removed,file(id,name,mimeType,size,parents,trashed))") + + listURL := "https://www.googleapis.com/drive/v3/changes?pageSize=100&spaces=drive&includeRemoved=true" + + "&includeItemsFromAllDrives=true&supportsAllDrives=true&fields=" + + url.QueryEscape("nextPageToken,newStartPageToken,changes(fileId,removed,file(id,name,mimeType,size,parents,trashed,driveId))") + "&pageToken=" + url.QueryEscape(pageToken) body, err := apiGet(ctx, d.client, listURL, accessToken) @@ -119,7 +120,7 @@ func (d *DriveImporter) importGoogleDriveDelta(ctx context.Context, job *Job, ac } item := googleFileToDriveItem(*change.File) relPath := d.resolveDriveRelPath(items, item) - if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, item, items, &imported, &exported, &skipped, job.StatsJSON); err != nil { + if err := d.uploadDriveItem(ctx, job, accessToken, ncUserID, root, relPath, item, items, &imported, &exported, &skipped, job.StatsJSON); err != nil { return err } batch++ @@ -160,6 +161,7 @@ type googleDriveFile struct { Size string `json:"size"` Parents []string `json:"parents"` Trashed bool `json:"trashed"` + DriveID string `json:"driveId"` } func googleFileToDriveItem(f googleDriveFile) driveItem { @@ -173,6 +175,7 @@ func googleFileToDriveItem(f googleDriveFile) driveItem { IsFolder: f.MimeType == "application/vnd.google-apps.folder", Size: size, MimeType: f.MimeType, + DriveID: f.DriveID, } if len(f.Parents) > 0 { item.ParentID = f.Parents[0] @@ -186,7 +189,7 @@ func googleFileToDriveItem(f googleDriveFile) driveItem { item.ExportExt = ext item.Name = driveExportFileName(f.Name, ext) } else { - item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media" + item.Download = googleDriveDownloadURL(f.ID, f.DriveID != "") } return item } @@ -232,7 +235,7 @@ func (d *DriveImporter) importMicrosoftDriveDelta(ctx context.Context, job *Job, } driveItem := graphDriveToItem(d.userUPN, item) relPath := d.resolveDriveRelPath(items, driveItem) - if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, driveItem, items, &imported, nil, &skipped, job.StatsJSON); err != nil { + if err := d.uploadDriveItem(ctx, job, accessToken, ncUserID, root, relPath, driveItem, items, &imported, nil, &skipped, job.StatsJSON); err != nil { return err } batch++ @@ -307,8 +310,15 @@ func (d *DriveImporter) resolveDriveRelPath(items *ImportedItemStore, item drive return path.Join(parentRel, sanitizeDrivePath(item.Name)) } -func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUserID, root, relPath string, item driveItem, items *ImportedItemStore, imported, exported, skipped *float64, stats map[string]any) error { +func (d *DriveImporter) uploadDriveItem(ctx context.Context, job *Job, accessToken, ncUserID, root, relPath string, item driveItem, items *ImportedItemStore, imported, exported, skipped *float64, stats map[string]any) error { targetPath := path.Join(root, relPath) + shared := item.DriveID != "" + if d.alreadyImportedShared(item.DriveID, item.ID, shared) { + if skipped != nil { + *skipped++ + } + return items.MarkSkipped(ctx, item.ID, "dedup: shared drive file already imported by project", relPath) + } if item.IsFolder { if err := d.nc.CreateFolder(ctx, ncUserID, targetPath); err != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { @@ -335,7 +345,7 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser } targetPath = path.Join(path.Dir(targetPath), fileName) relPath = path.Join(path.Dir(relPath), fileName) - if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { + if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, 0); err != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { return markErr } @@ -361,13 +371,6 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser } } } else { - if item.Size > maxDriveFileBytes { - if skipped != nil { - *skipped++ - } - reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes) - return items.MarkSkipped(ctx, item.ID, reason, relPath) - } content, contentType, err := d.downloadDriveFile(ctx, accessToken, item) if err != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { @@ -376,7 +379,7 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser incJobStat(stats, "failed") return nil } - if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { + if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, item.Size); err != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { return markErr } @@ -390,6 +393,9 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser if err := items.MarkPath(ctx, item.ID, relPath); err != nil { return err } + if err := d.markSharedImported(ctx, item.DriveID, item.ID, relPath, job.ID, shared); err != nil { + return err + } if imported != nil { *imported++ } diff --git a/internal/migration/drive_helpers.go b/internal/migration/drive_helpers.go index b2b7a9a..ab2a085 100644 --- a/internal/migration/drive_helpers.go +++ b/internal/migration/drive_helpers.go @@ -42,8 +42,10 @@ func driveExportFileName(name, ext string) string { } type driveFolderRef struct { - ID string - Path string + ID string + Path string + DriveID string // Google shared drive ID; empty for My Drive + Shared bool } func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderRef { @@ -56,8 +58,10 @@ func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderR } id, _ := m["id"].(string) p, _ := m["path"].(string) + driveID, _ := m["driveId"].(string) + shared, _ := m["shared"].(bool) if id != "" { - out = append(out, driveFolderRef{ID: id, Path: p}) + out = append(out, driveFolderRef{ID: id, Path: p, DriveID: driveID, Shared: shared}) } } if len(out) == 0 { @@ -72,14 +76,16 @@ func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderR func writeDriveFolderQueue(cursor map[string]any, queue []driveFolderRef) { raw := make([]any, 0, len(queue)) for _, f := range queue { - raw = append(raw, map[string]any{"id": f.ID, "path": f.Path}) + raw = append(raw, map[string]any{ + "id": f.ID, "path": f.Path, "driveId": f.DriveID, "shared": f.Shared, + }) } cursor["folderQueue"] = raw } func enqueueDriveFolder(queue []driveFolderRef, folder driveFolderRef) []driveFolderRef { for _, existing := range queue { - if existing.ID == folder.ID { + if existing.ID == folder.ID && existing.DriveID == folder.DriveID { return queue } } diff --git a/internal/migration/drive_import.go b/internal/migration/drive_import.go index aed8f52..d908498 100644 --- a/internal/migration/drive_import.go +++ b/internal/migration/drive_import.go @@ -18,10 +18,13 @@ import ( const maxDriveFileBytes = 25 * 1024 * 1024 type DriveImporter struct { - db *pgxpool.Pool - nc *nextcloud.Client - client *http.Client - userUPN string + db *pgxpool.Pool + nc *nextcloud.Client + client *http.Client + userUPN string + projectID string + sharedDriveMode string + sharedDedup *SharedDriveItemStore } func NewDriveImporter(db *pgxpool.Pool, nc *nextcloud.Client) *DriveImporter { @@ -40,6 +43,31 @@ func (d *DriveImporter) WithHTTPClient(c *http.Client) *DriveImporter { return d } +func (d *DriveImporter) WithProject(projectID, sharedDriveMode string, dedup *SharedDriveItemStore) *DriveImporter { + d.projectID = strings.TrimSpace(projectID) + d.sharedDriveMode = NormalizeSharedDriveMode(sharedDriveMode) + d.sharedDedup = dedup + return d +} + +func (d *DriveImporter) isSharedDriveDedup(driveID string, shared bool) bool { + return shared && driveID != "" && d.sharedDedup != nil +} + +func (d *DriveImporter) alreadyImportedShared(driveID, sourceID string, shared bool) bool { + if !d.isSharedDriveDedup(driveID, shared) { + return false + } + return d.sharedDedup.Has(driveID, sourceID) +} + +func (d *DriveImporter) markSharedImported(ctx context.Context, driveID, sourceID, relPath, jobID string, shared bool) error { + if !d.isSharedDriveDedup(driveID, shared) { + return nil + } + return d.sharedDedup.MarkImported(ctx, driveID, sourceID, relPath, jobID) +} + func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error { if d.nc == nil { return fmt.Errorf("nextcloud required for drive migration") @@ -59,6 +87,18 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, return err } + if provider == "google" && !jsonBool(job.CursorJSON["sharedDrivesBootstrapped"]) { + if err := d.bootstrapSharedDrives(ctx, job, accessToken); err != nil { + return err + } + job.CursorJSON["sharedDrivesBootstrapped"] = true + } + if provider == "google" { + if err := d.mergeSharedDriveFolders(ctx, job, provider); err != nil { + return err + } + } + if delta && d.hasDriveDeltaCursor(job, provider) { return d.importDriveDelta(ctx, job, accessToken, provider, ncUserID, root, store, update) } @@ -99,6 +139,14 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, if alreadyImported(store, item.ID) { continue } + if d.alreadyImportedShared(current.DriveID, item.ID, current.Shared) { + skipped++ + if err := store.MarkSkipped(ctx, item.ID, "dedup: shared drive file already imported by project", relPathForItem(current, item)); err != nil { + return err + } + batch++ + continue + } relPath := path.Join(current.Path, sanitizeDrivePath(item.Name)) targetPath := path.Join(root, relPath) if item.IsFolder { @@ -113,7 +161,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, if err := store.MarkPath(ctx, item.ID, relPath); err != nil { return err } - queue = enqueueDriveFolder(queue, driveFolderRef{ID: item.ID, Path: relPath}) + queue = enqueueDriveFolder(queue, driveFolderRef{ + ID: item.ID, Path: relPath, DriveID: current.DriveID, Shared: current.Shared, + }) } else { if item.Export { content, contentType, fileName, err := d.downloadGoogleExport(ctx, accessToken, item) @@ -127,7 +177,7 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, } targetPath = path.Join(path.Dir(targetPath), fileName) relPath = path.Join(path.Dir(relPath), fileName) - if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { + if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, 0); err != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { return markErr } @@ -151,15 +201,6 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, } } } else { - if item.Size > maxDriveFileBytes { - skipped++ - reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes) - if err := store.MarkSkipped(ctx, item.ID, reason, relPath); err != nil { - return err - } - batch++ - continue - } content, contentType, err := d.downloadDriveFile(ctx, accessToken, item) if err != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { @@ -169,7 +210,7 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, batch++ continue } - if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { + if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, item.Size); err != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { return markErr } @@ -186,6 +227,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, if err := store.MarkPath(ctx, item.ID, relPath); err != nil { return err } + if err := d.markSharedImported(ctx, current.DriveID, item.ID, relPath, job.ID, current.Shared); err != nil { + return err + } } imported++ batch++ @@ -193,7 +237,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, for _, sub := range subfolders { relPath := path.Join(current.Path, sanitizeDrivePath(sub.Name)) - queue = enqueueDriveFolder(queue, driveFolderRef{ID: sub.ID, Path: relPath}) + queue = enqueueDriveFolder(queue, driveFolderRef{ + ID: sub.ID, Path: relPath, DriveID: current.DriveID, Shared: current.Shared, + }) } writeDriveFolderQueue(job.CursorJSON, queue) @@ -234,6 +280,7 @@ type driveItem struct { Export bool ExportMime string ExportExt string + DriveID string } type driveSubfolder struct { @@ -247,6 +294,7 @@ func (d *DriveImporter) listDriveFolderItems(ctx context.Context, accessToken, p pageToken, _ := cursor["pageToken"].(string) q := url.QueryEscape("'" + folder.ID + "' in parents and trashed=false") listURL := "https://www.googleapis.com/drive/v3/files?pageSize=100&fields=nextPageToken,files(id,name,mimeType,size)&q=" + q + listURL += googleDriveListParams(folder) if pageToken != "" { listURL += "&pageToken=" + url.QueryEscape(pageToken) } @@ -289,7 +337,7 @@ func (d *DriveImporter) listDriveFolderItems(ctx context.Context, accessToken, p item.ExportExt = ext item.Name = driveExportFileName(f.Name, ext) } else { - item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media" + item.Download = googleDriveDownloadURL(f.ID, folder.Shared) } out = append(out, item) } @@ -378,6 +426,9 @@ func (d *DriveImporter) downloadGoogleExport(ctx context.Context, accessToken st url.PathEscape(item.ID), url.QueryEscape(item.ExportMime), ) + if item.DriveID != "" { + exportURL += "&supportsAllDrives=true" + } req, err := http.NewRequestWithContext(ctx, http.MethodGet, exportURL, nil) if err != nil { return nil, "", "", err @@ -403,3 +454,95 @@ func sanitizeDrivePath(name string) string { } return name } + +func relPathForItem(folder driveFolderRef, item driveItem) string { + return path.Join(folder.Path, sanitizeDrivePath(item.Name)) +} + +func jsonBool(v any) bool { + switch t := v.(type) { + case bool: + return t + case float64: + return t != 0 + case string: + return t == "true" || t == "1" + default: + return false + } +} + +func googleDriveListParams(folder driveFolderRef) string { + if folder.Shared && folder.DriveID != "" { + return "&corpora=drive&driveId=" + url.QueryEscape(folder.DriveID) + + "&includeItemsFromAllDrives=true&supportsAllDrives=true" + } + return "&supportsAllDrives=true" +} + +func googleDriveDownloadURL(fileID string, shared bool) string { + u := "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(fileID) + "?alt=media" + if shared { + u += "&supportsAllDrives=true" + } + return u +} + +func (d *DriveImporter) uploadToNextcloud(ctx context.Context, ncUserID, targetPath string, content io.ReadCloser, contentType string, size int64) error { + defer content.Close() + if size > maxDriveFileBytes { + return d.nc.UploadStreaming(ctx, ncUserID, targetPath, content, contentType, size) + } + return d.nc.Upload(ctx, ncUserID, targetPath, content, contentType) +} + +func (d *DriveImporter) bootstrapSharedDrives(ctx context.Context, job *Job, accessToken string) error { + pageToken := "" + for { + listURL := "https://www.googleapis.com/drive/v3/drives?pageSize=100&fields=nextPageToken,drives(id,name)" + if pageToken != "" { + listURL += "&pageToken=" + url.QueryEscape(pageToken) + } + body, err := apiGet(ctx, d.client, listURL, accessToken) + if err != nil { + return err + } + var parsed struct { + Drives []struct { + ID string `json:"id"` + Name string `json:"name"` + } `json:"drives"` + NextPageToken string `json:"nextPageToken"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return err + } + for _, drive := range parsed.Drives { + if err := d.upsertDiscoveredSharedDrive(ctx, job.ProjectID, job.UserID, drive.ID, drive.Name, d.sharedDriveMode); err != nil { + return err + } + } + if parsed.NextPageToken == "" { + break + } + pageToken = parsed.NextPageToken + } + return nil +} + +func (d *DriveImporter) mergeSharedDriveFolders(ctx context.Context, job *Job, provider string) error { + if provider != "google" { + return nil + } + queue := readDriveFolderQueue(job.CursorJSON, provider) + sharedFolders, err := d.loadApprovedSharedDriveFolders(ctx, job.ProjectID) + if err != nil { + return err + } + for _, folder := range sharedFolders { + queue = enqueueDriveFolder(queue, folder) + } + writeDriveFolderQueue(job.CursorJSON, queue) + return nil +} + diff --git a/internal/migration/drive_shared.go b/internal/migration/drive_shared.go new file mode 100644 index 0000000..6283cec --- /dev/null +++ b/internal/migration/drive_shared.go @@ -0,0 +1,237 @@ +package migration + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5/pgxpool" +) + +const ( + SharedDriveModeAuto = "auto" + SharedDriveModeManual = "manual" + + SharedDriveStatusPending = "pending" + SharedDriveStatusApproved = "approved" + SharedDriveStatusRejected = "rejected" +) + +type SharedDrive struct { + ID string `json:"id"` + ProjectID string `json:"project_id"` + DriveID string `json:"drive_id"` + Name string `json:"name"` + Status string `json:"status"` + DiscoveredByUserID *string `json:"discovered_by_user_id,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +func NormalizeSharedDriveMode(mode string) string { + switch mode { + case SharedDriveModeManual: + return SharedDriveModeManual + default: + return SharedDriveModeAuto + } +} + +// SharedDriveItemStore tracks project-level imports for shared drive files (cross-user dedup). +type SharedDriveItemStore struct { + db *pgxpool.Pool + projectID string + done map[string]struct{} // key: driveID + ":" + sourceID +} + +func NewSharedDriveItemStoreMemory() *SharedDriveItemStore { + return &SharedDriveItemStore{done: map[string]struct{}{}} +} + +func LoadSharedDriveItemStore(ctx context.Context, db *pgxpool.Pool, projectID string) (*SharedDriveItemStore, error) { + store := &SharedDriveItemStore{ + db: db, + projectID: projectID, + done: map[string]struct{}{}, + } + if db == nil || projectID == "" { + return store, nil + } + rows, err := db.Query(ctx, ` + SELECT drive_id, source_id + FROM migration_shared_drive_items + WHERE project_id = $1::uuid + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var driveID, sourceID string + if err := rows.Scan(&driveID, &sourceID); err != nil { + return nil, err + } + store.done[sharedDriveItemKey(driveID, sourceID)] = struct{}{} + } + return store, rows.Err() +} + +func sharedDriveItemKey(driveID, sourceID string) string { + return driveID + ":" + sourceID +} + +func (s *SharedDriveItemStore) Has(driveID, sourceID string) bool { + if s == nil || driveID == "" || sourceID == "" { + return false + } + _, ok := s.done[sharedDriveItemKey(driveID, sourceID)] + return ok +} + +func (s *SharedDriveItemStore) MarkImported(ctx context.Context, driveID, sourceID, relPath, jobID string) error { + if driveID == "" || sourceID == "" { + return nil + } + s.done[sharedDriveItemKey(driveID, sourceID)] = struct{}{} + if s.db == nil || s.projectID == "" { + return nil + } + _, err := s.db.Exec(ctx, ` + INSERT INTO migration_shared_drive_items (project_id, drive_id, source_id, rel_path, imported_by_job_id) + VALUES ($1::uuid, $2, $3, $4, NULLIF($5, '')::uuid) + ON CONFLICT (project_id, drive_id, source_id) DO NOTHING + `, s.projectID, driveID, sourceID, relPath, jobID) + return err +} + +func (s *Service) UpdateSharedDriveMode(ctx context.Context, projectID, mode string) (Project, error) { + mode = NormalizeSharedDriveMode(mode) + sc := newProjectScanner() + err := s.db.QueryRow(ctx, ` + UPDATE migration_projects + SET shared_drive_mode = $2, updated_at = NOW() + WHERE id = $1::uuid + RETURNING `+projectSelectSQL("")+` + `, projectID, mode).Scan(sc.targets()...) + return sc.result(), err +} + +func (s *Service) ListSharedDrives(ctx context.Context, projectID, statusFilter string) ([]SharedDrive, error) { + query := ` + SELECT id::text, project_id::text, drive_id, name, status, + NULLIF(discovered_by_user_id::text, ''), created_at::text, updated_at::text + FROM migration_shared_drives + WHERE project_id = $1::uuid + ` + args := []any{projectID} + if statusFilter != "" { + query += ` AND status = $2` + args = append(args, statusFilter) + } + query += ` ORDER BY name ASC, created_at ASC` + + rows, err := s.db.Query(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []SharedDrive + for rows.Next() { + var row SharedDrive + if err := rows.Scan( + &row.ID, &row.ProjectID, &row.DriveID, &row.Name, &row.Status, + &row.DiscoveredByUserID, &row.CreatedAt, &row.UpdatedAt, + ); err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +func (s *Service) SetSharedDriveStatus(ctx context.Context, projectID, driveID, status string) (SharedDrive, error) { + var row SharedDrive + err := s.db.QueryRow(ctx, ` + UPDATE migration_shared_drives + SET status = $3, updated_at = NOW() + WHERE project_id = $1::uuid AND drive_id = $2 + RETURNING id::text, project_id::text, drive_id, name, status, + NULLIF(discovered_by_user_id::text, ''), created_at::text, updated_at::text + `, projectID, driveID, status).Scan( + &row.ID, &row.ProjectID, &row.DriveID, &row.Name, &row.Status, + &row.DiscoveredByUserID, &row.CreatedAt, &row.UpdatedAt, + ) + if err != nil { + return SharedDrive{}, fmt.Errorf("shared drive not found") + } + return row, nil +} + +func (s *Service) ApproveSharedDrive(ctx context.Context, projectID, driveID string) (SharedDrive, error) { + return s.SetSharedDriveStatus(ctx, projectID, driveID, SharedDriveStatusApproved) +} + +func (s *Service) RejectSharedDrive(ctx context.Context, projectID, driveID string) (SharedDrive, error) { + return s.SetSharedDriveStatus(ctx, projectID, driveID, SharedDriveStatusRejected) +} + +func (d *DriveImporter) upsertDiscoveredSharedDrive(ctx context.Context, projectID, userID, driveID, name, mode string) error { + if d.db == nil { + return nil + } + autoApprove := NormalizeSharedDriveMode(mode) == SharedDriveModeAuto + initialStatus := SharedDriveStatusPending + if autoApprove { + initialStatus = SharedDriveStatusApproved + } + _, err := d.db.Exec(ctx, ` + INSERT INTO migration_shared_drives (project_id, drive_id, name, status, discovered_by_user_id) + VALUES ($1::uuid, $2, $3, $4, NULLIF($5, '')::uuid) + ON CONFLICT (project_id, drive_id) DO UPDATE + SET name = COALESCE(NULLIF(EXCLUDED.name, ''), migration_shared_drives.name), + status = CASE + WHEN migration_shared_drives.status = 'rejected' THEN 'rejected' + WHEN migration_shared_drives.status = 'approved' THEN 'approved' + WHEN $6 = 'auto' THEN 'approved' + ELSE migration_shared_drives.status + END, + updated_at = NOW() + `, projectID, driveID, name, initialStatus, userID, NormalizeSharedDriveMode(mode)) + return err +} + +func (d *DriveImporter) loadApprovedSharedDriveFolders(ctx context.Context, projectID string) ([]driveFolderRef, error) { + if d.db == nil { + return nil, nil + } + rows, err := d.db.Query(ctx, ` + SELECT drive_id, name + FROM migration_shared_drives + WHERE project_id = $1::uuid AND status = 'approved' + ORDER BY name ASC + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []driveFolderRef + for rows.Next() { + var id, name string + if err := rows.Scan(&id, &name); err != nil { + return nil, err + } + out = append(out, driveFolderRef{ + ID: id, + Path: pathJoinSharedDrive(name), + DriveID: id, + Shared: true, + }) + } + return out, rows.Err() +} + +func pathJoinSharedDrive(name string) string { + return "Shared Drives/" + sanitizeDrivePath(name) +} + diff --git a/internal/migration/drive_shared_test.go b/internal/migration/drive_shared_test.go new file mode 100644 index 0000000..16e68d9 --- /dev/null +++ b/internal/migration/drive_shared_test.go @@ -0,0 +1,116 @@ +package migration + +import ( + "context" + "net/http" + "strings" + "testing" +) + +func TestSharedDriveItemDedup(t *testing.T) { + store := NewSharedDriveItemStoreMemory() + ctx := context.Background() + if store.Has("drive-1", "file-1") { + t.Fatal("expected miss before mark") + } + if err := store.MarkImported(ctx, "drive-1", "file-1", "Shared Drives/Team/doc.pdf", "job-1"); err != nil { + t.Fatal(err) + } + if !store.Has("drive-1", "file-1") { + t.Fatal("expected hit after mark") + } + if store.Has("drive-1", "file-2") { + t.Fatal("different file should not match") + } + if store.Has("drive-2", "file-1") { + t.Fatal("different drive should not match") + } +} + +func TestDriveImporterSharedDedup(t *testing.T) { + d := &DriveImporter{sharedDedup: NewSharedDriveItemStoreMemory()} + ctx := context.Background() + if err := d.sharedDedup.MarkImported(ctx, "sd-1", "f-1", "path", "job-a"); err != nil { + t.Fatal(err) + } + if !d.alreadyImportedShared("sd-1", "f-1", true) { + t.Fatal("expected shared dedup hit") + } + if d.alreadyImportedShared("sd-1", "f-2", true) { + t.Fatal("expected miss for other file") + } + if d.alreadyImportedShared("", "f-1", false) { + t.Fatal("personal drive should not dedup at project level") + } +} + +func TestBootstrapSharedDrivesDiscovery(t *testing.T) { + client := mockGoogleHTTPClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/drive/v3/drives") { + _, _ = w.Write([]byte(`{"drives":[{"id":"sd-team","name":"Team Drive"}]}`)) + return + } + http.NotFound(w, r) + }) + + d := NewDriveImporter(nil, nil).WithHTTPClient(client).WithProject("proj-1", SharedDriveModeAuto, NewSharedDriveItemStoreMemory()) + job := &Job{ProjectID: "proj-1", UserID: "user-1", CursorJSON: map[string]any{}} + if err := d.bootstrapSharedDrives(context.Background(), job, "token"); err != nil { + t.Fatalf("bootstrap: %v", err) + } +} + +func TestGoogleDriveListParams(t *testing.T) { + myDrive := googleDriveListParams(driveFolderRef{ID: "root"}) + if !strings.Contains(myDrive, "supportsAllDrives=true") { + t.Fatalf("my drive params: %q", myDrive) + } + if strings.Contains(myDrive, "corpora=drive") { + t.Fatalf("my drive should not use corpora=drive: %q", myDrive) + } + + shared := googleDriveListParams(driveFolderRef{ID: "sd-1", DriveID: "sd-1", Shared: true}) + if !strings.Contains(shared, "corpora=drive") || !strings.Contains(shared, "driveId=sd-1") { + t.Fatalf("shared drive params: %q", shared) + } +} + +func TestGoogleDriveDownloadURL(t *testing.T) { + personal := googleDriveDownloadURL("file-1", false) + if strings.Contains(personal, "supportsAllDrives") { + t.Fatalf("personal download: %q", personal) + } + shared := googleDriveDownloadURL("file-1", true) + if !strings.Contains(shared, "supportsAllDrives=true") { + t.Fatalf("shared download: %q", shared) + } +} + +func TestMergeSharedDriveFolders(t *testing.T) { + d := NewDriveImporter(nil, nil) + job := &Job{CursorJSON: map[string]any{}} + if err := d.mergeSharedDriveFolders(context.Background(), job, "google"); err != nil { + t.Fatal(err) + } + queue := readDriveFolderQueue(job.CursorJSON, "google") + if len(queue) != 1 || queue[0].ID != "root" { + t.Fatalf("without db only root queue: %#v", queue) + } + + manualQueue := readDriveFolderQueue(map[string]any{}, "google") + manualQueue = enqueueDriveFolder(manualQueue, driveFolderRef{ + ID: "sd-1", Path: "Shared Drives/Finance", DriveID: "sd-1", Shared: true, + }) + if len(manualQueue) != 2 || !manualQueue[1].Shared { + t.Fatalf("manual enqueue: %#v", manualQueue) + } +} + +func TestNormalizeSharedDriveMode(t *testing.T) { + if got := NormalizeSharedDriveMode("manual"); got != SharedDriveModeManual { + t.Fatalf("got %q", got) + } + if got := NormalizeSharedDriveMode(""); got != SharedDriveModeAuto { + t.Fatalf("default got %q", got) + } +} diff --git a/internal/migration/gmail_import.go b/internal/migration/gmail_import.go index d880619..ec007e7 100644 --- a/internal/migration/gmail_import.go +++ b/internal/migration/gmail_import.go @@ -700,28 +700,4 @@ func truncateRunes(s string, n int) string { return string(r[:n]) } -func LinkHostedMailboxByEmail(ctx context.Context, db *pgxpool.Pool, userID, email string) error { - email = strings.ToLower(strings.TrimSpace(email)) - if email == "" { - return nil - } - _, err := db.Exec(ctx, ` - UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW() - WHERE user_id IS NULL AND lower(local_part || '@' || (SELECT name FROM mail_domains d WHERE d.id = mailboxes.domain_id)) = $2 - `, userID, email) - if err != nil { - return err - } - _, err = db.Exec(ctx, ` - UPDATE mail_accounts ma SET user_id = $1::uuid, updated_at = NOW() - FROM mailboxes mb - JOIN mail_domains md ON md.id = mb.domain_id - WHERE mb.mail_account_id = ma.id - AND ma.user_id IS NULL - AND mb.user_id = $1::uuid - AND lower(mb.local_part || '@' || md.name) = $2 - `, userID, email) - return err -} - var _ = pgx.ErrNoRows diff --git a/internal/migration/graph_import.go b/internal/migration/graph_import.go index 60b5143..283b8dd 100644 --- a/internal/migration/graph_import.go +++ b/internal/migration/graph_import.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "sort" "strings" "time" @@ -134,6 +135,9 @@ func (g *GraphImporter) ImportBatch( } if delta { + if len(graphFolderDeltaLinks(job.CursorJSON)) > 0 { + return g.importFolderDelta(ctx, job, accessToken, accountID, items, update) + } deltaLink, _ := job.CursorJSON["deltaLink"].(string) if deltaLink != "" { more, err := g.importDeltaPage(ctx, job, accessToken, accountID, deltaLink, items) @@ -147,12 +151,36 @@ func (g *GraphImporter) ImportBatch( } } + return g.importFullFolders(ctx, job, accessToken, accountID, items, delta, update) +} + +func (g *GraphImporter) importFullFolders( + ctx context.Context, + job *Job, + accessToken, accountID string, + items *ImportedItemStore, + captureDelta bool, + update func(status string, cursor, stats map[string]any, jobErr string) error, +) error { + queue := g.folderQueue(job.CursorJSON) + folderIndex := int(jsonNumber(job.CursorJSON["folderIndex"])) + if folderIndex >= len(queue) { + if captureDelta { + if err := g.bootstrapFolderDeltaLinks(ctx, accessToken, queue, job.CursorJSON); err != nil { + return err + } + } + job.StatsJSON["phase"] = "imported" + return update("completed", job.CursorJSON, job.StatsJSON, "") + } + + folderID := queue[folderIndex] nextLink, _ := job.CursorJSON["nextLink"].(string) var listURL string if nextLink != "" { listURL = nextLink } else { - listURL = g.graphURL(g.userBase()+"/messages?$top=100&$orderby="+url.QueryEscape("receivedDateTime desc")+"&$select="+graphMessageSelect) + listURL = g.folderMessagesURL(folderID) } body, err := g.apiGet(ctx, listURL, accessToken) @@ -210,15 +238,109 @@ func (g *GraphImporter) ImportBatch( } delete(job.CursorJSON, "nextLink") - if delta { - if listed.DeltaLink != "" { - job.CursorJSON["deltaLink"] = listed.DeltaLink - } else if link, err := g.initDeltaLink(ctx, accessToken); err == nil && link != "" { - job.CursorJSON["deltaLink"] = link + job.CursorJSON["folderIndex"] = float64(folderIndex + 1) + return update("pending", job.CursorJSON, job.StatsJSON, "") +} + +func (g *GraphImporter) importFolderDelta( + ctx context.Context, + job *Job, + accessToken, accountID string, + items *ImportedItemStore, + update func(status string, cursor, stats map[string]any, jobErr string) error, +) error { + queue := g.folderQueue(job.CursorJSON) + folderIndex := int(jsonNumber(job.CursorJSON["folderIndex"])) + if folderIndex >= len(queue) { + job.StatsJSON["phase"] = "delta" + return update("completed", job.CursorJSON, job.StatsJSON, "") + } + + folderID := queue[folderIndex] + deltaLinks := graphFolderDeltaLinks(job.CursorJSON) + deltaLink := deltaLinks[folderID] + if deltaLink == "" { + deltaLink, _ = job.CursorJSON["nextLink"].(string) + } + if deltaLink == "" { + link, err := g.initFolderDeltaLink(ctx, accessToken, folderID) + if err != nil { + return err + } + deltaLink = link + } + + more, err := g.importFolderDeltaPage(ctx, job, accessToken, accountID, folderID, deltaLink, items) + if err != nil { + return err + } + if more { + return update("pending", job.CursorJSON, job.StatsJSON, "") + } + + delete(job.CursorJSON, "nextLink") + job.CursorJSON["folderIndex"] = float64(folderIndex + 1) + return update("pending", job.CursorJSON, job.StatsJSON, "") +} + +func (g *GraphImporter) importFolderDeltaPage( + ctx context.Context, + job *Job, + accessToken, accountID, folderID, deltaLink string, + items *ImportedItemStore, +) (more bool, err error) { + body, err := g.apiGet(ctx, deltaLink, accessToken) + if err != nil { + return false, err + } + var parsed struct { + Value []graphMessage `json:"value"` + NextLink string `json:"@odata.nextLink"` + DeltaLink string `json:"@odata.deltaLink"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return false, err + } + deltaCount, _ := job.StatsJSON["delta_imported"].(float64) + deleted, _ := job.StatsJSON["delta_deleted"].(float64) + for _, msg := range parsed.Value { + if msg.Removed != nil { + if err := g.deleteByGraphID(ctx, accountID, msg.ID); err != nil { + return false, err + } + deleted++ + continue + } + if alreadyImported(items, msg.ID) { + continue + } + ok, err := g.importOne(ctx, accountID, msg) + if err != nil { + if markErr := items.MarkFailed(ctx, msg.ID, err.Error(), ""); markErr != nil { + return false, markErr + } + incJobStat(job.StatsJSON, "failed") + continue + } + if err := items.MarkImported(ctx, msg.ID); err != nil { + return false, err + } + if ok { + deltaCount++ } } - job.StatsJSON["phase"] = "imported" - return update("completed", job.CursorJSON, job.StatsJSON, "") + job.StatsJSON["delta_imported"] = deltaCount + job.StatsJSON["delta_deleted"] = deleted + if parsed.NextLink != "" { + setGraphFolderDeltaLink(job.CursorJSON, folderID, parsed.NextLink) + job.StatsJSON["phase"] = "delta" + return true, nil + } + if parsed.DeltaLink != "" { + setGraphFolderDeltaLink(job.CursorJSON, folderID, parsed.DeltaLink) + } + job.StatsJSON["phase"] = "delta" + return false, nil } func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessToken, accountID, deltaLink string, items *ImportedItemStore) (more bool, err error) { @@ -276,8 +398,16 @@ func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessTok return false, nil } -func (g *GraphImporter) initDeltaLink(ctx context.Context, accessToken string) (string, error) { - body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/messages/delta?$select=id"), accessToken) +func (g *GraphImporter) folderMessagesURL(folderID string) string { + path := g.userBase() + "/mailFolders/" + url.PathEscape(folderID) + "/messages" + + "?$top=100&$orderby=" + url.QueryEscape("receivedDateTime desc") + + "&$select=" + graphMessageSelect + return g.graphURL(path) +} + +func (g *GraphImporter) initFolderDeltaLink(ctx context.Context, accessToken, folderID string) (string, error) { + path := g.userBase() + "/mailFolders/" + url.PathEscape(folderID) + "/messages/delta?$select=id" + body, err := g.apiGet(ctx, g.graphURL(path), accessToken) if err != nil { return "", err } @@ -294,6 +424,37 @@ func (g *GraphImporter) initDeltaLink(ctx context.Context, accessToken string) ( return parsed.NextLink, nil } +func (g *GraphImporter) bootstrapFolderDeltaLinks(ctx context.Context, accessToken string, queue []string, cursor map[string]any) error { + for _, folderID := range queue { + if graphFolderDeltaLinks(cursor)[folderID] != "" { + continue + } + link, err := g.initFolderDeltaLink(ctx, accessToken, folderID) + if err != nil { + return err + } + if link != "" { + setGraphFolderDeltaLink(cursor, folderID, link) + } + } + delete(cursor, "deltaLink") + delete(cursor, "folderIndex") + return nil +} + +func (g *GraphImporter) folderQueue(cursor map[string]any) []string { + if queue := readGraphFolderQueue(cursor); len(queue) > 0 { + return queue + } + ids := make([]string, 0, len(g.folders)) + for id := range g.folders { + ids = append(ids, id) + } + sort.Strings(ids) + writeGraphFolderQueue(cursor, ids) + return ids +} + func (g *GraphImporter) importOne(ctx context.Context, accountID string, msg graphMessage) (bool, error) { meta := g.folders[msg.ParentFolderID] if meta.RemoteName == "" { @@ -391,23 +552,28 @@ func (g *GraphImporter) ensureGraphFolders(ctx context.Context, accessToken stri if len(g.folders) > 0 { return nil } - body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/mailFolders?$top=100&$select=id,displayName,wellKnownName"), accessToken) - if err != nil { - return err - } - var parsed struct { - Value []struct { - ID string `json:"id"` - DisplayName string `json:"displayName"` - WellKnownName string `json:"wellKnownName"` - } `json:"value"` - } - if err := json.Unmarshal(body, &parsed); err != nil { - return err - } - for _, f := range parsed.Value { - remote, ftype := graphWellKnownFolder(f.WellKnownName, f.DisplayName) - g.folders[f.ID] = graphFolderMeta{RemoteName: remote, FolderType: ftype} + listURL := g.graphURL(g.userBase() + "/mailFolders?$top=100&$select=id,displayName,wellKnownName") + for listURL != "" { + body, err := g.apiGet(ctx, listURL, accessToken) + if err != nil { + return err + } + var parsed struct { + Value []struct { + ID string `json:"id"` + DisplayName string `json:"displayName"` + WellKnownName string `json:"wellKnownName"` + } `json:"value"` + NextLink string `json:"@odata.nextLink"` + } + if err := json.Unmarshal(body, &parsed); err != nil { + return err + } + for _, f := range parsed.Value { + remote, ftype := graphWellKnownFolder(f.WellKnownName, f.DisplayName) + g.folders[f.ID] = graphFolderMeta{RemoteName: remote, FolderType: ftype} + } + listURL = parsed.NextLink } return nil } diff --git a/internal/migration/graph_import_test.go b/internal/migration/graph_import_test.go index ac9ea3b..129488f 100644 --- a/internal/migration/graph_import_test.go +++ b/internal/migration/graph_import_test.go @@ -1,6 +1,11 @@ package migration -import "testing" +import ( + "context" + "net/http" + "strings" + "testing" +) func TestGraphWellKnownFolder(t *testing.T) { remote, ftype := graphWellKnownFolder("inbox", "Inbox") @@ -47,3 +52,92 @@ func TestRemoteMessageUIDMatchesGmailUID(t *testing.T) { t.Fatal("uid helpers diverged") } } + +func TestGraphFolderQueueSortedAndCached(t *testing.T) { + g := NewGraphImporter(nil) + g.folders = map[string]graphFolderMeta{ + "sent-folder": {RemoteName: "SENT", FolderType: "sent"}, + "inbox-folder": {RemoteName: "INBOX", FolderType: "inbox"}, + } + cursor := map[string]any{} + queue := g.folderQueue(cursor) + if len(queue) != 2 { + t.Fatalf("queue len = %d", len(queue)) + } + if queue[0] != "inbox-folder" || queue[1] != "sent-folder" { + t.Fatalf("queue order = %v", queue) + } + cached := readGraphFolderQueue(cursor) + if len(cached) != 2 || cached[0] != "inbox-folder" { + t.Fatalf("cached queue = %v", cached) + } +} + +func TestGraphFolderMessagesURLUsesMailFoldersPath(t *testing.T) { + g := NewGraphImporter(nil).WithBaseURL("https://graph.test") + listURL := g.folderMessagesURL("folder-abc") + if !strings.Contains(listURL, "/mailFolders/folder-abc/messages") { + t.Fatalf("url = %q", listURL) + } + if strings.Contains(listURL, "/me/messages") { + t.Fatalf("flat messages path should not be used: %q", listURL) + } +} + +func TestGraphEnsureFoldersPaginates(t *testing.T) { + pages := 0 + client := mockGraphHTTPClient(t, func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(r.URL.Path, "/mailFolders") { + http.NotFound(w, r) + return + } + pages++ + if pages == 1 { + _, _ = w.Write([]byte(`{ + "value":[{"id":"inbox-id","displayName":"Inbox","wellKnownName":"inbox"}], + "@odata.nextLink":"https://graph.microsoft.com/v1.0/me/mailFolders?$top=100&$skip=100" + }`)) + return + } + _, _ = w.Write([]byte(`{"value":[{"id":"sent-id","displayName":"Sent","wellKnownName":"sentitems"}]}`)) + }) + + g := NewGraphImporter(nil).WithHTTPClient(client) + if err := g.ensureGraphFolders(context.Background(), "token"); err != nil { + t.Fatalf("ensure folders: %v", err) + } + if pages != 2 { + t.Fatalf("pages = %d, want 2", pages) + } + if len(g.folders) != 2 { + t.Fatalf("folders = %d", len(g.folders)) + } +} + +func TestGraphInitFolderDeltaLink(t *testing.T) { + client := mockGraphHTTPClient(t, func(w http.ResponseWriter, r *http.Request) { + if strings.Contains(r.URL.Path, "/mailFolders/inbox-id/messages/delta") { + _, _ = w.Write([]byte(`{"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/inbox-id/messages/delta?token=done"}`)) + return + } + http.NotFound(w, r) + }) + + g := NewGraphImporter(nil).WithHTTPClient(client) + link, err := g.initFolderDeltaLink(context.Background(), "token", "inbox-id") + if err != nil { + t.Fatalf("init delta: %v", err) + } + if !strings.Contains(link, "/mailFolders/inbox-id/messages/delta") { + t.Fatalf("delta link = %q", link) + } +} + +func TestGraphFolderDeltaLinkHelpers(t *testing.T) { + cursor := map[string]any{} + setGraphFolderDeltaLink(cursor, "inbox-id", "https://delta/inbox") + links := graphFolderDeltaLinks(cursor) + if links["inbox-id"] != "https://delta/inbox" { + t.Fatalf("links = %v", links) + } +} diff --git a/internal/migration/import_helpers.go b/internal/migration/import_helpers.go index 127a9fd..63b9446 100644 --- a/internal/migration/import_helpers.go +++ b/internal/migration/import_helpers.go @@ -107,6 +107,50 @@ func setCalendarDeltaLink(cursor map[string]any, calID, link string) { raw[calID] = link } +func graphFolderDeltaLinks(cursor map[string]any) map[string]string { + raw, _ := cursor["folderDeltaLinks"].(map[string]any) + out := make(map[string]string, len(raw)) + for k, v := range raw { + if s, ok := v.(string); ok && s != "" { + out[k] = s + } + } + return out +} + +func setGraphFolderDeltaLink(cursor map[string]any, folderID, link string) { + if folderID == "" || link == "" { + return + } + raw, _ := cursor["folderDeltaLinks"].(map[string]any) + if raw == nil { + raw = map[string]any{} + cursor["folderDeltaLinks"] = raw + } + raw[folderID] = link +} + +func readGraphFolderQueue(cursor map[string]any) []string { + raw, _ := cursor["graphFolderQueue"].([]any) + out := make([]string, 0, len(raw)) + for _, v := range raw { + if s, ok := v.(string); ok && s != "" { + out = append(out, s) + } + } + return out +} + +func writeGraphFolderQueue(cursor map[string]any, ids []string) { + queue := make([]any, 0, len(ids)) + for _, id := range ids { + if id != "" { + queue = append(queue, id) + } + } + cursor["graphFolderQueue"] = queue +} + func migrationContactPath(bookPath, provider, sourceID string) string { uid := sanitizeMigrationUID(provider, sourceID) return bookPath + uid + ".vcf" diff --git a/internal/migration/invite_provision.go b/internal/migration/invite_provision.go new file mode 100644 index 0000000..a061008 --- /dev/null +++ b/internal/migration/invite_provision.go @@ -0,0 +1,97 @@ +package migration + +import ( + "context" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" +) + +// HasPendingMigrationInvite reports whether an unclaimed invite exists for the email. +func HasPendingMigrationInvite(ctx context.Context, db *pgxpool.Pool, email string) (bool, error) { + if db == nil { + return false, nil + } + email = normalizeInviteEmail(email) + if email == "" { + return false, nil + } + var exists bool + err := db.QueryRow(ctx, ` + SELECT EXISTS( + SELECT 1 FROM migration_invites + WHERE status = 'invited' AND lower(email) = lower($1) + ) + `, email).Scan(&exists) + return exists, err +} + +// ProvisionAudit counts user-linked provision artifacts for test verification. +type ProvisionAudit struct { + Users int + Mailboxes int + MailAccounts int + NCPrincipals int +} + +// AuditProvisionByEmail counts rows tied to an email across users, mailboxes, mail accounts, and Nextcloud credentials. +func AuditProvisionByEmail(ctx context.Context, db *pgxpool.Pool, email string) (ProvisionAudit, error) { + var audit ProvisionAudit + if db == nil { + return audit, nil + } + email = strings.ToLower(strings.TrimSpace(email)) + if email == "" { + return audit, nil + } + + if err := db.QueryRow(ctx, ` + SELECT COUNT(*) FROM users WHERE lower(email) = $1 + `, email).Scan(&audit.Users); err != nil { + return audit, err + } + if err := db.QueryRow(ctx, ` + SELECT COUNT(*) + FROM mailboxes mb + JOIN mail_domains md ON md.id = mb.domain_id + WHERE lower(mb.local_part || '@' || md.name) = $1 + `, email).Scan(&audit.Mailboxes); err != nil { + return audit, err + } + if err := db.QueryRow(ctx, ` + SELECT COUNT(*) FROM mail_accounts WHERE lower(email) = $1 + `, email).Scan(&audit.MailAccounts); err != nil { + return audit, err + } + if err := db.QueryRow(ctx, ` + SELECT COUNT(*) FROM nextcloud_dav_credentials WHERE nc_user_id = $1 + `, email).Scan(&audit.NCPrincipals); err != nil { + return audit, err + } + return audit, nil +} + +// LinkHostedMailboxByEmail attaches orphan mailboxes/mail_accounts (e.g. from claim-before-enroll) to a user. +func LinkHostedMailboxByEmail(ctx context.Context, db *pgxpool.Pool, userID, email string) error { + email = strings.ToLower(strings.TrimSpace(email)) + if email == "" { + return nil + } + _, err := db.Exec(ctx, ` + UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW() + WHERE user_id IS NULL AND lower(local_part || '@' || (SELECT name FROM mail_domains d WHERE d.id = mailboxes.domain_id)) = $2 + `, userID, email) + if err != nil { + return err + } + _, err = db.Exec(ctx, ` + UPDATE mail_accounts ma SET user_id = $1::uuid, updated_at = NOW() + FROM mailboxes mb + JOIN mail_domains md ON md.id = mb.domain_id + WHERE mb.mail_account_id = ma.id + AND ma.user_id IS NULL + AND mb.user_id = $1::uuid + AND lower(mb.local_part || '@' || md.name) = $2 + `, userID, email) + return err +} diff --git a/internal/migration/invite_provision_test.go b/internal/migration/invite_provision_test.go new file mode 100644 index 0000000..cd44a40 --- /dev/null +++ b/internal/migration/invite_provision_test.go @@ -0,0 +1,35 @@ +package migration + +import ( + "testing" +) + +func TestHasPendingMigrationInviteNilDB(t *testing.T) { + ok, err := HasPendingMigrationInvite(t.Context(), nil, "user@example.com") + if err != nil { + t.Fatalf("HasPendingMigrationInvite() error = %v", err) + } + if ok { + t.Fatal("expected false with nil db") + } +} + +func TestAuditProvisionByEmailEmpty(t *testing.T) { + audit, err := AuditProvisionByEmail(t.Context(), nil, "") + if err != nil { + t.Fatalf("AuditProvisionByEmail() error = %v", err) + } + if audit.Users != 0 || audit.Mailboxes != 0 || audit.MailAccounts != 0 || audit.NCPrincipals != 0 { + t.Fatalf("expected zero audit, got %#v", audit) + } +} + +func TestAuditProvisionByEmailNormalizesEmail(t *testing.T) { + audit, err := AuditProvisionByEmail(t.Context(), nil, " User@Example.COM ") + if err != nil { + t.Fatalf("AuditProvisionByEmail() error = %v", err) + } + if audit.Users != 0 { + t.Fatalf("expected zero users with nil db, got %d", audit.Users) + } +} diff --git a/internal/migration/job_audit_export.go b/internal/migration/job_audit_export.go new file mode 100644 index 0000000..b0b4699 --- /dev/null +++ b/internal/migration/job_audit_export.go @@ -0,0 +1,319 @@ +package migration + +import ( + "context" + "encoding/csv" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +const jobAuditExportFlushEvery = 100 + +var jobAuditCSVHeaders = []string{"item_id", "rel_path", "status", "error", "service", "timestamp"} + +var projectAuditCSVHeaders = []string{"job_id", "item_id", "rel_path", "status", "error", "service", "timestamp"} + +// JobAuditExportMeta describes a migration job audit export download. +type JobAuditExportMeta struct { + ContentType string + FileName string +} + +// JobAuditExportRow is one exported audit line. +type JobAuditExportRow struct { + JobID string `json:"job_id,omitempty"` + ItemID string `json:"item_id"` + RelPath string `json:"rel_path,omitempty"` + Status string `json:"status"` + Error string `json:"error,omitempty"` + Service string `json:"service"` + Timestamp string `json:"timestamp"` +} + +// PrepareJobAuditExport validates the job belongs to the project and returns download metadata. +func (s *Service) PrepareJobAuditExport(ctx context.Context, projectID, jobID, format string) (JobAuditExportMeta, error) { + if _, err := s.verifyJobInProject(ctx, projectID, jobID); err != nil { + return JobAuditExportMeta{}, err + } + return jobAuditExportMeta(format, jobID, false), nil +} + +// PrepareProjectAuditExport validates the project and returns download metadata. +func (s *Service) PrepareProjectAuditExport(ctx context.Context, projectID, format string) (JobAuditExportMeta, error) { + if err := s.verifyProjectExists(ctx, projectID); err != nil { + return JobAuditExportMeta{}, err + } + return jobAuditExportMeta(format, projectID, true), nil +} + +// WriteJobAuditExport streams audit rows for one job to w. Call PrepareJobAuditExport first. +func (s *Service) WriteJobAuditExport( + ctx context.Context, + projectID, jobID, statusFilter, format string, + w io.Writer, +) error { + service, err := s.verifyJobInProject(ctx, projectID, jobID) + if err != nil { + return err + } + + statusFilter = normalizeAuditStatusFilter(statusFilter) + + listSQL := ` + SELECT source_id, rel_path, status, reason, imported_at::text + FROM migration_imported_items + WHERE job_id = $1::uuid + ` + listArgs := []any{jobID} + if statusFilter != "" { + listSQL += ` AND status = $2` + listArgs = append(listArgs, statusFilter) + } + listSQL += ` ORDER BY imported_at DESC, source_id ASC` + + rows, err := s.db.Query(ctx, listSQL, listArgs...) + if err != nil { + return err + } + defer rows.Close() + + if err := streamJobAuditRows(format, w, service, jobID, rows); err != nil { + return err + } + return rows.Err() +} + +// WriteProjectAuditExport streams audit rows for all jobs in a project to w. +func (s *Service) WriteProjectAuditExport( + ctx context.Context, + projectID, statusFilter, format string, + w io.Writer, +) error { + if err := s.verifyProjectExists(ctx, projectID); err != nil { + return err + } + + statusFilter = normalizeAuditStatusFilter(statusFilter) + + listSQL := ` + SELECT j.id::text, j.service, i.source_id, i.rel_path, i.status, i.reason, i.imported_at::text + FROM migration_imported_items i + JOIN migration_jobs j ON j.id = i.job_id + WHERE j.project_id = $1::uuid + ` + listArgs := []any{projectID} + if statusFilter != "" { + listSQL += ` AND i.status = $2` + listArgs = append(listArgs, statusFilter) + } + listSQL += ` ORDER BY i.imported_at DESC, i.source_id ASC` + + rows, err := s.db.Query(ctx, listSQL, listArgs...) + if err != nil { + return err + } + defer rows.Close() + + if err := streamProjectAuditRows(format, w, rows); err != nil { + return err + } + return rows.Err() +} + +func (s *Service) verifyProjectExists(ctx context.Context, projectID string) error { + var exists bool + err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM migration_projects WHERE id = $1::uuid)`, projectID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return fmt.Errorf("project not found") + } + return nil +} + +func jobAuditExportMeta(format, id string, projectLevel bool) JobAuditExportMeta { + now := time.Now().UTC().Format("20060102T150405Z") + shortID := id + if len(shortID) > 8 { + shortID = shortID[:8] + } + prefix := "migration-job-audit" + if projectLevel { + prefix = "migration-project-audit" + } + ext := "ndjson" + contentType := "application/x-ndjson; charset=utf-8" + if format == "csv" { + ext = "csv" + contentType = "text/csv; charset=utf-8" + } + return JobAuditExportMeta{ + ContentType: contentType, + FileName: fmt.Sprintf("%s-%s-%s.%s", prefix, shortID, now, ext), + } +} + +type jobAuditRowScanner interface { + Next() bool + Scan(dest ...any) error + Err() error +} + +func streamJobAuditRows(format string, w io.Writer, service, jobID string, rows jobAuditRowScanner) error { + flusher, _ := w.(http.Flusher) + switch format { + case "csv": + cw := csv.NewWriter(w) + if err := cw.Write(jobAuditCSVHeaders); err != nil { + return err + } + count := 0 + for rows.Next() { + var itemID, relPath, status, reason, importedAt string + if err := rows.Scan(&itemID, &relPath, &status, &reason, &importedAt); err != nil { + return err + } + if err := writeJobAuditCSVRow(cw, JobAuditExportRow{ + JobID: jobID, + ItemID: itemID, + RelPath: relPath, + Status: status, + Error: reason, + Service: service, + Timestamp: importedAt, + }); err != nil { + return err + } + count++ + if count%jobAuditExportFlushEvery == 0 { + cw.Flush() + if err := cw.Error(); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + } + } + cw.Flush() + return cw.Error() + default: + enc := json.NewEncoder(w) + count := 0 + for rows.Next() { + var itemID, relPath, status, reason, importedAt string + if err := rows.Scan(&itemID, &relPath, &status, &reason, &importedAt); err != nil { + return err + } + if err := enc.Encode(JobAuditExportRow{ + JobID: jobID, + ItemID: itemID, + RelPath: relPath, + Status: status, + Error: reason, + Service: service, + Timestamp: importedAt, + }); err != nil { + return err + } + count++ + if count%jobAuditExportFlushEvery == 0 && flusher != nil { + flusher.Flush() + } + } + return nil + } +} + +func streamProjectAuditRows(format string, w io.Writer, rows jobAuditRowScanner) error { + flusher, _ := w.(http.Flusher) + switch format { + case "csv": + cw := csv.NewWriter(w) + if err := cw.Write(projectAuditCSVHeaders); err != nil { + return err + } + count := 0 + for rows.Next() { + var jobID, service, itemID, relPath, status, reason, importedAt string + if err := rows.Scan(&jobID, &service, &itemID, &relPath, &status, &reason, &importedAt); err != nil { + return err + } + if err := writeProjectAuditCSVRow(cw, JobAuditExportRow{ + JobID: jobID, + ItemID: itemID, + RelPath: relPath, + Status: status, + Error: reason, + Service: service, + Timestamp: importedAt, + }); err != nil { + return err + } + count++ + if count%jobAuditExportFlushEvery == 0 { + cw.Flush() + if err := cw.Error(); err != nil { + return err + } + if flusher != nil { + flusher.Flush() + } + } + } + cw.Flush() + return cw.Error() + default: + enc := json.NewEncoder(w) + count := 0 + for rows.Next() { + var jobID, service, itemID, relPath, status, reason, importedAt string + if err := rows.Scan(&jobID, &service, &itemID, &relPath, &status, &reason, &importedAt); err != nil { + return err + } + if err := enc.Encode(JobAuditExportRow{ + JobID: jobID, + ItemID: itemID, + RelPath: relPath, + Status: status, + Error: reason, + Service: service, + Timestamp: importedAt, + }); err != nil { + return err + } + count++ + if count%jobAuditExportFlushEvery == 0 && flusher != nil { + flusher.Flush() + } + } + return nil + } +} + +func writeJobAuditCSVRow(w *csv.Writer, row JobAuditExportRow) error { + return w.Write([]string{ + row.ItemID, + row.RelPath, + row.Status, + row.Error, + row.Service, + row.Timestamp, + }) +} + +func writeProjectAuditCSVRow(w *csv.Writer, row JobAuditExportRow) error { + return w.Write([]string{ + row.JobID, + row.ItemID, + row.RelPath, + row.Status, + row.Error, + row.Service, + row.Timestamp, + }) +} diff --git a/internal/migration/job_audit_export_test.go b/internal/migration/job_audit_export_test.go new file mode 100644 index 0000000..d2bf0ba --- /dev/null +++ b/internal/migration/job_audit_export_test.go @@ -0,0 +1,115 @@ +package migration + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "strings" + "testing" +) + +func TestJobAuditExportCSVFormat(t *testing.T) { + var buf bytes.Buffer + cw := csv.NewWriter(&buf) + if err := cw.Write(jobAuditCSVHeaders); err != nil { + t.Fatal(err) + } + row := JobAuditExportRow{ + ItemID: "msg-fail", + RelPath: "Inbox/foo.eml", + Status: ItemStatusFailed, + Error: "upload timeout", + Service: "mail", + Timestamp: "2026-06-13T12:00:00Z", + } + if err := writeJobAuditCSVRow(cw, row); err != nil { + t.Fatal(err) + } + cw.Flush() + if err := cw.Error(); err != nil { + t.Fatal(err) + } + + reader := csv.NewReader(strings.NewReader(buf.String())) + records, err := reader.ReadAll() + if err != nil { + t.Fatal(err) + } + if len(records) != 2 { + t.Fatalf("records = %d, want 2", len(records)) + } + if got := strings.Join(records[0], ","); got != "item_id,rel_path,status,error,service,timestamp" { + t.Fatalf("headers = %q", got) + } + if records[1][0] != "msg-fail" || records[1][2] != ItemStatusFailed || records[1][3] != "upload timeout" { + t.Fatalf("row = %#v", records[1]) + } +} + +func TestProjectAuditExportCSVFormat(t *testing.T) { + var buf bytes.Buffer + cw := csv.NewWriter(&buf) + if err := cw.Write(projectAuditCSVHeaders); err != nil { + t.Fatal(err) + } + if err := writeProjectAuditCSVRow(cw, JobAuditExportRow{ + JobID: "job-1", + ItemID: "file-1", + Status: ItemStatusImported, + Service: "drive", + Timestamp: "2026-06-13T12:00:00Z", + }); err != nil { + t.Fatal(err) + } + cw.Flush() + + reader := csv.NewReader(strings.NewReader(buf.String())) + records, err := reader.ReadAll() + if err != nil { + t.Fatal(err) + } + if records[0][0] != "job_id" || records[1][0] != "job-1" { + t.Fatalf("records = %#v", records) + } +} + +func TestJobAuditExportNDJSONFormat(t *testing.T) { + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + if err := enc.Encode(JobAuditExportRow{ + ItemID: "msg-skip", + Status: ItemStatusSkipped, + Error: "file too large", + Service: "mail", + Timestamp: "2026-06-13T12:00:00Z", + }); err != nil { + t.Fatal(err) + } + + line := strings.TrimSpace(buf.String()) + var decoded JobAuditExportRow + if err := json.Unmarshal([]byte(line), &decoded); err != nil { + t.Fatal(err) + } + if decoded.ItemID != "msg-skip" || decoded.Status != ItemStatusSkipped || decoded.Error != "file too large" { + t.Fatalf("decoded = %#v", decoded) + } +} + +func TestJobAuditExportMeta(t *testing.T) { + csvMeta := jobAuditExportMeta("csv", "01234567-abcd-efgh", false) + if csvMeta.ContentType != "text/csv; charset=utf-8" { + t.Fatalf("csv content type = %q", csvMeta.ContentType) + } + if !strings.HasSuffix(csvMeta.FileName, ".csv") { + t.Fatalf("csv filename = %q", csvMeta.FileName) + } + + ndMeta := jobAuditExportMeta("ndjson", "01234567-abcd-efgh", true) + if ndMeta.ContentType != "application/x-ndjson; charset=utf-8" { + t.Fatalf("ndjson content type = %q", ndMeta.ContentType) + } + if !strings.HasPrefix(ndMeta.FileName, "migration-project-audit-") { + t.Fatalf("project filename = %q", ndMeta.FileName) + } +} diff --git a/internal/migration/project_columns.go b/internal/migration/project_columns.go index 564f8bc..0e59ee4 100644 --- a/internal/migration/project_columns.go +++ b/internal/migration/project_columns.go @@ -7,29 +7,26 @@ import ( "github.com/ultisuite/ulti-backend/internal/mail/hosted" ) -var projectSelectColumns = []string{ - "id::text", - "COALESCE(domain_id::text, '')", - "name", - "source_provider", - "auth_mode", - "status", - "cutover_at::text", - "delta_mode", - "created_at::text", - "NULLIF(microsoft_tenant_id, '')", - "microsoft_admin_consent_at::text", - "COALESCE(NULLIF(microsoft_admin_consent_error, ''), '')", - "cutover_dns_json", -} - func projectSelectSQL(tablePrefix string) string { if tablePrefix != "" && !strings.HasSuffix(tablePrefix, ".") { tablePrefix += "." } - cols := make([]string, len(projectSelectColumns)) - for i, col := range projectSelectColumns { - cols[i] = tablePrefix + col + p := tablePrefix + cols := []string{ + p + "id::text", + "COALESCE(" + p + "domain_id::text, '')", + p + "name", + p + "source_provider", + p + "auth_mode", + p + "status", + p + "cutover_at::text", + p + "delta_mode", + p + "shared_drive_mode", + p + "created_at::text", + "COALESCE(NULLIF(" + p + "microsoft_tenant_id, ''), '')", + p + "microsoft_admin_consent_at::text", + "COALESCE(NULLIF(" + p + "microsoft_admin_consent_error, ''), '')", + p + "cutover_dns_json", } return strings.Join(cols, ", ") } @@ -47,7 +44,7 @@ func (s *projectScanner) targets() []any { return []any{ &s.project.ID, &s.project.DomainID, &s.project.Name, &s.project.SourceProvider, &s.project.AuthMode, &s.project.Status, &s.project.CutoverAt, &s.project.DeltaMode, - &s.project.CreatedAt, &s.project.MicrosoftTenantID, &s.project.MicrosoftAdminConsentAt, + &s.project.SharedDriveMode, &s.project.CreatedAt, &s.project.MicrosoftTenantID, &s.project.MicrosoftAdminConsentAt, &s.project.MicrosoftAdminConsentError, &s.cutoverDNSRaw, } } diff --git a/internal/migration/roster.go b/internal/migration/roster.go new file mode 100644 index 0000000..7f03816 --- /dev/null +++ b/internal/migration/roster.go @@ -0,0 +1,347 @@ +package migration + +import ( + "context" + "encoding/csv" + "errors" + "fmt" + "io" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/ultisuite/ulti-backend/internal/mail/hosted" +) + +const ( + RosterStatusPending = "pending" + RosterStatusInvited = "invited" + RosterStatusClaimed = "claimed" +) + +type RosterEntry struct { + ID string `json:"id"` + ProjectID string `json:"project_id"` + Email string `json:"email"` + DisplayName string `json:"display_name,omitempty"` + AlternateEmails []string `json:"alternate_emails,omitempty"` + Status string `json:"status"` + InviteID string `json:"invite_id,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +type RosterRowInput struct { + Email string + DisplayName string + AlternateEmails []string +} + +type RosterImportRowError struct { + Row int `json:"row"` + Email string `json:"email,omitempty"` + Message string `json:"message"` +} + +type RosterImportResult struct { + Created int `json:"created"` + SkippedDuplicates int `json:"skipped_duplicates"` + Errors []RosterImportRowError `json:"errors,omitempty"` +} + +var rosterHeaderAliases = map[string]string{ + "email": "email", + "e-mail": "email", + "mail": "email", + "address": "email", + "display_name": "display_name", + "displayname": "display_name", + "name": "display_name", + "full_name": "display_name", + "alternate_emails": "alternate_emails", + "alternate_emails_": "alternate_emails", + "alternates": "alternate_emails", + "alias": "alternate_emails", + "aliases": "alternate_emails", +} + +func ParseRosterCSV(r io.Reader) ([]RosterRowInput, error) { + reader := csv.NewReader(r) + reader.TrimLeadingSpace = true + reader.FieldsPerRecord = -1 + + var rows []RosterRowInput + lineNum := 0 + emailCol := 0 + displayCol := -1 + alternateCol := -1 + headerResolved := false + + for { + record, err := reader.Read() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("csv row %d: %w", lineNum+1, err) + } + lineNum++ + + if len(record) == 0 { + continue + } + for len(record) > 0 && strings.TrimSpace(record[len(record)-1]) == "" { + record = record[:len(record)-1] + } + if len(record) == 0 { + continue + } + + if !headerResolved && looksLikeRosterHeader(record) { + for i, col := range record { + key := normalizeRosterHeader(col) + switch key { + case "email": + emailCol = i + case "display_name": + displayCol = i + case "alternate_emails": + alternateCol = i + } + } + headerResolved = true + continue + } + headerResolved = true + + if emailCol >= len(record) { + continue + } + email := normalizeInviteEmail(record[emailCol]) + if email == "" { + continue + } + if !isEmailAddress(email) { + return nil, fmt.Errorf("csv row %d: invalid email %q", lineNum, record[emailCol]) + } + + row := RosterRowInput{Email: email} + if displayCol >= 0 && displayCol < len(record) { + row.DisplayName = strings.TrimSpace(record[displayCol]) + } + if alternateCol >= 0 && alternateCol < len(record) { + row.AlternateEmails = parseAlternateEmailsField(record[alternateCol]) + } else if len(record) > 1 && displayCol < 0 && alternateCol < 0 { + // email,display_name or email,alternates without header + if len(record) > 1 { + second := strings.TrimSpace(record[1]) + if strings.Contains(second, "@") { + row.AlternateEmails = parseAlternateEmailsField(second) + } else { + row.DisplayName = second + } + } + if len(record) > 2 { + row.AlternateEmails = parseAlternateEmailsField(record[2]) + } + } + row.AlternateEmails = normalizeAlternateEmails(email, row.AlternateEmails) + rows = append(rows, row) + } + + return rows, nil +} + +func looksLikeRosterHeader(record []string) bool { + if len(record) == 0 { + return false + } + first := normalizeRosterHeader(record[0]) + if first == "email" { + return true + } + for _, col := range record { + if normalizeRosterHeader(col) == "email" { + return true + } + } + return false +} + +func normalizeRosterHeader(col string) string { + key := strings.ToLower(strings.TrimSpace(col)) + key = strings.ReplaceAll(key, " ", "_") + key = strings.ReplaceAll(key, "-", "_") + if mapped, ok := rosterHeaderAliases[key]; ok { + return mapped + } + return key +} + +func parseAlternateEmailsField(raw string) []string { + raw = strings.TrimSpace(raw) + if raw == "" { + return nil + } + raw = strings.Trim(raw, "\"'") + parts := strings.FieldsFunc(raw, func(r rune) bool { + return r == ';' || r == '|' || r == ',' + }) + var out []string + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out +} + +func (s *Service) ListRoster(ctx context.Context, projectID string) ([]RosterEntry, error) { + rows, err := s.db.Query(ctx, ` + SELECT id::text, project_id::text, email, display_name, alternate_emails, status, + COALESCE(invite_id::text, ''), created_at::text, updated_at::text + FROM migration_roster + WHERE project_id = $1::uuid + ORDER BY email + `, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + var out []RosterEntry + for rows.Next() { + var row RosterEntry + if err := rows.Scan( + &row.ID, &row.ProjectID, &row.Email, &row.DisplayName, &row.AlternateEmails, + &row.Status, &row.InviteID, &row.CreatedAt, &row.UpdatedAt, + ); err != nil { + return nil, err + } + out = append(out, row) + } + return out, rows.Err() +} + +func (s *Service) ImportRoster(ctx context.Context, projectID string, inputs []RosterRowInput) (RosterImportResult, error) { + result := RosterImportResult{} + for i, input := range inputs { + rowNum := i + 1 + email := normalizeInviteEmail(input.Email) + if email == "" || !isEmailAddress(email) { + result.Errors = append(result.Errors, RosterImportRowError{ + Row: rowNum, Email: input.Email, Message: "invalid email", + }) + continue + } + alternates := normalizeAlternateEmails(email, input.AlternateEmails) + displayName := strings.TrimSpace(input.DisplayName) + + existingStatus, err := s.rosterStatusByEmail(ctx, projectID, email) + if err != nil { + return result, err + } + if existingStatus != "" { + result.SkippedDuplicates++ + continue + } + + tx, err := s.db.Begin(ctx) + if err != nil { + return result, err + } + + var rosterID string + err = tx.QueryRow(ctx, ` + INSERT INTO migration_roster (project_id, email, display_name, alternate_emails, status) + VALUES ($1::uuid, $2, $3, $4, $5) + RETURNING id::text + `, projectID, email, displayName, alternates, RosterStatusPending).Scan(&rosterID) + if err != nil { + _ = tx.Rollback(ctx) + if isUniqueViolation(err) { + result.SkippedDuplicates++ + continue + } + result.Errors = append(result.Errors, RosterImportRowError{ + Row: rowNum, Email: email, Message: err.Error(), + }) + continue + } + + token, err := hosted.NewInviteToken() + if err != nil { + _ = tx.Rollback(ctx) + return result, err + } + + var inviteID string + err = tx.QueryRow(ctx, ` + INSERT INTO migration_invites (project_id, email, token, alternate_emails) + VALUES ($1::uuid, $2, $3, $4) + RETURNING id::text + `, projectID, email, token, alternates).Scan(&inviteID) + if err != nil { + _ = tx.Rollback(ctx) + if isUniqueViolation(err) { + result.SkippedDuplicates++ + continue + } + result.Errors = append(result.Errors, RosterImportRowError{ + Row: rowNum, Email: email, Message: err.Error(), + }) + continue + } + + _, err = tx.Exec(ctx, ` + UPDATE migration_roster + SET status = $1, invite_id = $2::uuid, updated_at = NOW() + WHERE id = $3::uuid + `, RosterStatusInvited, inviteID, rosterID) + if err != nil { + _ = tx.Rollback(ctx) + result.Errors = append(result.Errors, RosterImportRowError{ + Row: rowNum, Email: email, Message: err.Error(), + }) + continue + } + + if err := tx.Commit(ctx); err != nil { + result.Errors = append(result.Errors, RosterImportRowError{ + Row: rowNum, Email: email, Message: err.Error(), + }) + continue + } + result.Created++ + } + return result, nil +} + +func (s *Service) rosterStatusByEmail(ctx context.Context, projectID, email string) (string, error) { + var status string + err := s.db.QueryRow(ctx, ` + SELECT status FROM migration_roster + WHERE project_id = $1::uuid AND email = $2 + `, projectID, email).Scan(&status) + if errors.Is(err, pgx.ErrNoRows) { + return "", nil + } + return status, err +} + +func (s *Service) markRosterClaimed(ctx context.Context, tx pgx.Tx, projectID, inviteID, email string) error { + _, err := tx.Exec(ctx, ` + UPDATE migration_roster + SET status = $1, updated_at = NOW() + WHERE project_id = $2::uuid + AND (invite_id = $3::uuid OR (invite_id IS NULL AND email = $4)) + `, RosterStatusClaimed, projectID, inviteID, email) + return err +} + +func isUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + return errors.As(err, &pgErr) && pgErr.Code == "23505" +} diff --git a/internal/migration/roster_test.go b/internal/migration/roster_test.go new file mode 100644 index 0000000..37e94a5 --- /dev/null +++ b/internal/migration/roster_test.go @@ -0,0 +1,69 @@ +package migration + +import ( + "strings" + "testing" +) + +func TestParseRosterCSVWithHeader(t *testing.T) { + csv := `email,display_name,alternate_emails +alice@corp.com,Alice Corp,alice.old@corp.com;bob.alias@corp.com +bob@corp.com,Bob, +` + rows, err := ParseRosterCSV(strings.NewReader(csv)) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } + if rows[0].Email != "alice@corp.com" || rows[0].DisplayName != "Alice Corp" { + t.Fatalf("row0: %#v", rows[0]) + } + if len(rows[0].AlternateEmails) != 2 { + t.Fatalf("row0 alternates: %#v", rows[0].AlternateEmails) + } + if rows[1].Email != "bob@corp.com" || rows[1].DisplayName != "Bob" { + t.Fatalf("row1: %#v", rows[1]) + } +} + +func TestParseRosterCSVWithoutHeader(t *testing.T) { + csv := "alice@corp.com,Alice\nbob@corp.com\n" + rows, err := ParseRosterCSV(strings.NewReader(csv)) + if err != nil { + t.Fatalf("parse: %v", err) + } + if len(rows) != 2 { + t.Fatalf("expected 2 rows, got %d", len(rows)) + } + if rows[0].DisplayName != "Alice" { + t.Fatalf("display name: %#v", rows[0]) + } + if rows[1].DisplayName != "" { + t.Fatalf("row1 display: %#v", rows[1]) + } +} + +func TestParseRosterCSVInvalidEmail(t *testing.T) { + _, err := ParseRosterCSV(strings.NewReader("not-an-email,Someone\n")) + if err == nil { + t.Fatal("expected invalid email error") + } +} + +func TestParseAlternateEmailsField(t *testing.T) { + got := parseAlternateEmailsField(`a@x.com; b@x.com | c@x.com`) + if len(got) != 3 { + t.Fatalf("got %#v", got) + } +} + +func TestLooksLikeRosterHeader(t *testing.T) { + if !looksLikeRosterHeader([]string{"Email", "Name"}) { + t.Fatal("expected header detection") + } + if looksLikeRosterHeader([]string{"alice@corp.com", "Alice"}) { + t.Fatal("data row should not be header") + } +} diff --git a/internal/migration/service.go b/internal/migration/service.go index 9238da6..f8b0f67 100644 --- a/internal/migration/service.go +++ b/internal/migration/service.go @@ -23,6 +23,7 @@ var ( ErrInviteNotFound = errors.New("migration invite not found") ErrInviteClaimed = errors.New("migration invite already claimed") ErrEmailMismatch = errors.New("email does not match invite") + ErrTenantMismatch = errors.New("microsoft tenant does not match project") ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active") ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain") ) @@ -53,6 +54,7 @@ type Project struct { Status string `json:"status"` CutoverAt *string `json:"cutover_at,omitempty"` DeltaMode bool `json:"delta_mode"` + SharedDriveMode string `json:"shared_drive_mode"` CreatedAt string `json:"created_at"` MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"` MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"` @@ -136,6 +138,9 @@ func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alt return Invite{}, fmt.Errorf("email required") } alternates := normalizeAlternateEmails(email, alternateEmails) + if alternates == nil { + alternates = []string{} + } token, err := hosted.NewInviteToken() if err != nil { return Invite{}, err @@ -222,6 +227,9 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit hostedDomain = &domain projectDomain = domain.Name } + if err := validateMicrosoftTenantClaim(proj, identity.TenantID); err != nil { + return UserStatus{}, err + } if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) { return UserStatus{}, ErrEmailMismatch } @@ -242,7 +250,12 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit return UserStatus{}, err } + if err := s.markRosterClaimed(ctx, tx, proj.ID, inv.ID, mailboxEmail); err != nil { + return UserStatus{}, err + } + if s.hosted != nil { + _ = LinkHostedMailboxByEmail(ctx, s.db, userID, mailboxEmail) provision := hosted.ProvisionMailboxInput{ UserID: userID, Email: mailboxEmail, @@ -260,14 +273,12 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit } provision.DomainID = proj.DomainID } - _, err = s.hosted.ProvisionMailbox(ctx, provision) + _, err = s.hosted.EnsureMailboxProvisioned(ctx, provision) if err != nil { if errors.Is(err, hosted.ErrDomainNotActive) { return UserStatus{}, ErrMigrationDomainNotActive } - if !errors.Is(err, hosted.ErrAddressTaken) { - return UserStatus{}, err - } + return UserStatus{}, err } } @@ -485,6 +496,10 @@ func (s *Service) ActivateProject(ctx context.Context, projectID string) (Projec } func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) { + externalID = strings.TrimSpace(externalID) + if externalID == "" { + return "", pgx.ErrNoRows + } var userID string err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID) return userID, err diff --git a/internal/migration/worker.go b/internal/migration/worker.go index 301b1f1..96a73da 100644 --- a/internal/migration/worker.go +++ b/internal/migration/worker.go @@ -165,7 +165,20 @@ func (w *Worker) processJob(ctx context.Context, job Job) (string, error) { procErr = NewCalendarImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update) case "drive": selfManaged = true - procErr = NewDriveImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update) + sharedMode, err := w.projectSharedDriveMode(ctx, job.ProjectID) + if err != nil { + outcome = "failed" + return outcome, err + } + dedup, err := LoadSharedDriveItemStore(ctx, w.db, job.ProjectID) + if err != nil { + outcome = "failed" + return outcome, err + } + procErr = NewDriveImporter(w.db, w.nc). + WithUserPrincipal(graphUserUPN). + WithProject(job.ProjectID, sharedMode, dedup). + ImportBatch(ctx, &job, accessToken, provider, delta, update) default: procErr = fmt.Errorf("unknown service %q", job.Service) } @@ -253,6 +266,15 @@ func (w *Worker) projectMicrosoftTenant(ctx context.Context, projectID string) ( return tenantID, nil } +func (w *Worker) projectSharedDriveMode(ctx context.Context, projectID string) (string, error) { + var mode string + err := w.db.QueryRow(ctx, ` + SELECT COALESCE(shared_drive_mode, 'auto') + FROM migration_projects WHERE id = $1::uuid + `, projectID).Scan(&mode) + return NormalizeSharedDriveMode(mode), err +} + func (w *Worker) inviteEmail(ctx context.Context, projectID, userID string) (string, error) { var email string err := w.db.QueryRow(ctx, ` diff --git a/internal/nextcloud/upload_streaming.go b/internal/nextcloud/upload_streaming.go new file mode 100644 index 0000000..0798e67 --- /dev/null +++ b/internal/nextcloud/upload_streaming.go @@ -0,0 +1,65 @@ +package nextcloud + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/hex" + "fmt" + "io" +) + +const defaultUploadChunkSize = 10 * 1024 * 1024 + +// UploadStreaming uploads large files via Nextcloud chunked DAV assembly. +func (c *Client) UploadStreaming(ctx context.Context, userID, targetPath string, content io.Reader, contentType string, totalSize int64) error { + uploadID, err := newChunkUploadID() + if err != nil { + return err + } + buf := make([]byte, defaultUploadChunkSize) + var uploaded int64 + chunkIndex := 0 + for { + n, readErr := io.ReadFull(content, buf) + if n == 0 && readErr == io.EOF { + break + } + if n > 0 { + chunkIndex++ + if err := c.UploadChunk(ctx, userID, uploadID, chunkIndexName(chunkIndex), bytes.NewReader(buf[:n]), contentType); err != nil { + _ = c.AbortChunkUpload(ctx, userID, uploadID) + return err + } + uploaded += int64(n) + } + if readErr == io.EOF || readErr == io.ErrUnexpectedEOF { + break + } + if readErr != nil { + _ = c.AbortChunkUpload(ctx, userID, uploadID) + return readErr + } + } + finalSize := uploaded + if totalSize > 0 { + finalSize = totalSize + } + if err := c.AssembleChunks(ctx, userID, uploadID, targetPath, finalSize); err != nil { + _ = c.AbortChunkUpload(ctx, userID, uploadID) + return err + } + return nil +} + +func newChunkUploadID() (string, error) { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("chunk upload id: %w", err) + } + return hex.EncodeToString(b), nil +} + +func chunkIndexName(index int) string { + return fmt.Sprintf("%d", index) +} diff --git a/internal/provision/handler.go b/internal/provision/handler.go index f9e4a5c..d1cc102 100644 --- a/internal/provision/handler.go +++ b/internal/provision/handler.go @@ -6,11 +6,9 @@ import ( "net/http" "strings" - "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" - "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/mail/hosted" "github.com/ultisuite/ulti-backend/internal/migration" "github.com/ultisuite/ulti-backend/internal/nextcloud" @@ -75,56 +73,60 @@ func (h *Handler) ProvisionUser(w http.ResponseWriter, r *http.Request) { } ctx := r.Context() - var userID string externalID := strings.TrimSpace(req.ExternalID) - if externalID != "" { - err := h.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID) - if errors.Is(err, pgx.ErrNoRows) { - userID, err = users.EnsureUser(ctx, h.db, &auth.Claims{ - Sub: externalID, - Email: email, - Name: req.Name, - }) - } - if err != nil { - h.logger.Error("ensure user", "error", err) - apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to provision user", nil) - return - } - } - - result, err := h.hosted.ProvisionMailbox(ctx, hosted.ProvisionMailboxInput{ - UserID: userID, - Email: email, - DisplayName: req.Name, - Password: req.Password, - }) + userID, err := users.ResolveProvisionUser(ctx, h.db, externalID, email, req.Name) if err != nil { - if errors.Is(err, hosted.ErrAddressTaken) { - apiresponse.WriteError(w, r, http.StatusConflict, "address_taken", err.Error(), nil) - return - } - h.logger.Error("provision mailbox", "error", err, "email", email) - apiresponse.WriteError(w, r, http.StatusConflict, "provision_failed", err.Error(), nil) + h.logger.Error("resolve user", "error", err, "email", email) + apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to provision user", nil) return } - if userID != "" { - _ = migration.LinkHostedMailboxByEmail(ctx, h.db, userID, email) + skipMailbox, err := migration.HasPendingMigrationInvite(ctx, h.db, email) + if err != nil { + h.logger.Error("check migration invite", "error", err, "email", email) + apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to check migration invite", nil) + return } + var result hosted.ProvisionMailboxResult + if !skipMailbox { + result, err = h.hosted.EnsureMailboxProvisioned(ctx, hosted.ProvisionMailboxInput{ + UserID: userID, + Email: email, + DisplayName: req.Name, + Password: req.Password, + }) + if err != nil { + if errors.Is(err, hosted.ErrAddressTaken) { + apiresponse.WriteError(w, r, http.StatusConflict, "address_taken", err.Error(), nil) + return + } + h.logger.Error("provision mailbox", "error", err, "email", email) + apiresponse.WriteError(w, r, http.StatusConflict, "provision_failed", err.Error(), nil) + return + } + } + + _ = migration.LinkHostedMailboxByEmail(ctx, h.db, userID, email) + if h.nc != nil && userID != "" && externalID != "" { if _, err := h.nc.EnsurePrincipal(ctx, email, externalID, req.Name); err != nil { h.logger.Warn("nextcloud provision", "error", err) } } - apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ - "user_id": userID, - "email": email, - "mailbox_id": result.Mailbox.ID, - "mail_account_id": result.MailAccountID, - }) + resp := map[string]any{ + "user_id": userID, + "email": email, + } + if !skipMailbox { + resp["mailbox_id"] = result.Mailbox.ID + resp["mail_account_id"] = result.MailAccountID + } + if skipMailbox { + resp["mailbox_deferred"] = true + } + apiresponse.WriteJSON(w, http.StatusOK, resp) } // CheckAddress validates local part availability (Authentik expression policy or public API). diff --git a/internal/provision/handler_test.go b/internal/provision/handler_test.go new file mode 100644 index 0000000..95165ae --- /dev/null +++ b/internal/provision/handler_test.go @@ -0,0 +1,52 @@ +package provision + +import ( + "bytes" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDecodeProvisionBodyAuthentikPayload(t *testing.T) { + body := []byte(`{ + "email": "alice@ultisuite.fr", + "password": "secret", + "name": "Alice", + "external_id": "uuid-123", + "user": {"email": "ignored@example.com", "uuid": "ignored"} + }`) + req := httptest.NewRequest(http.MethodPost, "/internal/provision/user", bytes.NewReader(body)) + got, err := decodeProvisionBody(req) + if err != nil { + t.Fatalf("decodeProvisionBody() error = %v", err) + } + if got.Email != "alice@ultisuite.fr" || got.ExternalID != "uuid-123" || got.Name != "Alice" { + t.Fatalf("decodeProvisionBody() = %#v", got) + } +} + +func TestAuthorizeProvisionSecret(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/internal/provision/user", nil) + req.Header.Set("X-Provision-Secret", "topsecret") + if !authorizeProvision(req, "topsecret") { + t.Fatal("expected header secret to authorize") + } + req = httptest.NewRequest(http.MethodPost, "/internal/provision/user?secret=topsecret", nil) + if !authorizeProvision(req, "topsecret") { + t.Fatal("expected query secret to authorize") + } + if authorizeProvision(req, "wrong") { + t.Fatal("expected wrong secret to fail") + } +} + +func TestNormalizeProvisionRequestUsesUsername(t *testing.T) { + req := provisionUserRequest{Username: "bob@ultisuite.fr"} + normalizeProvisionRequest(&req) + if req.Email != "bob@ultisuite.fr" { + t.Fatalf("email = %q, want bob@ultisuite.fr", req.Email) + } + if req.Name != "bob@ultisuite.fr" { + t.Fatalf("name = %q, want fallback to email", req.Name) + } +} diff --git a/internal/users/provision.go b/internal/users/provision.go index 5ebef39..cf69868 100644 --- a/internal/users/provision.go +++ b/internal/users/provision.go @@ -2,13 +2,14 @@ package users import ( "context" + "errors" "fmt" "strings" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/auth" - "github.com/ultisuite/ulti-backend/internal/migration" ) // ProvisionEmail returns the email stored for a newly provisioned user. @@ -59,6 +60,66 @@ func EnsureUser(ctx context.Context, db *pgxpool.Pool, claims *auth.Claims) (str return "", fmt.Errorf("bootstrap platform admin: %w", err) } } - _ = migration.LinkHostedMailboxByEmail(ctx, db, userID, email) return userID, nil } + +// LookupUserID returns the internal user UUID for an OIDC subject. +func LookupUserID(ctx context.Context, db *pgxpool.Pool, externalID string) (string, error) { + if db == nil { + return "", fmt.Errorf("database not configured") + } + externalID = strings.TrimSpace(externalID) + if externalID == "" { + return "", pgx.ErrNoRows + } + var userID string + err := db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID) + return userID, err +} + +// LookupUserIDByEmail returns the internal user UUID for a stored email address. +func LookupUserIDByEmail(ctx context.Context, db *pgxpool.Pool, email string) (string, error) { + if db == nil { + return "", fmt.Errorf("database not configured") + } + email = strings.ToLower(strings.TrimSpace(email)) + if email == "" { + return "", pgx.ErrNoRows + } + var userID string + err := db.QueryRow(ctx, `SELECT id::text FROM users WHERE lower(email) = $1`, email).Scan(&userID) + return userID, err +} + +// ResolveProvisionUser finds an existing user by external_id or email, or creates one from Authentik enrollment data. +func ResolveProvisionUser(ctx context.Context, db *pgxpool.Pool, externalID, email, name string) (string, error) { + externalID = strings.TrimSpace(externalID) + email = strings.ToLower(strings.TrimSpace(email)) + + if externalID != "" { + userID, err := LookupUserID(ctx, db, externalID) + if err == nil { + return userID, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return "", err + } + } + if email != "" { + userID, err := LookupUserIDByEmail(ctx, db, email) + if err == nil { + return userID, nil + } + if !errors.Is(err, pgx.ErrNoRows) { + return "", err + } + } + if externalID == "" { + return "", fmt.Errorf("cannot provision user without external_id or existing email") + } + return EnsureUser(ctx, db, &auth.Claims{ + Sub: externalID, + Email: email, + Name: name, + }) +} diff --git a/internal/users/provision_resolve_test.go b/internal/users/provision_resolve_test.go new file mode 100644 index 0000000..b980d91 --- /dev/null +++ b/internal/users/provision_resolve_test.go @@ -0,0 +1,26 @@ +package users + +import ( + "testing" +) + +func TestLookupUserIDEmptyExternalID(t *testing.T) { + _, err := LookupUserID(t.Context(), nil, "") + if err == nil { + t.Fatal("expected error for empty external id") + } +} + +func TestLookupUserIDByEmailEmpty(t *testing.T) { + _, err := LookupUserIDByEmail(t.Context(), nil, "") + if err == nil { + t.Fatal("expected error for empty email") + } +} + +func TestResolveProvisionUserRequiresIdentity(t *testing.T) { + _, err := ResolveProvisionUser(t.Context(), nil, "", "user@example.com", "User") + if err == nil { + t.Fatal("expected error without external_id or existing user") + } +} diff --git a/migrations/000048_migration_roster.down.sql b/migrations/000048_migration_roster.down.sql new file mode 100644 index 0000000..e621b8e --- /dev/null +++ b/migrations/000048_migration_roster.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS migration_roster; diff --git a/migrations/000048_migration_roster.up.sql b/migrations/000048_migration_roster.up.sql new file mode 100644 index 0000000..6794f8b --- /dev/null +++ b/migrations/000048_migration_roster.up.sql @@ -0,0 +1,17 @@ +CREATE TABLE migration_roster ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE, + email TEXT NOT NULL, + display_name TEXT NOT NULL DEFAULT '', + alternate_emails TEXT[] NOT NULL DEFAULT '{}', + status TEXT NOT NULL DEFAULT 'pending', + invite_id UUID REFERENCES migration_invites(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE (project_id, email), + CONSTRAINT migration_roster_status_check CHECK (status IN ('pending', 'invited', 'claimed')) +); + +CREATE INDEX idx_migration_roster_project ON migration_roster(project_id); +CREATE INDEX idx_migration_roster_status ON migration_roster(status); +CREATE INDEX idx_migration_roster_invite ON migration_roster(invite_id); diff --git a/migrations/000049_migration_shared_drives.down.sql b/migrations/000049_migration_shared_drives.down.sql new file mode 100644 index 0000000..1f1ba4e --- /dev/null +++ b/migrations/000049_migration_shared_drives.down.sql @@ -0,0 +1,5 @@ +DROP TABLE IF EXISTS migration_shared_drive_items; +DROP TABLE IF EXISTS migration_shared_drives; + +ALTER TABLE migration_projects DROP CONSTRAINT IF EXISTS migration_projects_shared_drive_mode_check; +ALTER TABLE migration_projects DROP COLUMN IF EXISTS shared_drive_mode; diff --git a/migrations/000049_migration_shared_drives.up.sql b/migrations/000049_migration_shared_drives.up.sql new file mode 100644 index 0000000..77d3e5c --- /dev/null +++ b/migrations/000049_migration_shared_drives.up.sql @@ -0,0 +1,39 @@ +-- Shared drive import mode and project-level dedup for Google Workspace migrations. +ALTER TABLE migration_projects + ADD COLUMN shared_drive_mode TEXT NOT NULL DEFAULT 'auto'; + +ALTER TABLE migration_projects + ADD CONSTRAINT migration_projects_shared_drive_mode_check + CHECK (shared_drive_mode IN ('auto', 'manual')); + +CREATE TABLE migration_shared_drives ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE, + drive_id TEXT NOT NULL, + name TEXT NOT NULL DEFAULT '', + status TEXT NOT NULL DEFAULT 'pending', + discovered_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(project_id, drive_id), + CONSTRAINT migration_shared_drives_status_check + CHECK (status IN ('pending', 'approved', 'rejected')) +); + +CREATE INDEX idx_migration_shared_drives_project_status + ON migration_shared_drives(project_id, status); + +-- Project-level dedup: one import per shared-drive file across all users in a project. +CREATE TABLE migration_shared_drive_items ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE, + drive_id TEXT NOT NULL, + source_id TEXT NOT NULL, + rel_path TEXT NOT NULL DEFAULT '', + imported_by_job_id UUID REFERENCES migration_jobs(id) ON DELETE SET NULL, + imported_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(project_id, drive_id, source_id) +); + +CREATE INDEX idx_migration_shared_drive_items_project + ON migration_shared_drive_items(project_id, drive_id);