feat(migration): enhance migration API with roster and audit export features
Some checks are pending
CI / Go tests (push) Waiting to run
CI / Integration tests (push) Waiting to run
CI / DB migrations (push) Waiting to run

- Added endpoints for listing and importing migration rosters.
- Introduced audit export functionality for migration jobs in CSV and NDJSON formats.
- Implemented tenant mismatch validation for Microsoft migration claims.
- Enhanced error handling for email claiming and migration processes.
- Added integration tests for roster import and claim workflows.
This commit is contained in:
R3D347HR4Y 2026-06-13 13:11:30 +02:00
parent 7143a36c19
commit 1ffd0817d8
39 changed files with 3335 additions and 175 deletions

View File

@ -3,6 +3,7 @@ package admin
import ( import (
"encoding/csv" "encoding/csv"
"errors" "errors"
"fmt"
"io" "io"
"net/http" "net/http"
"strings" "strings"
@ -41,12 +42,20 @@ func (h *Handler) registerMailAdminRoutes(r chi.Router, read, write func(http.Ha
r.With(write).Post("/projects/{projectID}/cutover", h.StartMigrationCutover) r.With(write).Post("/projects/{projectID}/cutover", h.StartMigrationCutover)
r.With(write).Post("/projects/{projectID}/invites", h.CreateMigrationInvite) r.With(write).Post("/projects/{projectID}/invites", h.CreateMigrationInvite)
r.With(write).Post("/projects/{projectID}/invites/import", h.ImportMigrationInvites) r.With(write).Post("/projects/{projectID}/invites/import", h.ImportMigrationInvites)
r.With(read).Get("/projects/{projectID}/roster", h.ListMigrationRoster)
r.With(write).Post("/projects/{projectID}/roster", h.ImportMigrationRoster)
r.With(read).Get("/projects/{projectID}/jobs", h.ListMigrationProjectJobs) r.With(read).Get("/projects/{projectID}/jobs", h.ListMigrationProjectJobs)
r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit", h.ListMigrationJobAudit) r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit", h.ListMigrationJobAudit)
r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/summary", h.MigrationJobAuditSummary) r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/summary", h.MigrationJobAuditSummary)
r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/export", h.ExportMigrationJobAudit)
r.With(read).Get("/projects/{projectID}/audit/export", h.ExportMigrationProjectAudit)
r.With(write).Post("/projects/{projectID}/jobs/retry-failed", h.RetryMigrationFailedJobs) r.With(write).Post("/projects/{projectID}/jobs/retry-failed", h.RetryMigrationFailedJobs)
r.With(write).Post("/projects/{projectID}/jobs/{jobID}/retry", h.RetryMigrationJob) r.With(write).Post("/projects/{projectID}/jobs/{jobID}/retry", h.RetryMigrationJob)
r.With(write).Post("/projects/{projectID}/jobs/{jobID}/reset-cursor", h.ResetMigrationJobCursor) r.With(write).Post("/projects/{projectID}/jobs/{jobID}/reset-cursor", h.ResetMigrationJobCursor)
r.With(write).Patch("/projects/{projectID}/shared-drive-mode", h.UpdateMigrationSharedDriveMode)
r.With(read).Get("/projects/{projectID}/shared-drives", h.ListMigrationSharedDrives)
r.With(write).Post("/projects/{projectID}/shared-drives/{driveID}/approve", h.ApproveMigrationSharedDrive)
r.With(write).Post("/projects/{projectID}/shared-drives/{driveID}/reject", h.RejectMigrationSharedDrive)
r.With(read).Get("/microsoft/admin-consent-url", h.MicrosoftMigrationAdminConsentURL) r.With(read).Get("/microsoft/admin-consent-url", h.MicrosoftMigrationAdminConsentURL)
r.With(read).Get("/microsoft/admin-consents", h.ListMicrosoftAdminConsents) r.With(read).Get("/microsoft/admin-consents", h.ListMicrosoftAdminConsents)
}) })
@ -249,6 +258,81 @@ func (h *Handler) ImportMigrationInvites(w http.ResponseWriter, r *http.Request)
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"imported": count}) apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"imported": count})
} }
func (h *Handler) ListMigrationRoster(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
rows, err := h.svc.migration.ListRoster(r.Context(), chi.URLParam(r, "projectID"))
if err != nil {
apivalidate.WriteInternal(w, r)
return
}
if rows == nil {
rows = []migr.RosterEntry{}
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"roster": rows})
}
func (h *Handler) ImportMigrationRoster(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
projectID := chi.URLParam(r, "projectID")
var inputs []migr.RosterRowInput
contentType := r.Header.Get("Content-Type")
if strings.Contains(contentType, "multipart/form-data") {
file, _, err := r.FormFile("file")
if err == nil {
defer file.Close()
parsed, err := migr.ParseRosterCSV(file)
if err != nil {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
Field: "file", Message: err.Error(),
}))
return
}
inputs = parsed
}
}
if len(inputs) == 0 {
var body struct {
CSV string `json:"csv"`
Rows []migr.RosterRowInput `json:"rows"`
}
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &body); err != nil {
return
}
if len(body.Rows) > 0 {
inputs = body.Rows
} else if strings.TrimSpace(body.CSV) != "" {
parsed, err := migr.ParseRosterCSV(strings.NewReader(body.CSV))
if err != nil {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
Field: "csv", Message: err.Error(),
}))
return
}
inputs = parsed
}
}
if len(inputs) == 0 {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
Field: "csv", Message: "roster csv or rows required",
}))
return
}
result, err := h.svc.migration.ImportRoster(r.Context(), projectID, inputs)
if err != nil {
apivalidate.WriteInternal(w, r)
return
}
apiresponse.WriteJSON(w, http.StatusOK, result)
}
func (h *Handler) MicrosoftMigrationAdminConsentURL(w http.ResponseWriter, r *http.Request) { func (h *Handler) MicrosoftMigrationAdminConsentURL(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil { if h.svc.migration == nil {
apivalidate.WriteInternal(w, r) apivalidate.WriteInternal(w, r)
@ -388,3 +472,137 @@ func (h *Handler) MigrationJobAuditSummary(w http.ResponseWriter, r *http.Reques
} }
apiresponse.WriteJSON(w, http.StatusOK, summary) apiresponse.WriteJSON(w, http.StatusOK, summary)
} }
func (h *Handler) ExportMigrationJobAudit(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
format, verr := validateExportFormat(r.URL.Query().Get("format"))
if verr != nil {
apivalidate.WriteValidationError(w, r, verr)
return
}
projectID := chi.URLParam(r, "projectID")
jobID := chi.URLParam(r, "jobID")
meta, err := h.svc.migration.PrepareJobAuditExport(r.Context(), projectID, jobID, format)
if err != nil {
if strings.Contains(err.Error(), "not found") {
apiresponse.WriteError(w, r, http.StatusNotFound, "migration_job_not_found", err.Error(), nil)
return
}
apivalidate.WriteInternal(w, r)
return
}
w.Header().Set("Content-Type", meta.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, meta.FileName))
w.WriteHeader(http.StatusOK)
if err := h.svc.migration.WriteJobAuditExport(
r.Context(),
projectID,
jobID,
r.URL.Query().Get("status"),
format,
w,
); err != nil {
h.logger.Error("export migration job audit", "error", err)
}
}
func (h *Handler) ExportMigrationProjectAudit(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
format, verr := validateExportFormat(r.URL.Query().Get("format"))
if verr != nil {
apivalidate.WriteValidationError(w, r, verr)
return
}
projectID := chi.URLParam(r, "projectID")
meta, err := h.svc.migration.PrepareProjectAuditExport(r.Context(), projectID, format)
if err != nil {
if strings.Contains(err.Error(), "not found") {
apiresponse.WriteError(w, r, http.StatusNotFound, "migration_project_not_found", err.Error(), nil)
return
}
apivalidate.WriteInternal(w, r)
return
}
w.Header().Set("Content-Type", meta.ContentType)
w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"`, meta.FileName))
w.WriteHeader(http.StatusOK)
if err := h.svc.migration.WriteProjectAuditExport(
r.Context(),
projectID,
r.URL.Query().Get("status"),
format,
w,
); err != nil {
h.logger.Error("export migration project audit", "error", err)
}
}
type updateSharedDriveModeRequest struct {
Mode string `json:"shared_drive_mode"`
}
func (h *Handler) UpdateMigrationSharedDriveMode(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
var req updateSharedDriveModeRequest
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &req); err != nil {
return
}
row, err := h.svc.migration.UpdateSharedDriveMode(r.Context(), chi.URLParam(r, "projectID"), req.Mode)
if err != nil {
apivalidate.WriteInternal(w, r)
return
}
apiresponse.WriteJSON(w, http.StatusOK, row)
}
func (h *Handler) ListMigrationSharedDrives(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
rows, err := h.svc.migration.ListSharedDrives(r.Context(), chi.URLParam(r, "projectID"), r.URL.Query().Get("status"))
if err != nil {
apivalidate.WriteInternal(w, r)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"shared_drives": rows})
}
func (h *Handler) ApproveMigrationSharedDrive(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
row, err := h.svc.migration.ApproveSharedDrive(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "driveID"))
if err != nil {
apiresponse.WriteError(w, r, http.StatusNotFound, "shared_drive_not_found", err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, row)
}
func (h *Handler) RejectMigrationSharedDrive(w http.ResponseWriter, r *http.Request) {
if h.svc.migration == nil {
apivalidate.WriteInternal(w, r)
return
}
row, err := h.svc.migration.RejectSharedDrive(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "driveID"))
if err != nil {
apiresponse.WriteError(w, r, http.StatusNotFound, "shared_drive_not_found", err.Error(), nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, row)
}

View File

@ -162,6 +162,8 @@ func (h *Handler) ClaimInvite(w http.ResponseWriter, r *http.Request) {
errCode = "invite_already_claimed" errCode = "invite_already_claimed"
case err == migr.ErrEmailMismatch: case err == migr.ErrEmailMismatch:
errCode = "email_mismatch" errCode = "email_mismatch"
case err == migr.ErrTenantMismatch:
errCode = "tenant_mismatch"
case err == migr.ErrMigrationDomainNotActive: case err == migr.ErrMigrationDomainNotActive:
errCode = "migration_domain_not_active" errCode = "migration_domain_not_active"
case err == migr.ErrMigrationDomainMismatch: case err == migr.ErrMigrationDomainMismatch:

View File

@ -215,6 +215,8 @@ func buildTestConfig(env Env, infra *infra, oidc *OIDCServer) *config.Config {
MailActiveCredentialKeyID: "v1", MailActiveCredentialKeyID: "v1",
MailWebhookSharedSecret: "test-webhook-secret", MailWebhookSharedSecret: "test-webhook-secret",
MailAppURL: "http://localhost:3004", MailAppURL: "http://localhost:3004",
ProvisionWebhookSecret: "test-provision-secret",
PlatformMailDomain: "ultisuite.local",
SearchEngine: "postgres", SearchEngine: "postgres",
MeilisearchURL: env.MeilisearchURL, MeilisearchURL: env.MeilisearchURL,
MeilisearchKey: env.MeilisearchKey, MeilisearchKey: env.MeilisearchKey,

View File

@ -101,7 +101,7 @@ func TestClaimInviteRejectsEmailMismatch(t *testing.T) {
integrationtest.FailUnlessStatus(t, actResp, 200) integrationtest.FailUnlessStatus(t, actResp, 200)
inviteEmail := "victim-" + uuid.NewString() + "@example.com" inviteEmail := "victim-" + uuid.NewString() + "@example.com"
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{ inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{
"email": inviteEmail, "email": inviteEmail,
}) })
integrationtest.FailIf(err, t, "create invite") integrationtest.FailIf(err, t, "create invite")
@ -127,3 +127,131 @@ func TestClaimInviteRejectsEmailMismatch(t *testing.T) {
integrationtest.FailIf(err, t, "claim invite") integrationtest.FailIf(err, t, "claim invite")
integrationtest.AssertErrorCode(t, claimResp, 400, "email_mismatch") integrationtest.AssertErrorCode(t, claimResp, 400, "email_mismatch")
} }
func TestClaimInviteRejectsMicrosoftTenantMismatch(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Tenant mismatch",
"source_provider": "microsoft",
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
expectedTenant := "11111111-2222-3333-4444-555555555555"
if _, err := h.Pool.Exec(ctx, `
UPDATE migration_projects SET microsoft_tenant_id = $1 WHERE id = $2::uuid
`, expectedTenant, created.ID); err != nil {
t.Fatalf("set tenant: %v", err)
}
inviteEmail := "tenant-user-" + uuid.NewString()[:8] + "@example.com"
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{
"email": inviteEmail,
})
integrationtest.FailIf(err, t, "create invite")
integrationtest.FailUnlessStatus(t, inviteResp, 201)
var invite struct {
Token string `json:"token"`
}
integrationtest.DecodeJSON(t, inviteResp, &invite)
wrongTenantClaims := integrationtest.RegularUser(integrationtest.NewExternalID("tenant-mismatch"))
wrongTenantClaims.Email = inviteEmail
wrongTenantClaims.TID = "99999999-aaaa-bbbb-cccc-dddddddddddd"
wrongTenantClient, err := h.Client(wrongTenantClaims)
integrationtest.FailIf(err, t, "wrong tenant client")
if _, err := users.EnsureUser(ctx, h.Pool, wrongTenantClaims); err != nil {
t.Fatalf("ensure user: %v", err)
}
claimResp, err := wrongTenantClient.Post("/api/v1/migration/claim", map[string]string{
"token": invite.Token,
})
integrationtest.FailIf(err, t, "claim invite")
integrationtest.AssertErrorCode(t, claimResp, 400, "tenant_mismatch")
}
func TestClaimInviteGoogleProjectIgnoresTenant(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Google ignores tenant",
"source_provider": "google",
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
if _, err := h.Pool.Exec(ctx, `
UPDATE migration_projects SET microsoft_tenant_id = $1 WHERE id = $2::uuid
`, "11111111-2222-3333-4444-555555555555", created.ID); err != nil {
t.Fatalf("set tenant: %v", err)
}
inviteEmail := "google-user-" + uuid.NewString()[:8] + "@example.com"
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{
"email": inviteEmail,
})
integrationtest.FailIf(err, t, "create invite")
integrationtest.FailUnlessStatus(t, inviteResp, 201)
var invite struct {
Token string `json:"token"`
}
integrationtest.DecodeJSON(t, inviteResp, &invite)
userClaims := integrationtest.RegularUser(integrationtest.NewExternalID("google-tenant-ignore"))
userClaims.Email = inviteEmail
userClaims.TID = "wrong-tenant-id"
userClient, err := h.Client(userClaims)
integrationtest.FailIf(err, t, "user client")
if _, err := users.EnsureUser(ctx, h.Pool, userClaims); err != nil {
t.Fatalf("ensure user: %v", err)
}
claimResp, err := userClient.Post("/api/v1/migration/claim", map[string]string{
"token": invite.Token,
"password": "test-password-123",
})
integrationtest.FailIf(err, t, "claim invite")
integrationtest.FailUnlessStatus(t, claimResp, 200)
}

View File

@ -495,6 +495,98 @@ func TestGraphMailDeltaDeletesRemoved(t *testing.T) {
} }
} }
func TestGraphFolderDeltaDeletesRemoved(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("graph-folder-delta")))
integrationtest.FailIf(err, t, "ensure user")
var accountID string
err = h.Pool.QueryRow(ctx, `
INSERT INTO mail_accounts (user_id, email, provider, is_active)
VALUES ($1::uuid, 'graph-folder-delta@test.local', 'hosted', true)
RETURNING id::text
`, userID).Scan(&accountID)
integrationtest.FailIf(err, t, "insert mail account")
uid := migr.RemoteMessageUIDForTest("msg-folder-removed")
_, err = h.Pool.Exec(ctx, `
INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, date, snippet, body_text, body_html, flags, labels)
SELECT $1::uuid, f.id, $2, '<test@local>', 'To delete', '[]', '[]', NOW(), '', '', '', '{}', '{}'
FROM mail_folders f WHERE f.account_id = $1::uuid AND f.remote_name = 'INBOX' LIMIT 1
`, accountID, uid)
integrationtest.FailIf(err, t, "seed message")
inboxID := "inbox-folder-id"
sentID := "sent-folder-id"
client := graphRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/mailFolders") {
_, _ = w.Write([]byte(`{"value":[
{"id":"` + inboxID + `","displayName":"Inbox","wellKnownName":"inbox"},
{"id":"` + sentID + `","displayName":"Sent","wellKnownName":"sentitems"}
]}`))
return
}
if strings.Contains(r.URL.Path, "/mailFolders/"+inboxID+"/messages/delta") {
_, _ = w.Write([]byte(`{
"value":[{"id":"msg-folder-removed","@removed":{"reason":"deleted"}}],
"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/` + inboxID + `/messages/delta?token=inbox-done"
}`))
return
}
if strings.Contains(r.URL.Path, "/mailFolders/"+sentID+"/messages/delta") {
_, _ = w.Write([]byte(`{
"value":[],
"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/` + sentID + `/messages/delta?token=sent-done"
}`))
return
}
http.NotFound(w, r)
})
importer := migr.NewGraphImporter(h.Pool).WithHTTPClient(client).WithBaseURL("https://graph.microsoft.com")
job := &migr.Job{
UserID: userID,
CursorJSON: map[string]any{
"graphFolderQueue": []any{inboxID, sentID},
"folderDeltaLinks": map[string]any{
inboxID: "https://graph.microsoft.com/v1.0/me/mailFolders/" + inboxID + "/messages/delta?token=inbox-old",
sentID: "https://graph.microsoft.com/v1.0/me/mailFolders/" + sentID + "/messages/delta?token=sent-old",
},
},
StatsJSON: map[string]any{},
}
for {
var finalStatus string
err = importer.ImportBatch(ctx, job, "token", true, func(status string, cursor, stats map[string]any, jobErr string) error {
if jobErr != "" {
t.Fatalf("import error: %s", jobErr)
}
finalStatus = status
return nil
})
integrationtest.FailIf(err, t, "import batch")
if finalStatus == "completed" {
break
}
}
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
if deleted != 1 {
t.Fatalf("delta_deleted = %v, want 1", job.StatsJSON["delta_deleted"])
}
var count int
if err := h.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid).Scan(&count); err != nil {
t.Fatalf("count messages: %v", err)
}
if count != 0 {
t.Fatalf("message count = %d, want 0", count)
}
}
func TestGmailHistoryDeltaDeletesMessage(t *testing.T) { func TestGmailHistoryDeltaDeletesMessage(t *testing.T) {
h := integrationtest.RequireHarness(t) h := integrationtest.RequireHarness(t)
ctx := context.Background() ctx := context.Background()

View File

@ -4,6 +4,7 @@ package migration_test
import ( import (
"context" "context"
"encoding/json"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
@ -200,6 +201,38 @@ func TestMigrationInviteClaimFlow(t *testing.T) {
t.Fatalf("failed audit items = %+v", failedList.Items) t.Fatalf("failed audit items = %+v", failedList.Items)
} }
csvExportResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit/export?format=csv")
integrationtest.FailIf(err, t, "audit export csv")
integrationtest.FailUnlessStatus(t, csvExportResp, 200)
csvText := string(csvExportResp.Body)
if !strings.Contains(csvText, "item_id,rel_path,status,error,service,timestamp") {
t.Fatalf("csv headers missing: %q", csvText)
}
if !strings.Contains(csvText, "msg-fail") || !strings.Contains(csvText, "upload timeout") {
t.Fatalf("csv body = %q", csvText)
}
ndExportResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit/export?format=ndjson")
integrationtest.FailIf(err, t, "audit export ndjson")
integrationtest.FailUnlessStatus(t, ndExportResp, 200)
ndBody := string(ndExportResp.Body)
for _, line := range strings.Split(strings.TrimSpace(ndBody), "\n") {
if line == "" {
continue
}
var row struct {
ItemID string `json:"item_id"`
Status string `json:"status"`
Service string `json:"service"`
}
if err := json.Unmarshal([]byte(line), &row); err != nil {
t.Fatalf("ndjson line invalid: %q err=%v", line, err)
}
if row.ItemID == "" || row.Status == "" || row.Service == "" {
t.Fatalf("ndjson row incomplete: %+v", row)
}
}
resetResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/jobs/"+mailJobID+"/reset-cursor", nil) resetResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/jobs/"+mailJobID+"/reset-cursor", nil)
integrationtest.FailIf(err, t, "reset cursor") integrationtest.FailIf(err, t, "reset cursor")
integrationtest.FailUnlessStatus(t, resetResp, 200) integrationtest.FailUnlessStatus(t, resetResp, 200)
@ -247,12 +280,16 @@ func TestGraphImportWritesMessages(t *testing.T) {
integrationtest.FailIf(err, t, "insert mail account") integrationtest.FailIf(err, t, "insert mail account")
folderID := "inbox-folder-id" folderID := "inbox-folder-id"
sentFolderID := "sent-folder-id"
messagesListed := false messagesListed := false
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch { switch {
case strings.Contains(r.URL.Path, "/mailFolders"): case strings.HasSuffix(r.URL.Path, "/mailFolders"):
_, _ = w.Write([]byte(`{"value":[{"id":"` + folderID + `","displayName":"Inbox","wellKnownName":"inbox"}]}`)) _, _ = w.Write([]byte(`{"value":[
case strings.Contains(r.URL.Path, "/messages"): {"id":"` + folderID + `","displayName":"Inbox","wellKnownName":"inbox"},
{"id":"` + sentFolderID + `","displayName":"Sent Items","wellKnownName":"sentitems"}
]}`))
case strings.Contains(r.URL.Path, "/mailFolders/"+folderID+"/messages"):
messagesListed = true messagesListed = true
_, _ = w.Write([]byte(`{"value":[{ _, _ = w.Write([]byte(`{"value":[{
"id":"msg-1", "id":"msg-1",
@ -266,6 +303,8 @@ func TestGraphImportWritesMessages(t *testing.T) {
"isRead":true, "isRead":true,
"internetMessageId":"<graph-test@example.com>" "internetMessageId":"<graph-test@example.com>"
}]}`)) }]}`))
case strings.Contains(r.URL.Path, "/mailFolders/"+sentFolderID+"/messages"):
_, _ = w.Write([]byte(`{"value":[]}`))
default: default:
http.NotFound(w, r) http.NotFound(w, r)
} }
@ -278,16 +317,23 @@ func TestGraphImportWritesMessages(t *testing.T) {
CursorJSON: map[string]any{}, CursorJSON: map[string]any{},
StatsJSON: map[string]any{}, StatsJSON: map[string]any{},
} }
for {
var finalStatus string
err = importer.ImportBatch(ctx, job, "test-token", false, func(status string, cursor, stats map[string]any, jobErr string) error { err = importer.ImportBatch(ctx, job, "test-token", false, func(status string, cursor, stats map[string]any, jobErr string) error {
if jobErr != "" { if jobErr != "" {
t.Fatalf("import error: %s", jobErr) t.Fatalf("import error: %s", jobErr)
} }
if status != "completed" { finalStatus = status
t.Fatalf("status = %q, want completed", status)
}
return nil return nil
}) })
integrationtest.FailIf(err, t, "import batch") integrationtest.FailIf(err, t, "import batch")
if finalStatus == "completed" {
break
}
if finalStatus != "pending" {
t.Fatalf("status = %q, want pending or completed", finalStatus)
}
}
if !messagesListed { if !messagesListed {
t.Fatal("graph messages endpoint not called") t.Fatal("graph messages endpoint not called")
} }

View File

@ -0,0 +1,287 @@
//go:build integration
package migration_test
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"testing"
"github.com/google/uuid"
"github.com/ultisuite/ulti-backend/internal/integrationtest"
"github.com/ultisuite/ulti-backend/internal/migration"
"github.com/ultisuite/ulti-backend/internal/users"
)
const testProvisionSecret = "test-provision-secret"
func postProvision(t *testing.T, h *integrationtest.Harness, body map[string]any) *integrationtest.Response {
t.Helper()
data, err := json.Marshal(body)
integrationtest.FailIf(err, t, "marshal provision body")
req, err := http.NewRequest(http.MethodPost, h.Server.URL+"/internal/provision/user", bytes.NewReader(data))
integrationtest.FailIf(err, t, "provision request")
req.Header.Set("Content-Type", "application/json")
req.Header.Set("X-Provision-Secret", testProvisionSecret)
resp, err := http.DefaultClient.Do(req)
integrationtest.FailIf(err, t, "provision call")
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
integrationtest.FailIf(err, t, "read provision response")
return &integrationtest.Response{Status: resp.StatusCode, Body: bodyBytes, Header: resp.Header}
}
func TestProvisionEnrollThenClaim(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
domainName := "enroll-claim-" + uuid.NewString()[:8] + ".test"
var domainID string
err := h.Pool.QueryRow(ctx, `
INSERT INTO mail_domains (name, status, is_platform_domain)
VALUES ($1, 'active', false)
RETURNING id::text
`, domainName).Scan(&domainID)
integrationtest.FailIf(err, t, "insert domain")
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Enroll then claim",
"source_provider": "google",
"domain_id": domainID,
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
migrateeEmail := "user@" + domainName
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
"email": migrateeEmail,
})
integrationtest.FailIf(err, t, "create invite")
integrationtest.FailUnlessStatus(t, inviteResp, 201)
var invite struct {
Token string `json:"token"`
}
integrationtest.DecodeJSON(t, inviteResp, &invite)
externalID := integrationtest.NewExternalID("enroll-claim")
provisionResp := postProvision(t, h, map[string]any{
"email": migrateeEmail,
"name": "Migratee",
"password": "enroll-password-123",
"external_id": externalID,
})
if provisionResp.Status != 200 {
t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body))
}
migrateeClaims := integrationtest.RegularUser(externalID)
migrateeClaims.Email = migrateeEmail
migrateeClient, err := h.Client(migrateeClaims)
integrationtest.FailIf(err, t, "migratee client")
if _, err := users.EnsureUser(ctx, h.Pool, migrateeClaims); err != nil {
t.Fatalf("ensure migratee: %v", err)
}
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
"token": invite.Token,
"password": "claim-password-123",
})
integrationtest.FailIf(err, t, "claim invite")
integrationtest.FailUnlessStatus(t, claimResp, 200)
audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, migrateeEmail)
integrationtest.FailIf(err, t, "audit provision")
if audit.Users != 1 {
t.Fatalf("users = %d, want 1", audit.Users)
}
if audit.Mailboxes != 1 {
t.Fatalf("mailboxes = %d, want 1", audit.Mailboxes)
}
if audit.MailAccounts != 1 {
t.Fatalf("mail_accounts = %d, want 1", audit.MailAccounts)
}
}
func TestProvisionClaimThenEnroll(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
domainName := "claim-enroll-" + uuid.NewString()[:8] + ".test"
var domainID string
err := h.Pool.QueryRow(ctx, `
INSERT INTO mail_domains (name, status, is_platform_domain)
VALUES ($1, 'active', false)
RETURNING id::text
`, domainName).Scan(&domainID)
integrationtest.FailIf(err, t, "insert domain")
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Claim then enroll",
"source_provider": "google",
"domain_id": domainID,
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
migrateeEmail := "user@" + domainName
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
"email": migrateeEmail,
})
integrationtest.FailIf(err, t, "create invite")
integrationtest.FailUnlessStatus(t, inviteResp, 201)
var invite struct {
Token string `json:"token"`
}
integrationtest.DecodeJSON(t, inviteResp, &invite)
externalID := integrationtest.NewExternalID("claim-enroll")
migrateeClaims := integrationtest.RegularUser(externalID)
migrateeClaims.Email = migrateeEmail
migrateeClient, err := h.Client(migrateeClaims)
integrationtest.FailIf(err, t, "migratee client")
userID, err := users.EnsureUser(ctx, h.Pool, migrateeClaims)
integrationtest.FailIf(err, t, "ensure migratee")
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
"token": invite.Token,
"password": "claim-password-123",
})
integrationtest.FailIf(err, t, "claim invite")
integrationtest.FailUnlessStatus(t, claimResp, 200)
provisionResp := postProvision(t, h, map[string]any{
"email": migrateeEmail,
"name": "Migratee",
"password": "enroll-password-123",
"external_id": externalID,
})
if provisionResp.Status != 200 {
t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body))
}
var provisionBody struct {
UserID string `json:"user_id"`
}
integrationtest.DecodeJSON(t, provisionResp, &provisionBody)
if provisionBody.UserID != userID {
t.Fatalf("provision user_id = %q, want %q", provisionBody.UserID, userID)
}
audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, migrateeEmail)
integrationtest.FailIf(err, t, "audit provision")
if audit.Users != 1 {
t.Fatalf("users = %d, want 1", audit.Users)
}
if audit.Mailboxes != 1 {
t.Fatalf("mailboxes = %d, want 1", audit.Mailboxes)
}
if audit.MailAccounts != 1 {
t.Fatalf("mail_accounts = %d, want 1", audit.MailAccounts)
}
}
func TestProvisionEnrollmentDefersMailboxForPendingInvite(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
platformEmail := "pending-" + uuid.NewString()[:8] + "@ultisuite.local"
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Pending invite defer",
"source_provider": "google",
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
"email": platformEmail,
})
integrationtest.FailIf(err, t, "create invite")
integrationtest.FailUnlessStatus(t, inviteResp, 201)
externalID := integrationtest.NewExternalID("pending-invite")
provisionResp := postProvision(t, h, map[string]any{
"email": platformEmail,
"name": "Pending User",
"password": "enroll-password-123",
"external_id": externalID,
})
if provisionResp.Status != 200 {
t.Fatalf("provision status = %d, want 200; body=%s", provisionResp.Status, string(provisionResp.Body))
}
var provisionBody struct {
MailboxDeferred bool `json:"mailbox_deferred"`
}
integrationtest.DecodeJSON(t, provisionResp, &provisionBody)
if !provisionBody.MailboxDeferred {
t.Fatal("expected mailbox_deferred=true for pending invite enrollment")
}
audit, err := migration.AuditProvisionByEmail(ctx, h.Pool, platformEmail)
integrationtest.FailIf(err, t, "audit after deferred enroll")
if audit.Users != 1 {
t.Fatalf("users = %d, want 1", audit.Users)
}
if audit.Mailboxes != 0 {
t.Fatalf("mailboxes = %d, want 0 before claim", audit.Mailboxes)
}
}

View File

@ -0,0 +1,133 @@
//go:build integration
package migration_test
import (
"context"
"testing"
"github.com/google/uuid"
"github.com/ultisuite/ulti-backend/internal/integrationtest"
"github.com/ultisuite/ulti-backend/internal/users"
)
func TestMigrationRosterImportAndClaim(t *testing.T) {
h := integrationtest.RequireHarness(t)
ctx := context.Background()
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
t.Fatalf("ensure admin: %v", err)
}
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
t.Fatalf("grant admin: %v", err)
}
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
"name": "Roster migration",
"source_provider": "google",
})
integrationtest.FailIf(err, t, "create project")
integrationtest.FailUnlessStatus(t, createResp, 201)
var created struct {
ID string `json:"id"`
}
integrationtest.DecodeJSON(t, createResp, &created)
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
integrationtest.FailIf(err, t, "activate project")
integrationtest.FailUnlessStatus(t, actResp, 200)
csv := "email,display_name,alternate_emails\n" +
"migratee-" + uuid.NewString() + "@example.com,Test User,alt-" + uuid.NewString() + "@example.com\n"
importResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/roster", map[string]string{
"csv": csv,
})
integrationtest.FailIf(err, t, "import roster")
integrationtest.FailUnlessStatus(t, importResp, 200)
var importResult struct {
Created int `json:"created"`
SkippedDuplicates int `json:"skipped_duplicates"`
}
integrationtest.DecodeJSON(t, importResp, &importResult)
if importResult.Created != 1 {
t.Fatalf("expected 1 created, got %#v", importResult)
}
listResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/roster")
integrationtest.FailIf(err, t, "list roster")
integrationtest.FailUnlessStatus(t, listResp, 200)
var rosterList struct {
Roster []struct {
Email string `json:"email"`
DisplayName string `json:"display_name"`
Status string `json:"status"`
} `json:"roster"`
}
integrationtest.DecodeJSON(t, listResp, &rosterList)
if len(rosterList.Roster) != 1 {
t.Fatalf("expected 1 roster entry, got %d", len(rosterList.Roster))
}
if rosterList.Roster[0].Status != "invited" {
t.Fatalf("expected invited status, got %q", rosterList.Roster[0].Status)
}
migrateeEmail := rosterList.Roster[0].Email
dupResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/roster", map[string]string{
"csv": csv,
})
integrationtest.FailIf(err, t, "duplicate import")
integrationtest.FailUnlessStatus(t, dupResp, 200)
var dupResult struct {
SkippedDuplicates int `json:"skipped_duplicates"`
}
integrationtest.DecodeJSON(t, dupResp, &dupResult)
if dupResult.SkippedDuplicates != 1 {
t.Fatalf("expected 1 skipped duplicate, got %d", dupResult.SkippedDuplicates)
}
var inviteToken string
err = h.Pool.QueryRow(ctx, `
SELECT token FROM migration_invites WHERE project_id = $1::uuid AND email = $2
`, created.ID, migrateeEmail).Scan(&inviteToken)
if err != nil {
t.Fatalf("lookup invite token: %v", err)
}
if inviteToken == "" {
t.Fatal("missing invite token for roster entry")
}
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("roster-migratee"))
migrateeClaims.Email = migrateeEmail
migrateeClient, err := h.Client(migrateeClaims)
integrationtest.FailIf(err, t, "migratee client")
if _, err := users.EnsureUser(ctx, h.Pool, migrateeClaims); err != nil {
t.Fatalf("ensure migratee: %v", err)
}
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
"token": inviteToken,
"password": "test-password-123",
"display_name": "Test User",
})
integrationtest.FailIf(err, t, "claim invite")
integrationtest.FailUnlessStatus(t, claimResp, 200)
var rosterStatus string
err = h.Pool.QueryRow(ctx, `
SELECT status FROM migration_roster WHERE project_id = $1::uuid AND email = $2
`, created.ID, migrateeEmail).Scan(&rosterStatus)
if err != nil {
t.Fatalf("roster status: %v", err)
}
if rosterStatus != "claimed" {
t.Fatalf("expected claimed roster, got %q", rosterStatus)
}
}

View File

@ -99,6 +99,9 @@ func (s *OIDCServer) IssueToken(claims *auth.Claims) (string, error) {
"name": claims.Name, "name": claims.Name,
"groups": claims.Groups, "groups": claims.Groups,
}) })
if tid := strings.TrimSpace(claims.TID); tid != "" {
builder = builder.Claims(map[string]any{"tid": tid})
}
return builder.Serialize() return builder.Serialize()
} }

View File

@ -266,6 +266,118 @@ type ProvisionMailboxResult struct {
MailAccountID string MailAccountID string
} }
// EnsureMailboxProvisioned creates a mailbox or links an existing one to the requested user.
func (s *Service) EnsureMailboxProvisioned(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) {
email := strings.ToLower(strings.TrimSpace(in.Email))
existing, err := s.lookupMailboxByEmail(ctx, email)
if errors.Is(err, pgx.ErrNoRows) {
return s.ProvisionMailbox(ctx, in)
}
if err != nil {
return ProvisionMailboxResult{}, err
}
return s.reconcileExistingMailbox(ctx, existing, in)
}
func (s *Service) lookupMailboxByEmail(ctx context.Context, email string) (MailboxRow, error) {
email = strings.ToLower(strings.TrimSpace(email))
at := strings.LastIndex(email, "@")
if at <= 0 {
return MailboxRow{}, fmt.Errorf("invalid email")
}
localPart := email[:at]
domainName := email[at+1:]
localPart, err := normalizeLocalPart(localPart)
if err != nil {
return MailboxRow{}, err
}
var row MailboxRow
err = s.db.QueryRow(ctx, `
SELECT mb.id::text, mb.domain_id::text, mb.local_part,
lower(mb.local_part || '@' || md.name),
COALESCE(mb.user_id::text, ''),
COALESCE(mb.mail_account_id::text, ''),
mb.stalwart_account_id, mb.quota_bytes, mb.status
FROM mailboxes mb
JOIN mail_domains md ON md.id = mb.domain_id
WHERE md.name = $1 AND mb.local_part = $2
`, domainName, localPart).Scan(
&row.ID, &row.DomainID, &row.LocalPart, &row.Email,
&row.UserID, &row.MailAccountID, &row.StalwartAccountID, &row.QuotaBytes, &row.Status,
)
return row, err
}
func (s *Service) reconcileExistingMailbox(ctx context.Context, existing MailboxRow, in ProvisionMailboxInput) (ProvisionMailboxResult, error) {
userID := strings.TrimSpace(in.UserID)
if userID != "" && existing.UserID != "" && existing.UserID != userID {
return ProvisionMailboxResult{}, ErrAddressTaken
}
if userID == "" {
userID = existing.UserID
}
mailAccountID := existing.MailAccountID
if userID != "" && existing.UserID != userID {
if err := s.LinkMailboxToUser(ctx, existing.ID, userID); err != nil {
return ProvisionMailboxResult{}, err
}
existing.UserID = userID
}
if userID != "" && mailAccountID == "" && strings.TrimSpace(in.Password) != "" {
email := strings.ToLower(strings.TrimSpace(in.Email))
err := s.db.QueryRow(ctx, `
SELECT id::text FROM mail_accounts
WHERE user_id = $1::uuid AND lower(email) = $2
LIMIT 1
`, userID, email).Scan(&mailAccountID)
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
return ProvisionMailboxResult{}, err
}
if errors.Is(err, pgx.ErrNoRows) {
enc, err := s.encryptHostedCredential(email, in.Password)
if err != nil {
return ProvisionMailboxResult{}, err
}
displayName := strings.TrimSpace(in.DisplayName)
if displayName == "" {
displayName = email
}
err = s.db.QueryRow(ctx, `
INSERT INTO mail_accounts (
user_id, name, email, provider,
imap_host, imap_port, imap_tls,
smtp_host, smtp_port, smtp_tls,
credentials, is_active
)
VALUES ($1, $2, $3, 'hosted', $4, $5, $6, $7, $8, $9, $10, true)
RETURNING id::text
`, userID, displayName, email,
s.imapHost, s.imapPort, s.imapTLS,
s.smtpHost, s.smtpPort, s.smtpTLS,
enc,
).Scan(&mailAccountID)
if err != nil {
return ProvisionMailboxResult{}, err
}
}
_, err = s.db.Exec(ctx, `
UPDATE mailboxes SET mail_account_id = $1::uuid, updated_at = NOW()
WHERE id = $2::uuid AND (mail_account_id IS NULL OR mail_account_id = $1::uuid)
`, mailAccountID, existing.ID)
if err != nil {
return ProvisionMailboxResult{}, err
}
existing.MailAccountID = mailAccountID
}
return ProvisionMailboxResult{
Mailbox: existing,
MailAccountID: mailAccountID,
}, nil
}
func (s *Service) ProvisionMailbox(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) { func (s *Service) ProvisionMailbox(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) {
email := strings.ToLower(strings.TrimSpace(in.Email)) email := strings.ToLower(strings.TrimSpace(in.Email))
at := strings.LastIndex(email, "@") at := strings.LastIndex(email, "@")

View File

@ -11,6 +11,7 @@ type ClaimIdentity struct {
Email string Email string
PreferredUsername string PreferredUsername string
UPN string UPN string
TenantID string
} }
func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity { func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity {
@ -21,6 +22,7 @@ func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity {
Email: c.Email, Email: c.Email,
PreferredUsername: c.PreferredUsername, PreferredUsername: c.PreferredUsername,
UPN: c.UPN, UPN: c.UPN,
TenantID: c.TID,
} }
} }
@ -90,7 +92,17 @@ func inviteMatchTargets(inviteEmail string, alternateEmails []string) []string {
return out return out
} }
func localPartAliasMatch(a, b string) bool { func isGmailAliasDomain(domain string) bool {
switch strings.ToLower(strings.TrimSpace(domain)) {
case "gmail.com", "googlemail.com":
return true
default:
return false
}
}
// gmailLocalPartAliasMatch applies Gmail dot/plus normalization only on gmail.com / googlemail.com.
func gmailLocalPartAliasMatch(a, b string) bool {
aLocal, aDomain, okA := emailLocalAndDomain(a) aLocal, aDomain, okA := emailLocalAndDomain(a)
bLocal, bDomain, okB := emailLocalAndDomain(b) bLocal, bDomain, okB := emailLocalAndDomain(b)
if !okA || !okB { if !okA || !okB {
@ -99,6 +111,9 @@ func localPartAliasMatch(a, b string) bool {
if !strings.EqualFold(aDomain, bDomain) { if !strings.EqualFold(aDomain, bDomain) {
return false return false
} }
if !isGmailAliasDomain(aDomain) {
return false
}
return normalizeEmailLocalPart(aLocal) == normalizeEmailLocalPart(bLocal) return normalizeEmailLocalPart(aLocal) == normalizeEmailLocalPart(bLocal)
} }
@ -140,7 +155,7 @@ func InviteEmailMatchesIdentity(inviteEmail string, alternateEmails []string, pr
if candidate == target { if candidate == target {
return true return true
} }
if localPartAliasMatch(target, candidate) { if gmailLocalPartAliasMatch(target, candidate) {
return true return true
} }
} }
@ -148,3 +163,19 @@ func InviteEmailMatchesIdentity(inviteEmail string, alternateEmails []string, pr
return projectDomainUPNMatch(inviteEmail, projectDomain, identity) return projectDomainUPNMatch(inviteEmail, projectDomain, identity)
} }
// validateMicrosoftTenantClaim rejects claims when the OIDC tid does not match the project's pinned tenant.
func validateMicrosoftTenantClaim(proj Project, tokenTenantID string) error {
if strings.ToLower(strings.TrimSpace(proj.SourceProvider)) != "microsoft" {
return nil
}
expected := strings.TrimSpace(proj.MicrosoftTenantID)
if expected == "" {
return nil
}
got := strings.TrimSpace(tokenTenantID)
if got == "" || !strings.EqualFold(got, expected) {
return ErrTenantMismatch
}
return nil
}

View File

@ -41,16 +41,24 @@ func TestInviteEmailMatchesIdentityGmailDotAlias(t *testing.T) {
if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) { if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) {
t.Fatal("expected exact match baseline") t.Fatal("expected exact match baseline")
} }
id = ClaimIdentity{Email: "a.l.i.c.e.smith@gmail.com"}
if !InviteEmailMatchesIdentity("alice.smith@gmail.com", nil, "", id) {
t.Fatal("expected dot-insensitive local-part match on gmail.com")
}
id = ClaimIdentity{Email: "a.l.i.c.e.smith@acme.com"} id = ClaimIdentity{Email: "a.l.i.c.e.smith@acme.com"}
if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) { if InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) {
t.Fatal("expected dot-insensitive local-part match") t.Fatal("expected reject dot-alias on non-gmail domain")
} }
} }
func TestInviteEmailMatchesIdentityPlusTag(t *testing.T) { func TestInviteEmailMatchesIdentityPlusTag(t *testing.T) {
id := ClaimIdentity{Email: "alice+tag@acme.com"} id := ClaimIdentity{Email: "alice+tag@gmail.com"}
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) { if !InviteEmailMatchesIdentity("alice@gmail.com", nil, "", id) {
t.Fatal("expected plus-tag stripped match") t.Fatal("expected plus-tag stripped match on gmail.com")
}
id = ClaimIdentity{Email: "alice+tag@acme.com"}
if InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
t.Fatal("expected reject plus-tag alias on non-gmail domain")
} }
} }
@ -90,3 +98,29 @@ func TestInviteEmailMatchesIdentityIgnoresNonEmailPreferredUsername(t *testing.T
t.Fatal("expected reject when preferred_username is not an email") t.Fatal("expected reject when preferred_username is not an email")
} }
} }
func TestValidateMicrosoftTenantClaim(t *testing.T) {
msProj := Project{SourceProvider: "microsoft", MicrosoftTenantID: "tenant-abc"}
if err := validateMicrosoftTenantClaim(msProj, "tenant-abc"); err != nil {
t.Fatalf("expected match: %v", err)
}
if err := validateMicrosoftTenantClaim(msProj, "TENANT-ABC"); err != nil {
t.Fatalf("expected case-insensitive match: %v", err)
}
if err := validateMicrosoftTenantClaim(msProj, "other-tenant"); err != ErrTenantMismatch {
t.Fatalf("expected tenant mismatch, got %v", err)
}
if err := validateMicrosoftTenantClaim(msProj, ""); err != ErrTenantMismatch {
t.Fatalf("expected reject empty tid when tenant pinned: %v", err)
}
googleProj := Project{SourceProvider: "google", MicrosoftTenantID: "tenant-abc"}
if err := validateMicrosoftTenantClaim(googleProj, "wrong"); err != nil {
t.Fatalf("google project should ignore tenant: %v", err)
}
noTenant := Project{SourceProvider: "microsoft"}
if err := validateMicrosoftTenantClaim(noTenant, "any"); err != nil {
t.Fatalf("expected skip when project tenant unset: %v", err)
}
}

View File

@ -83,8 +83,9 @@ func (d *DriveImporter) importGoogleDriveDelta(ctx context.Context, job *Job, ac
return fmt.Errorf("google drive delta token missing") return fmt.Errorf("google drive delta token missing")
} }
listURL := "https://www.googleapis.com/drive/v3/changes?pageSize=100&spaces=drive&includeRemoved=true&fields=" + listURL := "https://www.googleapis.com/drive/v3/changes?pageSize=100&spaces=drive&includeRemoved=true" +
url.QueryEscape("nextPageToken,newStartPageToken,changes(fileId,removed,file(id,name,mimeType,size,parents,trashed))") + "&includeItemsFromAllDrives=true&supportsAllDrives=true&fields=" +
url.QueryEscape("nextPageToken,newStartPageToken,changes(fileId,removed,file(id,name,mimeType,size,parents,trashed,driveId))") +
"&pageToken=" + url.QueryEscape(pageToken) "&pageToken=" + url.QueryEscape(pageToken)
body, err := apiGet(ctx, d.client, listURL, accessToken) body, err := apiGet(ctx, d.client, listURL, accessToken)
@ -119,7 +120,7 @@ func (d *DriveImporter) importGoogleDriveDelta(ctx context.Context, job *Job, ac
} }
item := googleFileToDriveItem(*change.File) item := googleFileToDriveItem(*change.File)
relPath := d.resolveDriveRelPath(items, item) relPath := d.resolveDriveRelPath(items, item)
if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, item, items, &imported, &exported, &skipped, job.StatsJSON); err != nil { if err := d.uploadDriveItem(ctx, job, accessToken, ncUserID, root, relPath, item, items, &imported, &exported, &skipped, job.StatsJSON); err != nil {
return err return err
} }
batch++ batch++
@ -160,6 +161,7 @@ type googleDriveFile struct {
Size string `json:"size"` Size string `json:"size"`
Parents []string `json:"parents"` Parents []string `json:"parents"`
Trashed bool `json:"trashed"` Trashed bool `json:"trashed"`
DriveID string `json:"driveId"`
} }
func googleFileToDriveItem(f googleDriveFile) driveItem { func googleFileToDriveItem(f googleDriveFile) driveItem {
@ -173,6 +175,7 @@ func googleFileToDriveItem(f googleDriveFile) driveItem {
IsFolder: f.MimeType == "application/vnd.google-apps.folder", IsFolder: f.MimeType == "application/vnd.google-apps.folder",
Size: size, Size: size,
MimeType: f.MimeType, MimeType: f.MimeType,
DriveID: f.DriveID,
} }
if len(f.Parents) > 0 { if len(f.Parents) > 0 {
item.ParentID = f.Parents[0] item.ParentID = f.Parents[0]
@ -186,7 +189,7 @@ func googleFileToDriveItem(f googleDriveFile) driveItem {
item.ExportExt = ext item.ExportExt = ext
item.Name = driveExportFileName(f.Name, ext) item.Name = driveExportFileName(f.Name, ext)
} else { } else {
item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media" item.Download = googleDriveDownloadURL(f.ID, f.DriveID != "")
} }
return item return item
} }
@ -232,7 +235,7 @@ func (d *DriveImporter) importMicrosoftDriveDelta(ctx context.Context, job *Job,
} }
driveItem := graphDriveToItem(d.userUPN, item) driveItem := graphDriveToItem(d.userUPN, item)
relPath := d.resolveDriveRelPath(items, driveItem) relPath := d.resolveDriveRelPath(items, driveItem)
if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, driveItem, items, &imported, nil, &skipped, job.StatsJSON); err != nil { if err := d.uploadDriveItem(ctx, job, accessToken, ncUserID, root, relPath, driveItem, items, &imported, nil, &skipped, job.StatsJSON); err != nil {
return err return err
} }
batch++ batch++
@ -307,8 +310,15 @@ func (d *DriveImporter) resolveDriveRelPath(items *ImportedItemStore, item drive
return path.Join(parentRel, sanitizeDrivePath(item.Name)) return path.Join(parentRel, sanitizeDrivePath(item.Name))
} }
func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUserID, root, relPath string, item driveItem, items *ImportedItemStore, imported, exported, skipped *float64, stats map[string]any) error { func (d *DriveImporter) uploadDriveItem(ctx context.Context, job *Job, accessToken, ncUserID, root, relPath string, item driveItem, items *ImportedItemStore, imported, exported, skipped *float64, stats map[string]any) error {
targetPath := path.Join(root, relPath) targetPath := path.Join(root, relPath)
shared := item.DriveID != ""
if d.alreadyImportedShared(item.DriveID, item.ID, shared) {
if skipped != nil {
*skipped++
}
return items.MarkSkipped(ctx, item.ID, "dedup: shared drive file already imported by project", relPath)
}
if item.IsFolder { if item.IsFolder {
if err := d.nc.CreateFolder(ctx, ncUserID, targetPath); err != nil { if err := d.nc.CreateFolder(ctx, ncUserID, targetPath); err != nil {
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
@ -335,7 +345,7 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser
} }
targetPath = path.Join(path.Dir(targetPath), fileName) targetPath = path.Join(path.Dir(targetPath), fileName)
relPath = path.Join(path.Dir(relPath), fileName) relPath = path.Join(path.Dir(relPath), fileName)
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, 0); err != nil {
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
return markErr return markErr
} }
@ -361,13 +371,6 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser
} }
} }
} else { } else {
if item.Size > maxDriveFileBytes {
if skipped != nil {
*skipped++
}
reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes)
return items.MarkSkipped(ctx, item.ID, reason, relPath)
}
content, contentType, err := d.downloadDriveFile(ctx, accessToken, item) content, contentType, err := d.downloadDriveFile(ctx, accessToken, item)
if err != nil { if err != nil {
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
@ -376,7 +379,7 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser
incJobStat(stats, "failed") incJobStat(stats, "failed")
return nil return nil
} }
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, item.Size); err != nil {
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
return markErr return markErr
} }
@ -390,6 +393,9 @@ func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUser
if err := items.MarkPath(ctx, item.ID, relPath); err != nil { if err := items.MarkPath(ctx, item.ID, relPath); err != nil {
return err return err
} }
if err := d.markSharedImported(ctx, item.DriveID, item.ID, relPath, job.ID, shared); err != nil {
return err
}
if imported != nil { if imported != nil {
*imported++ *imported++
} }

View File

@ -44,6 +44,8 @@ func driveExportFileName(name, ext string) string {
type driveFolderRef struct { type driveFolderRef struct {
ID string ID string
Path string Path string
DriveID string // Google shared drive ID; empty for My Drive
Shared bool
} }
func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderRef { func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderRef {
@ -56,8 +58,10 @@ func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderR
} }
id, _ := m["id"].(string) id, _ := m["id"].(string)
p, _ := m["path"].(string) p, _ := m["path"].(string)
driveID, _ := m["driveId"].(string)
shared, _ := m["shared"].(bool)
if id != "" { if id != "" {
out = append(out, driveFolderRef{ID: id, Path: p}) out = append(out, driveFolderRef{ID: id, Path: p, DriveID: driveID, Shared: shared})
} }
} }
if len(out) == 0 { if len(out) == 0 {
@ -72,14 +76,16 @@ func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderR
func writeDriveFolderQueue(cursor map[string]any, queue []driveFolderRef) { func writeDriveFolderQueue(cursor map[string]any, queue []driveFolderRef) {
raw := make([]any, 0, len(queue)) raw := make([]any, 0, len(queue))
for _, f := range queue { for _, f := range queue {
raw = append(raw, map[string]any{"id": f.ID, "path": f.Path}) raw = append(raw, map[string]any{
"id": f.ID, "path": f.Path, "driveId": f.DriveID, "shared": f.Shared,
})
} }
cursor["folderQueue"] = raw cursor["folderQueue"] = raw
} }
func enqueueDriveFolder(queue []driveFolderRef, folder driveFolderRef) []driveFolderRef { func enqueueDriveFolder(queue []driveFolderRef, folder driveFolderRef) []driveFolderRef {
for _, existing := range queue { for _, existing := range queue {
if existing.ID == folder.ID { if existing.ID == folder.ID && existing.DriveID == folder.DriveID {
return queue return queue
} }
} }

View File

@ -22,6 +22,9 @@ type DriveImporter struct {
nc *nextcloud.Client nc *nextcloud.Client
client *http.Client client *http.Client
userUPN string userUPN string
projectID string
sharedDriveMode string
sharedDedup *SharedDriveItemStore
} }
func NewDriveImporter(db *pgxpool.Pool, nc *nextcloud.Client) *DriveImporter { func NewDriveImporter(db *pgxpool.Pool, nc *nextcloud.Client) *DriveImporter {
@ -40,6 +43,31 @@ func (d *DriveImporter) WithHTTPClient(c *http.Client) *DriveImporter {
return d return d
} }
func (d *DriveImporter) WithProject(projectID, sharedDriveMode string, dedup *SharedDriveItemStore) *DriveImporter {
d.projectID = strings.TrimSpace(projectID)
d.sharedDriveMode = NormalizeSharedDriveMode(sharedDriveMode)
d.sharedDedup = dedup
return d
}
func (d *DriveImporter) isSharedDriveDedup(driveID string, shared bool) bool {
return shared && driveID != "" && d.sharedDedup != nil
}
func (d *DriveImporter) alreadyImportedShared(driveID, sourceID string, shared bool) bool {
if !d.isSharedDriveDedup(driveID, shared) {
return false
}
return d.sharedDedup.Has(driveID, sourceID)
}
func (d *DriveImporter) markSharedImported(ctx context.Context, driveID, sourceID, relPath, jobID string, shared bool) error {
if !d.isSharedDriveDedup(driveID, shared) {
return nil
}
return d.sharedDedup.MarkImported(ctx, driveID, sourceID, relPath, jobID)
}
func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error { func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error {
if d.nc == nil { if d.nc == nil {
return fmt.Errorf("nextcloud required for drive migration") return fmt.Errorf("nextcloud required for drive migration")
@ -59,6 +87,18 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
return err return err
} }
if provider == "google" && !jsonBool(job.CursorJSON["sharedDrivesBootstrapped"]) {
if err := d.bootstrapSharedDrives(ctx, job, accessToken); err != nil {
return err
}
job.CursorJSON["sharedDrivesBootstrapped"] = true
}
if provider == "google" {
if err := d.mergeSharedDriveFolders(ctx, job, provider); err != nil {
return err
}
}
if delta && d.hasDriveDeltaCursor(job, provider) { if delta && d.hasDriveDeltaCursor(job, provider) {
return d.importDriveDelta(ctx, job, accessToken, provider, ncUserID, root, store, update) return d.importDriveDelta(ctx, job, accessToken, provider, ncUserID, root, store, update)
} }
@ -99,6 +139,14 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
if alreadyImported(store, item.ID) { if alreadyImported(store, item.ID) {
continue continue
} }
if d.alreadyImportedShared(current.DriveID, item.ID, current.Shared) {
skipped++
if err := store.MarkSkipped(ctx, item.ID, "dedup: shared drive file already imported by project", relPathForItem(current, item)); err != nil {
return err
}
batch++
continue
}
relPath := path.Join(current.Path, sanitizeDrivePath(item.Name)) relPath := path.Join(current.Path, sanitizeDrivePath(item.Name))
targetPath := path.Join(root, relPath) targetPath := path.Join(root, relPath)
if item.IsFolder { if item.IsFolder {
@ -113,7 +161,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
if err := store.MarkPath(ctx, item.ID, relPath); err != nil { if err := store.MarkPath(ctx, item.ID, relPath); err != nil {
return err return err
} }
queue = enqueueDriveFolder(queue, driveFolderRef{ID: item.ID, Path: relPath}) queue = enqueueDriveFolder(queue, driveFolderRef{
ID: item.ID, Path: relPath, DriveID: current.DriveID, Shared: current.Shared,
})
} else { } else {
if item.Export { if item.Export {
content, contentType, fileName, err := d.downloadGoogleExport(ctx, accessToken, item) content, contentType, fileName, err := d.downloadGoogleExport(ctx, accessToken, item)
@ -127,7 +177,7 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
} }
targetPath = path.Join(path.Dir(targetPath), fileName) targetPath = path.Join(path.Dir(targetPath), fileName)
relPath = path.Join(path.Dir(relPath), fileName) relPath = path.Join(path.Dir(relPath), fileName)
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, 0); err != nil {
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
return markErr return markErr
} }
@ -151,15 +201,6 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
} }
} }
} else { } else {
if item.Size > maxDriveFileBytes {
skipped++
reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes)
if err := store.MarkSkipped(ctx, item.ID, reason, relPath); err != nil {
return err
}
batch++
continue
}
content, contentType, err := d.downloadDriveFile(ctx, accessToken, item) content, contentType, err := d.downloadDriveFile(ctx, accessToken, item)
if err != nil { if err != nil {
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
@ -169,7 +210,7 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
batch++ batch++
continue continue
} }
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil { if err := d.uploadToNextcloud(ctx, ncUserID, targetPath, content, contentType, item.Size); err != nil {
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil { if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
return markErr return markErr
} }
@ -186,6 +227,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
if err := store.MarkPath(ctx, item.ID, relPath); err != nil { if err := store.MarkPath(ctx, item.ID, relPath); err != nil {
return err return err
} }
if err := d.markSharedImported(ctx, current.DriveID, item.ID, relPath, job.ID, current.Shared); err != nil {
return err
}
} }
imported++ imported++
batch++ batch++
@ -193,7 +237,9 @@ func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken,
for _, sub := range subfolders { for _, sub := range subfolders {
relPath := path.Join(current.Path, sanitizeDrivePath(sub.Name)) relPath := path.Join(current.Path, sanitizeDrivePath(sub.Name))
queue = enqueueDriveFolder(queue, driveFolderRef{ID: sub.ID, Path: relPath}) queue = enqueueDriveFolder(queue, driveFolderRef{
ID: sub.ID, Path: relPath, DriveID: current.DriveID, Shared: current.Shared,
})
} }
writeDriveFolderQueue(job.CursorJSON, queue) writeDriveFolderQueue(job.CursorJSON, queue)
@ -234,6 +280,7 @@ type driveItem struct {
Export bool Export bool
ExportMime string ExportMime string
ExportExt string ExportExt string
DriveID string
} }
type driveSubfolder struct { type driveSubfolder struct {
@ -247,6 +294,7 @@ func (d *DriveImporter) listDriveFolderItems(ctx context.Context, accessToken, p
pageToken, _ := cursor["pageToken"].(string) pageToken, _ := cursor["pageToken"].(string)
q := url.QueryEscape("'" + folder.ID + "' in parents and trashed=false") q := url.QueryEscape("'" + folder.ID + "' in parents and trashed=false")
listURL := "https://www.googleapis.com/drive/v3/files?pageSize=100&fields=nextPageToken,files(id,name,mimeType,size)&q=" + q listURL := "https://www.googleapis.com/drive/v3/files?pageSize=100&fields=nextPageToken,files(id,name,mimeType,size)&q=" + q
listURL += googleDriveListParams(folder)
if pageToken != "" { if pageToken != "" {
listURL += "&pageToken=" + url.QueryEscape(pageToken) listURL += "&pageToken=" + url.QueryEscape(pageToken)
} }
@ -289,7 +337,7 @@ func (d *DriveImporter) listDriveFolderItems(ctx context.Context, accessToken, p
item.ExportExt = ext item.ExportExt = ext
item.Name = driveExportFileName(f.Name, ext) item.Name = driveExportFileName(f.Name, ext)
} else { } else {
item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media" item.Download = googleDriveDownloadURL(f.ID, folder.Shared)
} }
out = append(out, item) out = append(out, item)
} }
@ -378,6 +426,9 @@ func (d *DriveImporter) downloadGoogleExport(ctx context.Context, accessToken st
url.PathEscape(item.ID), url.PathEscape(item.ID),
url.QueryEscape(item.ExportMime), url.QueryEscape(item.ExportMime),
) )
if item.DriveID != "" {
exportURL += "&supportsAllDrives=true"
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, exportURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, exportURL, nil)
if err != nil { if err != nil {
return nil, "", "", err return nil, "", "", err
@ -403,3 +454,95 @@ func sanitizeDrivePath(name string) string {
} }
return name return name
} }
func relPathForItem(folder driveFolderRef, item driveItem) string {
return path.Join(folder.Path, sanitizeDrivePath(item.Name))
}
func jsonBool(v any) bool {
switch t := v.(type) {
case bool:
return t
case float64:
return t != 0
case string:
return t == "true" || t == "1"
default:
return false
}
}
func googleDriveListParams(folder driveFolderRef) string {
if folder.Shared && folder.DriveID != "" {
return "&corpora=drive&driveId=" + url.QueryEscape(folder.DriveID) +
"&includeItemsFromAllDrives=true&supportsAllDrives=true"
}
return "&supportsAllDrives=true"
}
func googleDriveDownloadURL(fileID string, shared bool) string {
u := "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(fileID) + "?alt=media"
if shared {
u += "&supportsAllDrives=true"
}
return u
}
func (d *DriveImporter) uploadToNextcloud(ctx context.Context, ncUserID, targetPath string, content io.ReadCloser, contentType string, size int64) error {
defer content.Close()
if size > maxDriveFileBytes {
return d.nc.UploadStreaming(ctx, ncUserID, targetPath, content, contentType, size)
}
return d.nc.Upload(ctx, ncUserID, targetPath, content, contentType)
}
func (d *DriveImporter) bootstrapSharedDrives(ctx context.Context, job *Job, accessToken string) error {
pageToken := ""
for {
listURL := "https://www.googleapis.com/drive/v3/drives?pageSize=100&fields=nextPageToken,drives(id,name)"
if pageToken != "" {
listURL += "&pageToken=" + url.QueryEscape(pageToken)
}
body, err := apiGet(ctx, d.client, listURL, accessToken)
if err != nil {
return err
}
var parsed struct {
Drives []struct {
ID string `json:"id"`
Name string `json:"name"`
} `json:"drives"`
NextPageToken string `json:"nextPageToken"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return err
}
for _, drive := range parsed.Drives {
if err := d.upsertDiscoveredSharedDrive(ctx, job.ProjectID, job.UserID, drive.ID, drive.Name, d.sharedDriveMode); err != nil {
return err
}
}
if parsed.NextPageToken == "" {
break
}
pageToken = parsed.NextPageToken
}
return nil
}
func (d *DriveImporter) mergeSharedDriveFolders(ctx context.Context, job *Job, provider string) error {
if provider != "google" {
return nil
}
queue := readDriveFolderQueue(job.CursorJSON, provider)
sharedFolders, err := d.loadApprovedSharedDriveFolders(ctx, job.ProjectID)
if err != nil {
return err
}
for _, folder := range sharedFolders {
queue = enqueueDriveFolder(queue, folder)
}
writeDriveFolderQueue(job.CursorJSON, queue)
return nil
}

View File

@ -0,0 +1,237 @@
package migration
import (
"context"
"fmt"
"github.com/jackc/pgx/v5/pgxpool"
)
const (
SharedDriveModeAuto = "auto"
SharedDriveModeManual = "manual"
SharedDriveStatusPending = "pending"
SharedDriveStatusApproved = "approved"
SharedDriveStatusRejected = "rejected"
)
type SharedDrive struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
DriveID string `json:"drive_id"`
Name string `json:"name"`
Status string `json:"status"`
DiscoveredByUserID *string `json:"discovered_by_user_id,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
func NormalizeSharedDriveMode(mode string) string {
switch mode {
case SharedDriveModeManual:
return SharedDriveModeManual
default:
return SharedDriveModeAuto
}
}
// SharedDriveItemStore tracks project-level imports for shared drive files (cross-user dedup).
type SharedDriveItemStore struct {
db *pgxpool.Pool
projectID string
done map[string]struct{} // key: driveID + ":" + sourceID
}
func NewSharedDriveItemStoreMemory() *SharedDriveItemStore {
return &SharedDriveItemStore{done: map[string]struct{}{}}
}
func LoadSharedDriveItemStore(ctx context.Context, db *pgxpool.Pool, projectID string) (*SharedDriveItemStore, error) {
store := &SharedDriveItemStore{
db: db,
projectID: projectID,
done: map[string]struct{}{},
}
if db == nil || projectID == "" {
return store, nil
}
rows, err := db.Query(ctx, `
SELECT drive_id, source_id
FROM migration_shared_drive_items
WHERE project_id = $1::uuid
`, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var driveID, sourceID string
if err := rows.Scan(&driveID, &sourceID); err != nil {
return nil, err
}
store.done[sharedDriveItemKey(driveID, sourceID)] = struct{}{}
}
return store, rows.Err()
}
func sharedDriveItemKey(driveID, sourceID string) string {
return driveID + ":" + sourceID
}
func (s *SharedDriveItemStore) Has(driveID, sourceID string) bool {
if s == nil || driveID == "" || sourceID == "" {
return false
}
_, ok := s.done[sharedDriveItemKey(driveID, sourceID)]
return ok
}
func (s *SharedDriveItemStore) MarkImported(ctx context.Context, driveID, sourceID, relPath, jobID string) error {
if driveID == "" || sourceID == "" {
return nil
}
s.done[sharedDriveItemKey(driveID, sourceID)] = struct{}{}
if s.db == nil || s.projectID == "" {
return nil
}
_, err := s.db.Exec(ctx, `
INSERT INTO migration_shared_drive_items (project_id, drive_id, source_id, rel_path, imported_by_job_id)
VALUES ($1::uuid, $2, $3, $4, NULLIF($5, '')::uuid)
ON CONFLICT (project_id, drive_id, source_id) DO NOTHING
`, s.projectID, driveID, sourceID, relPath, jobID)
return err
}
func (s *Service) UpdateSharedDriveMode(ctx context.Context, projectID, mode string) (Project, error) {
mode = NormalizeSharedDriveMode(mode)
sc := newProjectScanner()
err := s.db.QueryRow(ctx, `
UPDATE migration_projects
SET shared_drive_mode = $2, updated_at = NOW()
WHERE id = $1::uuid
RETURNING `+projectSelectSQL("")+`
`, projectID, mode).Scan(sc.targets()...)
return sc.result(), err
}
func (s *Service) ListSharedDrives(ctx context.Context, projectID, statusFilter string) ([]SharedDrive, error) {
query := `
SELECT id::text, project_id::text, drive_id, name, status,
NULLIF(discovered_by_user_id::text, ''), created_at::text, updated_at::text
FROM migration_shared_drives
WHERE project_id = $1::uuid
`
args := []any{projectID}
if statusFilter != "" {
query += ` AND status = $2`
args = append(args, statusFilter)
}
query += ` ORDER BY name ASC, created_at ASC`
rows, err := s.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var out []SharedDrive
for rows.Next() {
var row SharedDrive
if err := rows.Scan(
&row.ID, &row.ProjectID, &row.DriveID, &row.Name, &row.Status,
&row.DiscoveredByUserID, &row.CreatedAt, &row.UpdatedAt,
); err != nil {
return nil, err
}
out = append(out, row)
}
return out, rows.Err()
}
func (s *Service) SetSharedDriveStatus(ctx context.Context, projectID, driveID, status string) (SharedDrive, error) {
var row SharedDrive
err := s.db.QueryRow(ctx, `
UPDATE migration_shared_drives
SET status = $3, updated_at = NOW()
WHERE project_id = $1::uuid AND drive_id = $2
RETURNING id::text, project_id::text, drive_id, name, status,
NULLIF(discovered_by_user_id::text, ''), created_at::text, updated_at::text
`, projectID, driveID, status).Scan(
&row.ID, &row.ProjectID, &row.DriveID, &row.Name, &row.Status,
&row.DiscoveredByUserID, &row.CreatedAt, &row.UpdatedAt,
)
if err != nil {
return SharedDrive{}, fmt.Errorf("shared drive not found")
}
return row, nil
}
func (s *Service) ApproveSharedDrive(ctx context.Context, projectID, driveID string) (SharedDrive, error) {
return s.SetSharedDriveStatus(ctx, projectID, driveID, SharedDriveStatusApproved)
}
func (s *Service) RejectSharedDrive(ctx context.Context, projectID, driveID string) (SharedDrive, error) {
return s.SetSharedDriveStatus(ctx, projectID, driveID, SharedDriveStatusRejected)
}
func (d *DriveImporter) upsertDiscoveredSharedDrive(ctx context.Context, projectID, userID, driveID, name, mode string) error {
if d.db == nil {
return nil
}
autoApprove := NormalizeSharedDriveMode(mode) == SharedDriveModeAuto
initialStatus := SharedDriveStatusPending
if autoApprove {
initialStatus = SharedDriveStatusApproved
}
_, err := d.db.Exec(ctx, `
INSERT INTO migration_shared_drives (project_id, drive_id, name, status, discovered_by_user_id)
VALUES ($1::uuid, $2, $3, $4, NULLIF($5, '')::uuid)
ON CONFLICT (project_id, drive_id) DO UPDATE
SET name = COALESCE(NULLIF(EXCLUDED.name, ''), migration_shared_drives.name),
status = CASE
WHEN migration_shared_drives.status = 'rejected' THEN 'rejected'
WHEN migration_shared_drives.status = 'approved' THEN 'approved'
WHEN $6 = 'auto' THEN 'approved'
ELSE migration_shared_drives.status
END,
updated_at = NOW()
`, projectID, driveID, name, initialStatus, userID, NormalizeSharedDriveMode(mode))
return err
}
func (d *DriveImporter) loadApprovedSharedDriveFolders(ctx context.Context, projectID string) ([]driveFolderRef, error) {
if d.db == nil {
return nil, nil
}
rows, err := d.db.Query(ctx, `
SELECT drive_id, name
FROM migration_shared_drives
WHERE project_id = $1::uuid AND status = 'approved'
ORDER BY name ASC
`, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []driveFolderRef
for rows.Next() {
var id, name string
if err := rows.Scan(&id, &name); err != nil {
return nil, err
}
out = append(out, driveFolderRef{
ID: id,
Path: pathJoinSharedDrive(name),
DriveID: id,
Shared: true,
})
}
return out, rows.Err()
}
func pathJoinSharedDrive(name string) string {
return "Shared Drives/" + sanitizeDrivePath(name)
}

View File

@ -0,0 +1,116 @@
package migration
import (
"context"
"net/http"
"strings"
"testing"
)
func TestSharedDriveItemDedup(t *testing.T) {
store := NewSharedDriveItemStoreMemory()
ctx := context.Background()
if store.Has("drive-1", "file-1") {
t.Fatal("expected miss before mark")
}
if err := store.MarkImported(ctx, "drive-1", "file-1", "Shared Drives/Team/doc.pdf", "job-1"); err != nil {
t.Fatal(err)
}
if !store.Has("drive-1", "file-1") {
t.Fatal("expected hit after mark")
}
if store.Has("drive-1", "file-2") {
t.Fatal("different file should not match")
}
if store.Has("drive-2", "file-1") {
t.Fatal("different drive should not match")
}
}
func TestDriveImporterSharedDedup(t *testing.T) {
d := &DriveImporter{sharedDedup: NewSharedDriveItemStoreMemory()}
ctx := context.Background()
if err := d.sharedDedup.MarkImported(ctx, "sd-1", "f-1", "path", "job-a"); err != nil {
t.Fatal(err)
}
if !d.alreadyImportedShared("sd-1", "f-1", true) {
t.Fatal("expected shared dedup hit")
}
if d.alreadyImportedShared("sd-1", "f-2", true) {
t.Fatal("expected miss for other file")
}
if d.alreadyImportedShared("", "f-1", false) {
t.Fatal("personal drive should not dedup at project level")
}
}
func TestBootstrapSharedDrivesDiscovery(t *testing.T) {
client := mockGoogleHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/drive/v3/drives") {
_, _ = w.Write([]byte(`{"drives":[{"id":"sd-team","name":"Team Drive"}]}`))
return
}
http.NotFound(w, r)
})
d := NewDriveImporter(nil, nil).WithHTTPClient(client).WithProject("proj-1", SharedDriveModeAuto, NewSharedDriveItemStoreMemory())
job := &Job{ProjectID: "proj-1", UserID: "user-1", CursorJSON: map[string]any{}}
if err := d.bootstrapSharedDrives(context.Background(), job, "token"); err != nil {
t.Fatalf("bootstrap: %v", err)
}
}
func TestGoogleDriveListParams(t *testing.T) {
myDrive := googleDriveListParams(driveFolderRef{ID: "root"})
if !strings.Contains(myDrive, "supportsAllDrives=true") {
t.Fatalf("my drive params: %q", myDrive)
}
if strings.Contains(myDrive, "corpora=drive") {
t.Fatalf("my drive should not use corpora=drive: %q", myDrive)
}
shared := googleDriveListParams(driveFolderRef{ID: "sd-1", DriveID: "sd-1", Shared: true})
if !strings.Contains(shared, "corpora=drive") || !strings.Contains(shared, "driveId=sd-1") {
t.Fatalf("shared drive params: %q", shared)
}
}
func TestGoogleDriveDownloadURL(t *testing.T) {
personal := googleDriveDownloadURL("file-1", false)
if strings.Contains(personal, "supportsAllDrives") {
t.Fatalf("personal download: %q", personal)
}
shared := googleDriveDownloadURL("file-1", true)
if !strings.Contains(shared, "supportsAllDrives=true") {
t.Fatalf("shared download: %q", shared)
}
}
func TestMergeSharedDriveFolders(t *testing.T) {
d := NewDriveImporter(nil, nil)
job := &Job{CursorJSON: map[string]any{}}
if err := d.mergeSharedDriveFolders(context.Background(), job, "google"); err != nil {
t.Fatal(err)
}
queue := readDriveFolderQueue(job.CursorJSON, "google")
if len(queue) != 1 || queue[0].ID != "root" {
t.Fatalf("without db only root queue: %#v", queue)
}
manualQueue := readDriveFolderQueue(map[string]any{}, "google")
manualQueue = enqueueDriveFolder(manualQueue, driveFolderRef{
ID: "sd-1", Path: "Shared Drives/Finance", DriveID: "sd-1", Shared: true,
})
if len(manualQueue) != 2 || !manualQueue[1].Shared {
t.Fatalf("manual enqueue: %#v", manualQueue)
}
}
func TestNormalizeSharedDriveMode(t *testing.T) {
if got := NormalizeSharedDriveMode("manual"); got != SharedDriveModeManual {
t.Fatalf("got %q", got)
}
if got := NormalizeSharedDriveMode(""); got != SharedDriveModeAuto {
t.Fatalf("default got %q", got)
}
}

View File

@ -700,28 +700,4 @@ func truncateRunes(s string, n int) string {
return string(r[:n]) return string(r[:n])
} }
func LinkHostedMailboxByEmail(ctx context.Context, db *pgxpool.Pool, userID, email string) error {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return nil
}
_, err := db.Exec(ctx, `
UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW()
WHERE user_id IS NULL AND lower(local_part || '@' || (SELECT name FROM mail_domains d WHERE d.id = mailboxes.domain_id)) = $2
`, userID, email)
if err != nil {
return err
}
_, err = db.Exec(ctx, `
UPDATE mail_accounts ma SET user_id = $1::uuid, updated_at = NOW()
FROM mailboxes mb
JOIN mail_domains md ON md.id = mb.domain_id
WHERE mb.mail_account_id = ma.id
AND ma.user_id IS NULL
AND mb.user_id = $1::uuid
AND lower(mb.local_part || '@' || md.name) = $2
`, userID, email)
return err
}
var _ = pgx.ErrNoRows var _ = pgx.ErrNoRows

View File

@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"sort"
"strings" "strings"
"time" "time"
@ -134,6 +135,9 @@ func (g *GraphImporter) ImportBatch(
} }
if delta { if delta {
if len(graphFolderDeltaLinks(job.CursorJSON)) > 0 {
return g.importFolderDelta(ctx, job, accessToken, accountID, items, update)
}
deltaLink, _ := job.CursorJSON["deltaLink"].(string) deltaLink, _ := job.CursorJSON["deltaLink"].(string)
if deltaLink != "" { if deltaLink != "" {
more, err := g.importDeltaPage(ctx, job, accessToken, accountID, deltaLink, items) more, err := g.importDeltaPage(ctx, job, accessToken, accountID, deltaLink, items)
@ -147,12 +151,36 @@ func (g *GraphImporter) ImportBatch(
} }
} }
return g.importFullFolders(ctx, job, accessToken, accountID, items, delta, update)
}
func (g *GraphImporter) importFullFolders(
ctx context.Context,
job *Job,
accessToken, accountID string,
items *ImportedItemStore,
captureDelta bool,
update func(status string, cursor, stats map[string]any, jobErr string) error,
) error {
queue := g.folderQueue(job.CursorJSON)
folderIndex := int(jsonNumber(job.CursorJSON["folderIndex"]))
if folderIndex >= len(queue) {
if captureDelta {
if err := g.bootstrapFolderDeltaLinks(ctx, accessToken, queue, job.CursorJSON); err != nil {
return err
}
}
job.StatsJSON["phase"] = "imported"
return update("completed", job.CursorJSON, job.StatsJSON, "")
}
folderID := queue[folderIndex]
nextLink, _ := job.CursorJSON["nextLink"].(string) nextLink, _ := job.CursorJSON["nextLink"].(string)
var listURL string var listURL string
if nextLink != "" { if nextLink != "" {
listURL = nextLink listURL = nextLink
} else { } else {
listURL = g.graphURL(g.userBase()+"/messages?$top=100&$orderby="+url.QueryEscape("receivedDateTime desc")+"&$select="+graphMessageSelect) listURL = g.folderMessagesURL(folderID)
} }
body, err := g.apiGet(ctx, listURL, accessToken) body, err := g.apiGet(ctx, listURL, accessToken)
@ -210,17 +238,111 @@ func (g *GraphImporter) ImportBatch(
} }
delete(job.CursorJSON, "nextLink") delete(job.CursorJSON, "nextLink")
if delta { job.CursorJSON["folderIndex"] = float64(folderIndex + 1)
if listed.DeltaLink != "" { return update("pending", job.CursorJSON, job.StatsJSON, "")
job.CursorJSON["deltaLink"] = listed.DeltaLink
} else if link, err := g.initDeltaLink(ctx, accessToken); err == nil && link != "" {
job.CursorJSON["deltaLink"] = link
} }
}
job.StatsJSON["phase"] = "imported" func (g *GraphImporter) importFolderDelta(
ctx context.Context,
job *Job,
accessToken, accountID string,
items *ImportedItemStore,
update func(status string, cursor, stats map[string]any, jobErr string) error,
) error {
queue := g.folderQueue(job.CursorJSON)
folderIndex := int(jsonNumber(job.CursorJSON["folderIndex"]))
if folderIndex >= len(queue) {
job.StatsJSON["phase"] = "delta"
return update("completed", job.CursorJSON, job.StatsJSON, "") return update("completed", job.CursorJSON, job.StatsJSON, "")
} }
folderID := queue[folderIndex]
deltaLinks := graphFolderDeltaLinks(job.CursorJSON)
deltaLink := deltaLinks[folderID]
if deltaLink == "" {
deltaLink, _ = job.CursorJSON["nextLink"].(string)
}
if deltaLink == "" {
link, err := g.initFolderDeltaLink(ctx, accessToken, folderID)
if err != nil {
return err
}
deltaLink = link
}
more, err := g.importFolderDeltaPage(ctx, job, accessToken, accountID, folderID, deltaLink, items)
if err != nil {
return err
}
if more {
return update("pending", job.CursorJSON, job.StatsJSON, "")
}
delete(job.CursorJSON, "nextLink")
job.CursorJSON["folderIndex"] = float64(folderIndex + 1)
return update("pending", job.CursorJSON, job.StatsJSON, "")
}
func (g *GraphImporter) importFolderDeltaPage(
ctx context.Context,
job *Job,
accessToken, accountID, folderID, deltaLink string,
items *ImportedItemStore,
) (more bool, err error) {
body, err := g.apiGet(ctx, deltaLink, accessToken)
if err != nil {
return false, err
}
var parsed struct {
Value []graphMessage `json:"value"`
NextLink string `json:"@odata.nextLink"`
DeltaLink string `json:"@odata.deltaLink"`
}
if err := json.Unmarshal(body, &parsed); err != nil {
return false, err
}
deltaCount, _ := job.StatsJSON["delta_imported"].(float64)
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
for _, msg := range parsed.Value {
if msg.Removed != nil {
if err := g.deleteByGraphID(ctx, accountID, msg.ID); err != nil {
return false, err
}
deleted++
continue
}
if alreadyImported(items, msg.ID) {
continue
}
ok, err := g.importOne(ctx, accountID, msg)
if err != nil {
if markErr := items.MarkFailed(ctx, msg.ID, err.Error(), ""); markErr != nil {
return false, markErr
}
incJobStat(job.StatsJSON, "failed")
continue
}
if err := items.MarkImported(ctx, msg.ID); err != nil {
return false, err
}
if ok {
deltaCount++
}
}
job.StatsJSON["delta_imported"] = deltaCount
job.StatsJSON["delta_deleted"] = deleted
if parsed.NextLink != "" {
setGraphFolderDeltaLink(job.CursorJSON, folderID, parsed.NextLink)
job.StatsJSON["phase"] = "delta"
return true, nil
}
if parsed.DeltaLink != "" {
setGraphFolderDeltaLink(job.CursorJSON, folderID, parsed.DeltaLink)
}
job.StatsJSON["phase"] = "delta"
return false, nil
}
func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessToken, accountID, deltaLink string, items *ImportedItemStore) (more bool, err error) { func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessToken, accountID, deltaLink string, items *ImportedItemStore) (more bool, err error) {
body, err := g.apiGet(ctx, deltaLink, accessToken) body, err := g.apiGet(ctx, deltaLink, accessToken)
if err != nil { if err != nil {
@ -276,8 +398,16 @@ func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessTok
return false, nil return false, nil
} }
func (g *GraphImporter) initDeltaLink(ctx context.Context, accessToken string) (string, error) { func (g *GraphImporter) folderMessagesURL(folderID string) string {
body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/messages/delta?$select=id"), accessToken) path := g.userBase() + "/mailFolders/" + url.PathEscape(folderID) + "/messages" +
"?$top=100&$orderby=" + url.QueryEscape("receivedDateTime desc") +
"&$select=" + graphMessageSelect
return g.graphURL(path)
}
func (g *GraphImporter) initFolderDeltaLink(ctx context.Context, accessToken, folderID string) (string, error) {
path := g.userBase() + "/mailFolders/" + url.PathEscape(folderID) + "/messages/delta?$select=id"
body, err := g.apiGet(ctx, g.graphURL(path), accessToken)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -294,6 +424,37 @@ func (g *GraphImporter) initDeltaLink(ctx context.Context, accessToken string) (
return parsed.NextLink, nil return parsed.NextLink, nil
} }
func (g *GraphImporter) bootstrapFolderDeltaLinks(ctx context.Context, accessToken string, queue []string, cursor map[string]any) error {
for _, folderID := range queue {
if graphFolderDeltaLinks(cursor)[folderID] != "" {
continue
}
link, err := g.initFolderDeltaLink(ctx, accessToken, folderID)
if err != nil {
return err
}
if link != "" {
setGraphFolderDeltaLink(cursor, folderID, link)
}
}
delete(cursor, "deltaLink")
delete(cursor, "folderIndex")
return nil
}
func (g *GraphImporter) folderQueue(cursor map[string]any) []string {
if queue := readGraphFolderQueue(cursor); len(queue) > 0 {
return queue
}
ids := make([]string, 0, len(g.folders))
for id := range g.folders {
ids = append(ids, id)
}
sort.Strings(ids)
writeGraphFolderQueue(cursor, ids)
return ids
}
func (g *GraphImporter) importOne(ctx context.Context, accountID string, msg graphMessage) (bool, error) { func (g *GraphImporter) importOne(ctx context.Context, accountID string, msg graphMessage) (bool, error) {
meta := g.folders[msg.ParentFolderID] meta := g.folders[msg.ParentFolderID]
if meta.RemoteName == "" { if meta.RemoteName == "" {
@ -391,7 +552,9 @@ func (g *GraphImporter) ensureGraphFolders(ctx context.Context, accessToken stri
if len(g.folders) > 0 { if len(g.folders) > 0 {
return nil return nil
} }
body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/mailFolders?$top=100&$select=id,displayName,wellKnownName"), accessToken) listURL := g.graphURL(g.userBase() + "/mailFolders?$top=100&$select=id,displayName,wellKnownName")
for listURL != "" {
body, err := g.apiGet(ctx, listURL, accessToken)
if err != nil { if err != nil {
return err return err
} }
@ -401,6 +564,7 @@ func (g *GraphImporter) ensureGraphFolders(ctx context.Context, accessToken stri
DisplayName string `json:"displayName"` DisplayName string `json:"displayName"`
WellKnownName string `json:"wellKnownName"` WellKnownName string `json:"wellKnownName"`
} `json:"value"` } `json:"value"`
NextLink string `json:"@odata.nextLink"`
} }
if err := json.Unmarshal(body, &parsed); err != nil { if err := json.Unmarshal(body, &parsed); err != nil {
return err return err
@ -409,6 +573,8 @@ func (g *GraphImporter) ensureGraphFolders(ctx context.Context, accessToken stri
remote, ftype := graphWellKnownFolder(f.WellKnownName, f.DisplayName) remote, ftype := graphWellKnownFolder(f.WellKnownName, f.DisplayName)
g.folders[f.ID] = graphFolderMeta{RemoteName: remote, FolderType: ftype} g.folders[f.ID] = graphFolderMeta{RemoteName: remote, FolderType: ftype}
} }
listURL = parsed.NextLink
}
return nil return nil
} }

View File

@ -1,6 +1,11 @@
package migration package migration
import "testing" import (
"context"
"net/http"
"strings"
"testing"
)
func TestGraphWellKnownFolder(t *testing.T) { func TestGraphWellKnownFolder(t *testing.T) {
remote, ftype := graphWellKnownFolder("inbox", "Inbox") remote, ftype := graphWellKnownFolder("inbox", "Inbox")
@ -47,3 +52,92 @@ func TestRemoteMessageUIDMatchesGmailUID(t *testing.T) {
t.Fatal("uid helpers diverged") t.Fatal("uid helpers diverged")
} }
} }
func TestGraphFolderQueueSortedAndCached(t *testing.T) {
g := NewGraphImporter(nil)
g.folders = map[string]graphFolderMeta{
"sent-folder": {RemoteName: "SENT", FolderType: "sent"},
"inbox-folder": {RemoteName: "INBOX", FolderType: "inbox"},
}
cursor := map[string]any{}
queue := g.folderQueue(cursor)
if len(queue) != 2 {
t.Fatalf("queue len = %d", len(queue))
}
if queue[0] != "inbox-folder" || queue[1] != "sent-folder" {
t.Fatalf("queue order = %v", queue)
}
cached := readGraphFolderQueue(cursor)
if len(cached) != 2 || cached[0] != "inbox-folder" {
t.Fatalf("cached queue = %v", cached)
}
}
func TestGraphFolderMessagesURLUsesMailFoldersPath(t *testing.T) {
g := NewGraphImporter(nil).WithBaseURL("https://graph.test")
listURL := g.folderMessagesURL("folder-abc")
if !strings.Contains(listURL, "/mailFolders/folder-abc/messages") {
t.Fatalf("url = %q", listURL)
}
if strings.Contains(listURL, "/me/messages") {
t.Fatalf("flat messages path should not be used: %q", listURL)
}
}
func TestGraphEnsureFoldersPaginates(t *testing.T) {
pages := 0
client := mockGraphHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
if !strings.HasSuffix(r.URL.Path, "/mailFolders") {
http.NotFound(w, r)
return
}
pages++
if pages == 1 {
_, _ = w.Write([]byte(`{
"value":[{"id":"inbox-id","displayName":"Inbox","wellKnownName":"inbox"}],
"@odata.nextLink":"https://graph.microsoft.com/v1.0/me/mailFolders?$top=100&$skip=100"
}`))
return
}
_, _ = w.Write([]byte(`{"value":[{"id":"sent-id","displayName":"Sent","wellKnownName":"sentitems"}]}`))
})
g := NewGraphImporter(nil).WithHTTPClient(client)
if err := g.ensureGraphFolders(context.Background(), "token"); err != nil {
t.Fatalf("ensure folders: %v", err)
}
if pages != 2 {
t.Fatalf("pages = %d, want 2", pages)
}
if len(g.folders) != 2 {
t.Fatalf("folders = %d", len(g.folders))
}
}
func TestGraphInitFolderDeltaLink(t *testing.T) {
client := mockGraphHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/mailFolders/inbox-id/messages/delta") {
_, _ = w.Write([]byte(`{"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/mailFolders/inbox-id/messages/delta?token=done"}`))
return
}
http.NotFound(w, r)
})
g := NewGraphImporter(nil).WithHTTPClient(client)
link, err := g.initFolderDeltaLink(context.Background(), "token", "inbox-id")
if err != nil {
t.Fatalf("init delta: %v", err)
}
if !strings.Contains(link, "/mailFolders/inbox-id/messages/delta") {
t.Fatalf("delta link = %q", link)
}
}
func TestGraphFolderDeltaLinkHelpers(t *testing.T) {
cursor := map[string]any{}
setGraphFolderDeltaLink(cursor, "inbox-id", "https://delta/inbox")
links := graphFolderDeltaLinks(cursor)
if links["inbox-id"] != "https://delta/inbox" {
t.Fatalf("links = %v", links)
}
}

View File

@ -107,6 +107,50 @@ func setCalendarDeltaLink(cursor map[string]any, calID, link string) {
raw[calID] = link raw[calID] = link
} }
func graphFolderDeltaLinks(cursor map[string]any) map[string]string {
raw, _ := cursor["folderDeltaLinks"].(map[string]any)
out := make(map[string]string, len(raw))
for k, v := range raw {
if s, ok := v.(string); ok && s != "" {
out[k] = s
}
}
return out
}
func setGraphFolderDeltaLink(cursor map[string]any, folderID, link string) {
if folderID == "" || link == "" {
return
}
raw, _ := cursor["folderDeltaLinks"].(map[string]any)
if raw == nil {
raw = map[string]any{}
cursor["folderDeltaLinks"] = raw
}
raw[folderID] = link
}
func readGraphFolderQueue(cursor map[string]any) []string {
raw, _ := cursor["graphFolderQueue"].([]any)
out := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok && s != "" {
out = append(out, s)
}
}
return out
}
func writeGraphFolderQueue(cursor map[string]any, ids []string) {
queue := make([]any, 0, len(ids))
for _, id := range ids {
if id != "" {
queue = append(queue, id)
}
}
cursor["graphFolderQueue"] = queue
}
func migrationContactPath(bookPath, provider, sourceID string) string { func migrationContactPath(bookPath, provider, sourceID string) string {
uid := sanitizeMigrationUID(provider, sourceID) uid := sanitizeMigrationUID(provider, sourceID)
return bookPath + uid + ".vcf" return bookPath + uid + ".vcf"

View File

@ -0,0 +1,97 @@
package migration
import (
"context"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
)
// HasPendingMigrationInvite reports whether an unclaimed invite exists for the email.
func HasPendingMigrationInvite(ctx context.Context, db *pgxpool.Pool, email string) (bool, error) {
if db == nil {
return false, nil
}
email = normalizeInviteEmail(email)
if email == "" {
return false, nil
}
var exists bool
err := db.QueryRow(ctx, `
SELECT EXISTS(
SELECT 1 FROM migration_invites
WHERE status = 'invited' AND lower(email) = lower($1)
)
`, email).Scan(&exists)
return exists, err
}
// ProvisionAudit counts user-linked provision artifacts for test verification.
type ProvisionAudit struct {
Users int
Mailboxes int
MailAccounts int
NCPrincipals int
}
// AuditProvisionByEmail counts rows tied to an email across users, mailboxes, mail accounts, and Nextcloud credentials.
func AuditProvisionByEmail(ctx context.Context, db *pgxpool.Pool, email string) (ProvisionAudit, error) {
var audit ProvisionAudit
if db == nil {
return audit, nil
}
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return audit, nil
}
if err := db.QueryRow(ctx, `
SELECT COUNT(*) FROM users WHERE lower(email) = $1
`, email).Scan(&audit.Users); err != nil {
return audit, err
}
if err := db.QueryRow(ctx, `
SELECT COUNT(*)
FROM mailboxes mb
JOIN mail_domains md ON md.id = mb.domain_id
WHERE lower(mb.local_part || '@' || md.name) = $1
`, email).Scan(&audit.Mailboxes); err != nil {
return audit, err
}
if err := db.QueryRow(ctx, `
SELECT COUNT(*) FROM mail_accounts WHERE lower(email) = $1
`, email).Scan(&audit.MailAccounts); err != nil {
return audit, err
}
if err := db.QueryRow(ctx, `
SELECT COUNT(*) FROM nextcloud_dav_credentials WHERE nc_user_id = $1
`, email).Scan(&audit.NCPrincipals); err != nil {
return audit, err
}
return audit, nil
}
// LinkHostedMailboxByEmail attaches orphan mailboxes/mail_accounts (e.g. from claim-before-enroll) to a user.
func LinkHostedMailboxByEmail(ctx context.Context, db *pgxpool.Pool, userID, email string) error {
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return nil
}
_, err := db.Exec(ctx, `
UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW()
WHERE user_id IS NULL AND lower(local_part || '@' || (SELECT name FROM mail_domains d WHERE d.id = mailboxes.domain_id)) = $2
`, userID, email)
if err != nil {
return err
}
_, err = db.Exec(ctx, `
UPDATE mail_accounts ma SET user_id = $1::uuid, updated_at = NOW()
FROM mailboxes mb
JOIN mail_domains md ON md.id = mb.domain_id
WHERE mb.mail_account_id = ma.id
AND ma.user_id IS NULL
AND mb.user_id = $1::uuid
AND lower(mb.local_part || '@' || md.name) = $2
`, userID, email)
return err
}

View File

@ -0,0 +1,35 @@
package migration
import (
"testing"
)
func TestHasPendingMigrationInviteNilDB(t *testing.T) {
ok, err := HasPendingMigrationInvite(t.Context(), nil, "user@example.com")
if err != nil {
t.Fatalf("HasPendingMigrationInvite() error = %v", err)
}
if ok {
t.Fatal("expected false with nil db")
}
}
func TestAuditProvisionByEmailEmpty(t *testing.T) {
audit, err := AuditProvisionByEmail(t.Context(), nil, "")
if err != nil {
t.Fatalf("AuditProvisionByEmail() error = %v", err)
}
if audit.Users != 0 || audit.Mailboxes != 0 || audit.MailAccounts != 0 || audit.NCPrincipals != 0 {
t.Fatalf("expected zero audit, got %#v", audit)
}
}
func TestAuditProvisionByEmailNormalizesEmail(t *testing.T) {
audit, err := AuditProvisionByEmail(t.Context(), nil, " User@Example.COM ")
if err != nil {
t.Fatalf("AuditProvisionByEmail() error = %v", err)
}
if audit.Users != 0 {
t.Fatalf("expected zero users with nil db, got %d", audit.Users)
}
}

View File

@ -0,0 +1,319 @@
package migration
import (
"context"
"encoding/csv"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
const jobAuditExportFlushEvery = 100
var jobAuditCSVHeaders = []string{"item_id", "rel_path", "status", "error", "service", "timestamp"}
var projectAuditCSVHeaders = []string{"job_id", "item_id", "rel_path", "status", "error", "service", "timestamp"}
// JobAuditExportMeta describes a migration job audit export download.
type JobAuditExportMeta struct {
ContentType string
FileName string
}
// JobAuditExportRow is one exported audit line.
type JobAuditExportRow struct {
JobID string `json:"job_id,omitempty"`
ItemID string `json:"item_id"`
RelPath string `json:"rel_path,omitempty"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
Service string `json:"service"`
Timestamp string `json:"timestamp"`
}
// PrepareJobAuditExport validates the job belongs to the project and returns download metadata.
func (s *Service) PrepareJobAuditExport(ctx context.Context, projectID, jobID, format string) (JobAuditExportMeta, error) {
if _, err := s.verifyJobInProject(ctx, projectID, jobID); err != nil {
return JobAuditExportMeta{}, err
}
return jobAuditExportMeta(format, jobID, false), nil
}
// PrepareProjectAuditExport validates the project and returns download metadata.
func (s *Service) PrepareProjectAuditExport(ctx context.Context, projectID, format string) (JobAuditExportMeta, error) {
if err := s.verifyProjectExists(ctx, projectID); err != nil {
return JobAuditExportMeta{}, err
}
return jobAuditExportMeta(format, projectID, true), nil
}
// WriteJobAuditExport streams audit rows for one job to w. Call PrepareJobAuditExport first.
func (s *Service) WriteJobAuditExport(
ctx context.Context,
projectID, jobID, statusFilter, format string,
w io.Writer,
) error {
service, err := s.verifyJobInProject(ctx, projectID, jobID)
if err != nil {
return err
}
statusFilter = normalizeAuditStatusFilter(statusFilter)
listSQL := `
SELECT source_id, rel_path, status, reason, imported_at::text
FROM migration_imported_items
WHERE job_id = $1::uuid
`
listArgs := []any{jobID}
if statusFilter != "" {
listSQL += ` AND status = $2`
listArgs = append(listArgs, statusFilter)
}
listSQL += ` ORDER BY imported_at DESC, source_id ASC`
rows, err := s.db.Query(ctx, listSQL, listArgs...)
if err != nil {
return err
}
defer rows.Close()
if err := streamJobAuditRows(format, w, service, jobID, rows); err != nil {
return err
}
return rows.Err()
}
// WriteProjectAuditExport streams audit rows for all jobs in a project to w.
func (s *Service) WriteProjectAuditExport(
ctx context.Context,
projectID, statusFilter, format string,
w io.Writer,
) error {
if err := s.verifyProjectExists(ctx, projectID); err != nil {
return err
}
statusFilter = normalizeAuditStatusFilter(statusFilter)
listSQL := `
SELECT j.id::text, j.service, i.source_id, i.rel_path, i.status, i.reason, i.imported_at::text
FROM migration_imported_items i
JOIN migration_jobs j ON j.id = i.job_id
WHERE j.project_id = $1::uuid
`
listArgs := []any{projectID}
if statusFilter != "" {
listSQL += ` AND i.status = $2`
listArgs = append(listArgs, statusFilter)
}
listSQL += ` ORDER BY i.imported_at DESC, i.source_id ASC`
rows, err := s.db.Query(ctx, listSQL, listArgs...)
if err != nil {
return err
}
defer rows.Close()
if err := streamProjectAuditRows(format, w, rows); err != nil {
return err
}
return rows.Err()
}
func (s *Service) verifyProjectExists(ctx context.Context, projectID string) error {
var exists bool
err := s.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM migration_projects WHERE id = $1::uuid)`, projectID).Scan(&exists)
if err != nil {
return err
}
if !exists {
return fmt.Errorf("project not found")
}
return nil
}
func jobAuditExportMeta(format, id string, projectLevel bool) JobAuditExportMeta {
now := time.Now().UTC().Format("20060102T150405Z")
shortID := id
if len(shortID) > 8 {
shortID = shortID[:8]
}
prefix := "migration-job-audit"
if projectLevel {
prefix = "migration-project-audit"
}
ext := "ndjson"
contentType := "application/x-ndjson; charset=utf-8"
if format == "csv" {
ext = "csv"
contentType = "text/csv; charset=utf-8"
}
return JobAuditExportMeta{
ContentType: contentType,
FileName: fmt.Sprintf("%s-%s-%s.%s", prefix, shortID, now, ext),
}
}
type jobAuditRowScanner interface {
Next() bool
Scan(dest ...any) error
Err() error
}
func streamJobAuditRows(format string, w io.Writer, service, jobID string, rows jobAuditRowScanner) error {
flusher, _ := w.(http.Flusher)
switch format {
case "csv":
cw := csv.NewWriter(w)
if err := cw.Write(jobAuditCSVHeaders); err != nil {
return err
}
count := 0
for rows.Next() {
var itemID, relPath, status, reason, importedAt string
if err := rows.Scan(&itemID, &relPath, &status, &reason, &importedAt); err != nil {
return err
}
if err := writeJobAuditCSVRow(cw, JobAuditExportRow{
JobID: jobID,
ItemID: itemID,
RelPath: relPath,
Status: status,
Error: reason,
Service: service,
Timestamp: importedAt,
}); err != nil {
return err
}
count++
if count%jobAuditExportFlushEvery == 0 {
cw.Flush()
if err := cw.Error(); err != nil {
return err
}
if flusher != nil {
flusher.Flush()
}
}
}
cw.Flush()
return cw.Error()
default:
enc := json.NewEncoder(w)
count := 0
for rows.Next() {
var itemID, relPath, status, reason, importedAt string
if err := rows.Scan(&itemID, &relPath, &status, &reason, &importedAt); err != nil {
return err
}
if err := enc.Encode(JobAuditExportRow{
JobID: jobID,
ItemID: itemID,
RelPath: relPath,
Status: status,
Error: reason,
Service: service,
Timestamp: importedAt,
}); err != nil {
return err
}
count++
if count%jobAuditExportFlushEvery == 0 && flusher != nil {
flusher.Flush()
}
}
return nil
}
}
func streamProjectAuditRows(format string, w io.Writer, rows jobAuditRowScanner) error {
flusher, _ := w.(http.Flusher)
switch format {
case "csv":
cw := csv.NewWriter(w)
if err := cw.Write(projectAuditCSVHeaders); err != nil {
return err
}
count := 0
for rows.Next() {
var jobID, service, itemID, relPath, status, reason, importedAt string
if err := rows.Scan(&jobID, &service, &itemID, &relPath, &status, &reason, &importedAt); err != nil {
return err
}
if err := writeProjectAuditCSVRow(cw, JobAuditExportRow{
JobID: jobID,
ItemID: itemID,
RelPath: relPath,
Status: status,
Error: reason,
Service: service,
Timestamp: importedAt,
}); err != nil {
return err
}
count++
if count%jobAuditExportFlushEvery == 0 {
cw.Flush()
if err := cw.Error(); err != nil {
return err
}
if flusher != nil {
flusher.Flush()
}
}
}
cw.Flush()
return cw.Error()
default:
enc := json.NewEncoder(w)
count := 0
for rows.Next() {
var jobID, service, itemID, relPath, status, reason, importedAt string
if err := rows.Scan(&jobID, &service, &itemID, &relPath, &status, &reason, &importedAt); err != nil {
return err
}
if err := enc.Encode(JobAuditExportRow{
JobID: jobID,
ItemID: itemID,
RelPath: relPath,
Status: status,
Error: reason,
Service: service,
Timestamp: importedAt,
}); err != nil {
return err
}
count++
if count%jobAuditExportFlushEvery == 0 && flusher != nil {
flusher.Flush()
}
}
return nil
}
}
func writeJobAuditCSVRow(w *csv.Writer, row JobAuditExportRow) error {
return w.Write([]string{
row.ItemID,
row.RelPath,
row.Status,
row.Error,
row.Service,
row.Timestamp,
})
}
func writeProjectAuditCSVRow(w *csv.Writer, row JobAuditExportRow) error {
return w.Write([]string{
row.JobID,
row.ItemID,
row.RelPath,
row.Status,
row.Error,
row.Service,
row.Timestamp,
})
}

View File

@ -0,0 +1,115 @@
package migration
import (
"bytes"
"encoding/csv"
"encoding/json"
"strings"
"testing"
)
func TestJobAuditExportCSVFormat(t *testing.T) {
var buf bytes.Buffer
cw := csv.NewWriter(&buf)
if err := cw.Write(jobAuditCSVHeaders); err != nil {
t.Fatal(err)
}
row := JobAuditExportRow{
ItemID: "msg-fail",
RelPath: "Inbox/foo.eml",
Status: ItemStatusFailed,
Error: "upload timeout",
Service: "mail",
Timestamp: "2026-06-13T12:00:00Z",
}
if err := writeJobAuditCSVRow(cw, row); err != nil {
t.Fatal(err)
}
cw.Flush()
if err := cw.Error(); err != nil {
t.Fatal(err)
}
reader := csv.NewReader(strings.NewReader(buf.String()))
records, err := reader.ReadAll()
if err != nil {
t.Fatal(err)
}
if len(records) != 2 {
t.Fatalf("records = %d, want 2", len(records))
}
if got := strings.Join(records[0], ","); got != "item_id,rel_path,status,error,service,timestamp" {
t.Fatalf("headers = %q", got)
}
if records[1][0] != "msg-fail" || records[1][2] != ItemStatusFailed || records[1][3] != "upload timeout" {
t.Fatalf("row = %#v", records[1])
}
}
func TestProjectAuditExportCSVFormat(t *testing.T) {
var buf bytes.Buffer
cw := csv.NewWriter(&buf)
if err := cw.Write(projectAuditCSVHeaders); err != nil {
t.Fatal(err)
}
if err := writeProjectAuditCSVRow(cw, JobAuditExportRow{
JobID: "job-1",
ItemID: "file-1",
Status: ItemStatusImported,
Service: "drive",
Timestamp: "2026-06-13T12:00:00Z",
}); err != nil {
t.Fatal(err)
}
cw.Flush()
reader := csv.NewReader(strings.NewReader(buf.String()))
records, err := reader.ReadAll()
if err != nil {
t.Fatal(err)
}
if records[0][0] != "job_id" || records[1][0] != "job-1" {
t.Fatalf("records = %#v", records)
}
}
func TestJobAuditExportNDJSONFormat(t *testing.T) {
var buf bytes.Buffer
enc := json.NewEncoder(&buf)
if err := enc.Encode(JobAuditExportRow{
ItemID: "msg-skip",
Status: ItemStatusSkipped,
Error: "file too large",
Service: "mail",
Timestamp: "2026-06-13T12:00:00Z",
}); err != nil {
t.Fatal(err)
}
line := strings.TrimSpace(buf.String())
var decoded JobAuditExportRow
if err := json.Unmarshal([]byte(line), &decoded); err != nil {
t.Fatal(err)
}
if decoded.ItemID != "msg-skip" || decoded.Status != ItemStatusSkipped || decoded.Error != "file too large" {
t.Fatalf("decoded = %#v", decoded)
}
}
func TestJobAuditExportMeta(t *testing.T) {
csvMeta := jobAuditExportMeta("csv", "01234567-abcd-efgh", false)
if csvMeta.ContentType != "text/csv; charset=utf-8" {
t.Fatalf("csv content type = %q", csvMeta.ContentType)
}
if !strings.HasSuffix(csvMeta.FileName, ".csv") {
t.Fatalf("csv filename = %q", csvMeta.FileName)
}
ndMeta := jobAuditExportMeta("ndjson", "01234567-abcd-efgh", true)
if ndMeta.ContentType != "application/x-ndjson; charset=utf-8" {
t.Fatalf("ndjson content type = %q", ndMeta.ContentType)
}
if !strings.HasPrefix(ndMeta.FileName, "migration-project-audit-") {
t.Fatalf("project filename = %q", ndMeta.FileName)
}
}

View File

@ -7,29 +7,26 @@ import (
"github.com/ultisuite/ulti-backend/internal/mail/hosted" "github.com/ultisuite/ulti-backend/internal/mail/hosted"
) )
var projectSelectColumns = []string{
"id::text",
"COALESCE(domain_id::text, '')",
"name",
"source_provider",
"auth_mode",
"status",
"cutover_at::text",
"delta_mode",
"created_at::text",
"NULLIF(microsoft_tenant_id, '')",
"microsoft_admin_consent_at::text",
"COALESCE(NULLIF(microsoft_admin_consent_error, ''), '')",
"cutover_dns_json",
}
func projectSelectSQL(tablePrefix string) string { func projectSelectSQL(tablePrefix string) string {
if tablePrefix != "" && !strings.HasSuffix(tablePrefix, ".") { if tablePrefix != "" && !strings.HasSuffix(tablePrefix, ".") {
tablePrefix += "." tablePrefix += "."
} }
cols := make([]string, len(projectSelectColumns)) p := tablePrefix
for i, col := range projectSelectColumns { cols := []string{
cols[i] = tablePrefix + col p + "id::text",
"COALESCE(" + p + "domain_id::text, '')",
p + "name",
p + "source_provider",
p + "auth_mode",
p + "status",
p + "cutover_at::text",
p + "delta_mode",
p + "shared_drive_mode",
p + "created_at::text",
"COALESCE(NULLIF(" + p + "microsoft_tenant_id, ''), '')",
p + "microsoft_admin_consent_at::text",
"COALESCE(NULLIF(" + p + "microsoft_admin_consent_error, ''), '')",
p + "cutover_dns_json",
} }
return strings.Join(cols, ", ") return strings.Join(cols, ", ")
} }
@ -47,7 +44,7 @@ func (s *projectScanner) targets() []any {
return []any{ return []any{
&s.project.ID, &s.project.DomainID, &s.project.Name, &s.project.SourceProvider, &s.project.ID, &s.project.DomainID, &s.project.Name, &s.project.SourceProvider,
&s.project.AuthMode, &s.project.Status, &s.project.CutoverAt, &s.project.DeltaMode, &s.project.AuthMode, &s.project.Status, &s.project.CutoverAt, &s.project.DeltaMode,
&s.project.CreatedAt, &s.project.MicrosoftTenantID, &s.project.MicrosoftAdminConsentAt, &s.project.SharedDriveMode, &s.project.CreatedAt, &s.project.MicrosoftTenantID, &s.project.MicrosoftAdminConsentAt,
&s.project.MicrosoftAdminConsentError, &s.cutoverDNSRaw, &s.project.MicrosoftAdminConsentError, &s.cutoverDNSRaw,
} }
} }

View File

@ -0,0 +1,347 @@
package migration
import (
"context"
"encoding/csv"
"errors"
"fmt"
"io"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
)
const (
RosterStatusPending = "pending"
RosterStatusInvited = "invited"
RosterStatusClaimed = "claimed"
)
type RosterEntry struct {
ID string `json:"id"`
ProjectID string `json:"project_id"`
Email string `json:"email"`
DisplayName string `json:"display_name,omitempty"`
AlternateEmails []string `json:"alternate_emails,omitempty"`
Status string `json:"status"`
InviteID string `json:"invite_id,omitempty"`
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}
type RosterRowInput struct {
Email string
DisplayName string
AlternateEmails []string
}
type RosterImportRowError struct {
Row int `json:"row"`
Email string `json:"email,omitempty"`
Message string `json:"message"`
}
type RosterImportResult struct {
Created int `json:"created"`
SkippedDuplicates int `json:"skipped_duplicates"`
Errors []RosterImportRowError `json:"errors,omitempty"`
}
var rosterHeaderAliases = map[string]string{
"email": "email",
"e-mail": "email",
"mail": "email",
"address": "email",
"display_name": "display_name",
"displayname": "display_name",
"name": "display_name",
"full_name": "display_name",
"alternate_emails": "alternate_emails",
"alternate_emails_": "alternate_emails",
"alternates": "alternate_emails",
"alias": "alternate_emails",
"aliases": "alternate_emails",
}
func ParseRosterCSV(r io.Reader) ([]RosterRowInput, error) {
reader := csv.NewReader(r)
reader.TrimLeadingSpace = true
reader.FieldsPerRecord = -1
var rows []RosterRowInput
lineNum := 0
emailCol := 0
displayCol := -1
alternateCol := -1
headerResolved := false
for {
record, err := reader.Read()
if errors.Is(err, io.EOF) {
break
}
if err != nil {
return nil, fmt.Errorf("csv row %d: %w", lineNum+1, err)
}
lineNum++
if len(record) == 0 {
continue
}
for len(record) > 0 && strings.TrimSpace(record[len(record)-1]) == "" {
record = record[:len(record)-1]
}
if len(record) == 0 {
continue
}
if !headerResolved && looksLikeRosterHeader(record) {
for i, col := range record {
key := normalizeRosterHeader(col)
switch key {
case "email":
emailCol = i
case "display_name":
displayCol = i
case "alternate_emails":
alternateCol = i
}
}
headerResolved = true
continue
}
headerResolved = true
if emailCol >= len(record) {
continue
}
email := normalizeInviteEmail(record[emailCol])
if email == "" {
continue
}
if !isEmailAddress(email) {
return nil, fmt.Errorf("csv row %d: invalid email %q", lineNum, record[emailCol])
}
row := RosterRowInput{Email: email}
if displayCol >= 0 && displayCol < len(record) {
row.DisplayName = strings.TrimSpace(record[displayCol])
}
if alternateCol >= 0 && alternateCol < len(record) {
row.AlternateEmails = parseAlternateEmailsField(record[alternateCol])
} else if len(record) > 1 && displayCol < 0 && alternateCol < 0 {
// email,display_name or email,alternates without header
if len(record) > 1 {
second := strings.TrimSpace(record[1])
if strings.Contains(second, "@") {
row.AlternateEmails = parseAlternateEmailsField(second)
} else {
row.DisplayName = second
}
}
if len(record) > 2 {
row.AlternateEmails = parseAlternateEmailsField(record[2])
}
}
row.AlternateEmails = normalizeAlternateEmails(email, row.AlternateEmails)
rows = append(rows, row)
}
return rows, nil
}
func looksLikeRosterHeader(record []string) bool {
if len(record) == 0 {
return false
}
first := normalizeRosterHeader(record[0])
if first == "email" {
return true
}
for _, col := range record {
if normalizeRosterHeader(col) == "email" {
return true
}
}
return false
}
func normalizeRosterHeader(col string) string {
key := strings.ToLower(strings.TrimSpace(col))
key = strings.ReplaceAll(key, " ", "_")
key = strings.ReplaceAll(key, "-", "_")
if mapped, ok := rosterHeaderAliases[key]; ok {
return mapped
}
return key
}
func parseAlternateEmailsField(raw string) []string {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil
}
raw = strings.Trim(raw, "\"'")
parts := strings.FieldsFunc(raw, func(r rune) bool {
return r == ';' || r == '|' || r == ','
})
var out []string
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
out = append(out, p)
}
}
return out
}
func (s *Service) ListRoster(ctx context.Context, projectID string) ([]RosterEntry, error) {
rows, err := s.db.Query(ctx, `
SELECT id::text, project_id::text, email, display_name, alternate_emails, status,
COALESCE(invite_id::text, ''), created_at::text, updated_at::text
FROM migration_roster
WHERE project_id = $1::uuid
ORDER BY email
`, projectID)
if err != nil {
return nil, err
}
defer rows.Close()
var out []RosterEntry
for rows.Next() {
var row RosterEntry
if err := rows.Scan(
&row.ID, &row.ProjectID, &row.Email, &row.DisplayName, &row.AlternateEmails,
&row.Status, &row.InviteID, &row.CreatedAt, &row.UpdatedAt,
); err != nil {
return nil, err
}
out = append(out, row)
}
return out, rows.Err()
}
func (s *Service) ImportRoster(ctx context.Context, projectID string, inputs []RosterRowInput) (RosterImportResult, error) {
result := RosterImportResult{}
for i, input := range inputs {
rowNum := i + 1
email := normalizeInviteEmail(input.Email)
if email == "" || !isEmailAddress(email) {
result.Errors = append(result.Errors, RosterImportRowError{
Row: rowNum, Email: input.Email, Message: "invalid email",
})
continue
}
alternates := normalizeAlternateEmails(email, input.AlternateEmails)
displayName := strings.TrimSpace(input.DisplayName)
existingStatus, err := s.rosterStatusByEmail(ctx, projectID, email)
if err != nil {
return result, err
}
if existingStatus != "" {
result.SkippedDuplicates++
continue
}
tx, err := s.db.Begin(ctx)
if err != nil {
return result, err
}
var rosterID string
err = tx.QueryRow(ctx, `
INSERT INTO migration_roster (project_id, email, display_name, alternate_emails, status)
VALUES ($1::uuid, $2, $3, $4, $5)
RETURNING id::text
`, projectID, email, displayName, alternates, RosterStatusPending).Scan(&rosterID)
if err != nil {
_ = tx.Rollback(ctx)
if isUniqueViolation(err) {
result.SkippedDuplicates++
continue
}
result.Errors = append(result.Errors, RosterImportRowError{
Row: rowNum, Email: email, Message: err.Error(),
})
continue
}
token, err := hosted.NewInviteToken()
if err != nil {
_ = tx.Rollback(ctx)
return result, err
}
var inviteID string
err = tx.QueryRow(ctx, `
INSERT INTO migration_invites (project_id, email, token, alternate_emails)
VALUES ($1::uuid, $2, $3, $4)
RETURNING id::text
`, projectID, email, token, alternates).Scan(&inviteID)
if err != nil {
_ = tx.Rollback(ctx)
if isUniqueViolation(err) {
result.SkippedDuplicates++
continue
}
result.Errors = append(result.Errors, RosterImportRowError{
Row: rowNum, Email: email, Message: err.Error(),
})
continue
}
_, err = tx.Exec(ctx, `
UPDATE migration_roster
SET status = $1, invite_id = $2::uuid, updated_at = NOW()
WHERE id = $3::uuid
`, RosterStatusInvited, inviteID, rosterID)
if err != nil {
_ = tx.Rollback(ctx)
result.Errors = append(result.Errors, RosterImportRowError{
Row: rowNum, Email: email, Message: err.Error(),
})
continue
}
if err := tx.Commit(ctx); err != nil {
result.Errors = append(result.Errors, RosterImportRowError{
Row: rowNum, Email: email, Message: err.Error(),
})
continue
}
result.Created++
}
return result, nil
}
func (s *Service) rosterStatusByEmail(ctx context.Context, projectID, email string) (string, error) {
var status string
err := s.db.QueryRow(ctx, `
SELECT status FROM migration_roster
WHERE project_id = $1::uuid AND email = $2
`, projectID, email).Scan(&status)
if errors.Is(err, pgx.ErrNoRows) {
return "", nil
}
return status, err
}
func (s *Service) markRosterClaimed(ctx context.Context, tx pgx.Tx, projectID, inviteID, email string) error {
_, err := tx.Exec(ctx, `
UPDATE migration_roster
SET status = $1, updated_at = NOW()
WHERE project_id = $2::uuid
AND (invite_id = $3::uuid OR (invite_id IS NULL AND email = $4))
`, RosterStatusClaimed, projectID, inviteID, email)
return err
}
func isUniqueViolation(err error) bool {
var pgErr *pgconn.PgError
return errors.As(err, &pgErr) && pgErr.Code == "23505"
}

View File

@ -0,0 +1,69 @@
package migration
import (
"strings"
"testing"
)
func TestParseRosterCSVWithHeader(t *testing.T) {
csv := `email,display_name,alternate_emails
alice@corp.com,Alice Corp,alice.old@corp.com;bob.alias@corp.com
bob@corp.com,Bob,
`
rows, err := ParseRosterCSV(strings.NewReader(csv))
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(rows) != 2 {
t.Fatalf("expected 2 rows, got %d", len(rows))
}
if rows[0].Email != "alice@corp.com" || rows[0].DisplayName != "Alice Corp" {
t.Fatalf("row0: %#v", rows[0])
}
if len(rows[0].AlternateEmails) != 2 {
t.Fatalf("row0 alternates: %#v", rows[0].AlternateEmails)
}
if rows[1].Email != "bob@corp.com" || rows[1].DisplayName != "Bob" {
t.Fatalf("row1: %#v", rows[1])
}
}
func TestParseRosterCSVWithoutHeader(t *testing.T) {
csv := "alice@corp.com,Alice\nbob@corp.com\n"
rows, err := ParseRosterCSV(strings.NewReader(csv))
if err != nil {
t.Fatalf("parse: %v", err)
}
if len(rows) != 2 {
t.Fatalf("expected 2 rows, got %d", len(rows))
}
if rows[0].DisplayName != "Alice" {
t.Fatalf("display name: %#v", rows[0])
}
if rows[1].DisplayName != "" {
t.Fatalf("row1 display: %#v", rows[1])
}
}
func TestParseRosterCSVInvalidEmail(t *testing.T) {
_, err := ParseRosterCSV(strings.NewReader("not-an-email,Someone\n"))
if err == nil {
t.Fatal("expected invalid email error")
}
}
func TestParseAlternateEmailsField(t *testing.T) {
got := parseAlternateEmailsField(`a@x.com; b@x.com | c@x.com`)
if len(got) != 3 {
t.Fatalf("got %#v", got)
}
}
func TestLooksLikeRosterHeader(t *testing.T) {
if !looksLikeRosterHeader([]string{"Email", "Name"}) {
t.Fatal("expected header detection")
}
if looksLikeRosterHeader([]string{"alice@corp.com", "Alice"}) {
t.Fatal("data row should not be header")
}
}

View File

@ -23,6 +23,7 @@ var (
ErrInviteNotFound = errors.New("migration invite not found") ErrInviteNotFound = errors.New("migration invite not found")
ErrInviteClaimed = errors.New("migration invite already claimed") ErrInviteClaimed = errors.New("migration invite already claimed")
ErrEmailMismatch = errors.New("email does not match invite") ErrEmailMismatch = errors.New("email does not match invite")
ErrTenantMismatch = errors.New("microsoft tenant does not match project")
ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active") ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active")
ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain") ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain")
) )
@ -53,6 +54,7 @@ type Project struct {
Status string `json:"status"` Status string `json:"status"`
CutoverAt *string `json:"cutover_at,omitempty"` CutoverAt *string `json:"cutover_at,omitempty"`
DeltaMode bool `json:"delta_mode"` DeltaMode bool `json:"delta_mode"`
SharedDriveMode string `json:"shared_drive_mode"`
CreatedAt string `json:"created_at"` CreatedAt string `json:"created_at"`
MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"` MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"`
MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"` MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"`
@ -136,6 +138,9 @@ func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alt
return Invite{}, fmt.Errorf("email required") return Invite{}, fmt.Errorf("email required")
} }
alternates := normalizeAlternateEmails(email, alternateEmails) alternates := normalizeAlternateEmails(email, alternateEmails)
if alternates == nil {
alternates = []string{}
}
token, err := hosted.NewInviteToken() token, err := hosted.NewInviteToken()
if err != nil { if err != nil {
return Invite{}, err return Invite{}, err
@ -222,6 +227,9 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit
hostedDomain = &domain hostedDomain = &domain
projectDomain = domain.Name projectDomain = domain.Name
} }
if err := validateMicrosoftTenantClaim(proj, identity.TenantID); err != nil {
return UserStatus{}, err
}
if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) { if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) {
return UserStatus{}, ErrEmailMismatch return UserStatus{}, ErrEmailMismatch
} }
@ -242,7 +250,12 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit
return UserStatus{}, err return UserStatus{}, err
} }
if err := s.markRosterClaimed(ctx, tx, proj.ID, inv.ID, mailboxEmail); err != nil {
return UserStatus{}, err
}
if s.hosted != nil { if s.hosted != nil {
_ = LinkHostedMailboxByEmail(ctx, s.db, userID, mailboxEmail)
provision := hosted.ProvisionMailboxInput{ provision := hosted.ProvisionMailboxInput{
UserID: userID, UserID: userID,
Email: mailboxEmail, Email: mailboxEmail,
@ -260,16 +273,14 @@ func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identit
} }
provision.DomainID = proj.DomainID provision.DomainID = proj.DomainID
} }
_, err = s.hosted.ProvisionMailbox(ctx, provision) _, err = s.hosted.EnsureMailboxProvisioned(ctx, provision)
if err != nil { if err != nil {
if errors.Is(err, hosted.ErrDomainNotActive) { if errors.Is(err, hosted.ErrDomainNotActive) {
return UserStatus{}, ErrMigrationDomainNotActive return UserStatus{}, ErrMigrationDomainNotActive
} }
if !errors.Is(err, hosted.ErrAddressTaken) {
return UserStatus{}, err return UserStatus{}, err
} }
} }
}
services := []string{"mail", "contacts", "calendar", "drive"} services := []string{"mail", "contacts", "calendar", "drive"}
for _, svc := range services { for _, svc := range services {
@ -485,6 +496,10 @@ func (s *Service) ActivateProject(ctx context.Context, projectID string) (Projec
} }
func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) { func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) {
externalID = strings.TrimSpace(externalID)
if externalID == "" {
return "", pgx.ErrNoRows
}
var userID string var userID string
err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID) err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
return userID, err return userID, err

View File

@ -165,7 +165,20 @@ func (w *Worker) processJob(ctx context.Context, job Job) (string, error) {
procErr = NewCalendarImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update) procErr = NewCalendarImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update)
case "drive": case "drive":
selfManaged = true selfManaged = true
procErr = NewDriveImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update) sharedMode, err := w.projectSharedDriveMode(ctx, job.ProjectID)
if err != nil {
outcome = "failed"
return outcome, err
}
dedup, err := LoadSharedDriveItemStore(ctx, w.db, job.ProjectID)
if err != nil {
outcome = "failed"
return outcome, err
}
procErr = NewDriveImporter(w.db, w.nc).
WithUserPrincipal(graphUserUPN).
WithProject(job.ProjectID, sharedMode, dedup).
ImportBatch(ctx, &job, accessToken, provider, delta, update)
default: default:
procErr = fmt.Errorf("unknown service %q", job.Service) procErr = fmt.Errorf("unknown service %q", job.Service)
} }
@ -253,6 +266,15 @@ func (w *Worker) projectMicrosoftTenant(ctx context.Context, projectID string) (
return tenantID, nil return tenantID, nil
} }
func (w *Worker) projectSharedDriveMode(ctx context.Context, projectID string) (string, error) {
var mode string
err := w.db.QueryRow(ctx, `
SELECT COALESCE(shared_drive_mode, 'auto')
FROM migration_projects WHERE id = $1::uuid
`, projectID).Scan(&mode)
return NormalizeSharedDriveMode(mode), err
}
func (w *Worker) inviteEmail(ctx context.Context, projectID, userID string) (string, error) { func (w *Worker) inviteEmail(ctx context.Context, projectID, userID string) (string, error) {
var email string var email string
err := w.db.QueryRow(ctx, ` err := w.db.QueryRow(ctx, `

View File

@ -0,0 +1,65 @@
package nextcloud
import (
"bytes"
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
)
const defaultUploadChunkSize = 10 * 1024 * 1024
// UploadStreaming uploads large files via Nextcloud chunked DAV assembly.
func (c *Client) UploadStreaming(ctx context.Context, userID, targetPath string, content io.Reader, contentType string, totalSize int64) error {
uploadID, err := newChunkUploadID()
if err != nil {
return err
}
buf := make([]byte, defaultUploadChunkSize)
var uploaded int64
chunkIndex := 0
for {
n, readErr := io.ReadFull(content, buf)
if n == 0 && readErr == io.EOF {
break
}
if n > 0 {
chunkIndex++
if err := c.UploadChunk(ctx, userID, uploadID, chunkIndexName(chunkIndex), bytes.NewReader(buf[:n]), contentType); err != nil {
_ = c.AbortChunkUpload(ctx, userID, uploadID)
return err
}
uploaded += int64(n)
}
if readErr == io.EOF || readErr == io.ErrUnexpectedEOF {
break
}
if readErr != nil {
_ = c.AbortChunkUpload(ctx, userID, uploadID)
return readErr
}
}
finalSize := uploaded
if totalSize > 0 {
finalSize = totalSize
}
if err := c.AssembleChunks(ctx, userID, uploadID, targetPath, finalSize); err != nil {
_ = c.AbortChunkUpload(ctx, userID, uploadID)
return err
}
return nil
}
func newChunkUploadID() (string, error) {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("chunk upload id: %w", err)
}
return hex.EncodeToString(b), nil
}
func chunkIndexName(index int) string {
return fmt.Sprintf("%d", index)
}

View File

@ -6,11 +6,9 @@ import (
"net/http" "net/http"
"strings" "strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/mail/hosted" "github.com/ultisuite/ulti-backend/internal/mail/hosted"
"github.com/ultisuite/ulti-backend/internal/migration" "github.com/ultisuite/ulti-backend/internal/migration"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
@ -75,25 +73,24 @@ func (h *Handler) ProvisionUser(w http.ResponseWriter, r *http.Request) {
} }
ctx := r.Context() ctx := r.Context()
var userID string
externalID := strings.TrimSpace(req.ExternalID) externalID := strings.TrimSpace(req.ExternalID)
if externalID != "" { userID, err := users.ResolveProvisionUser(ctx, h.db, externalID, email, req.Name)
err := h.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
if errors.Is(err, pgx.ErrNoRows) {
userID, err = users.EnsureUser(ctx, h.db, &auth.Claims{
Sub: externalID,
Email: email,
Name: req.Name,
})
}
if err != nil { if err != nil {
h.logger.Error("ensure user", "error", err) h.logger.Error("resolve user", "error", err, "email", email)
apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to provision user", nil) apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to provision user", nil)
return return
} }
skipMailbox, err := migration.HasPendingMigrationInvite(ctx, h.db, email)
if err != nil {
h.logger.Error("check migration invite", "error", err, "email", email)
apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to check migration invite", nil)
return
} }
result, err := h.hosted.ProvisionMailbox(ctx, hosted.ProvisionMailboxInput{ var result hosted.ProvisionMailboxResult
if !skipMailbox {
result, err = h.hosted.EnsureMailboxProvisioned(ctx, hosted.ProvisionMailboxInput{
UserID: userID, UserID: userID,
Email: email, Email: email,
DisplayName: req.Name, DisplayName: req.Name,
@ -108,23 +105,28 @@ func (h *Handler) ProvisionUser(w http.ResponseWriter, r *http.Request) {
apiresponse.WriteError(w, r, http.StatusConflict, "provision_failed", err.Error(), nil) apiresponse.WriteError(w, r, http.StatusConflict, "provision_failed", err.Error(), nil)
return return
} }
if userID != "" {
_ = migration.LinkHostedMailboxByEmail(ctx, h.db, userID, email)
} }
_ = migration.LinkHostedMailboxByEmail(ctx, h.db, userID, email)
if h.nc != nil && userID != "" && externalID != "" { if h.nc != nil && userID != "" && externalID != "" {
if _, err := h.nc.EnsurePrincipal(ctx, email, externalID, req.Name); err != nil { if _, err := h.nc.EnsurePrincipal(ctx, email, externalID, req.Name); err != nil {
h.logger.Warn("nextcloud provision", "error", err) h.logger.Warn("nextcloud provision", "error", err)
} }
} }
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ resp := map[string]any{
"user_id": userID, "user_id": userID,
"email": email, "email": email,
"mailbox_id": result.Mailbox.ID, }
"mail_account_id": result.MailAccountID, if !skipMailbox {
}) resp["mailbox_id"] = result.Mailbox.ID
resp["mail_account_id"] = result.MailAccountID
}
if skipMailbox {
resp["mailbox_deferred"] = true
}
apiresponse.WriteJSON(w, http.StatusOK, resp)
} }
// CheckAddress validates local part availability (Authentik expression policy or public API). // CheckAddress validates local part availability (Authentik expression policy or public API).

View File

@ -0,0 +1,52 @@
package provision
import (
"bytes"
"net/http"
"net/http/httptest"
"testing"
)
func TestDecodeProvisionBodyAuthentikPayload(t *testing.T) {
body := []byte(`{
"email": "alice@ultisuite.fr",
"password": "secret",
"name": "Alice",
"external_id": "uuid-123",
"user": {"email": "ignored@example.com", "uuid": "ignored"}
}`)
req := httptest.NewRequest(http.MethodPost, "/internal/provision/user", bytes.NewReader(body))
got, err := decodeProvisionBody(req)
if err != nil {
t.Fatalf("decodeProvisionBody() error = %v", err)
}
if got.Email != "alice@ultisuite.fr" || got.ExternalID != "uuid-123" || got.Name != "Alice" {
t.Fatalf("decodeProvisionBody() = %#v", got)
}
}
func TestAuthorizeProvisionSecret(t *testing.T) {
req := httptest.NewRequest(http.MethodPost, "/internal/provision/user", nil)
req.Header.Set("X-Provision-Secret", "topsecret")
if !authorizeProvision(req, "topsecret") {
t.Fatal("expected header secret to authorize")
}
req = httptest.NewRequest(http.MethodPost, "/internal/provision/user?secret=topsecret", nil)
if !authorizeProvision(req, "topsecret") {
t.Fatal("expected query secret to authorize")
}
if authorizeProvision(req, "wrong") {
t.Fatal("expected wrong secret to fail")
}
}
func TestNormalizeProvisionRequestUsesUsername(t *testing.T) {
req := provisionUserRequest{Username: "bob@ultisuite.fr"}
normalizeProvisionRequest(&req)
if req.Email != "bob@ultisuite.fr" {
t.Fatalf("email = %q, want bob@ultisuite.fr", req.Email)
}
if req.Name != "bob@ultisuite.fr" {
t.Fatalf("name = %q, want fallback to email", req.Name)
}
}

View File

@ -2,13 +2,14 @@ package users
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"strings" "strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/migration"
) )
// ProvisionEmail returns the email stored for a newly provisioned user. // ProvisionEmail returns the email stored for a newly provisioned user.
@ -59,6 +60,66 @@ func EnsureUser(ctx context.Context, db *pgxpool.Pool, claims *auth.Claims) (str
return "", fmt.Errorf("bootstrap platform admin: %w", err) return "", fmt.Errorf("bootstrap platform admin: %w", err)
} }
} }
_ = migration.LinkHostedMailboxByEmail(ctx, db, userID, email)
return userID, nil return userID, nil
} }
// LookupUserID returns the internal user UUID for an OIDC subject.
func LookupUserID(ctx context.Context, db *pgxpool.Pool, externalID string) (string, error) {
if db == nil {
return "", fmt.Errorf("database not configured")
}
externalID = strings.TrimSpace(externalID)
if externalID == "" {
return "", pgx.ErrNoRows
}
var userID string
err := db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
return userID, err
}
// LookupUserIDByEmail returns the internal user UUID for a stored email address.
func LookupUserIDByEmail(ctx context.Context, db *pgxpool.Pool, email string) (string, error) {
if db == nil {
return "", fmt.Errorf("database not configured")
}
email = strings.ToLower(strings.TrimSpace(email))
if email == "" {
return "", pgx.ErrNoRows
}
var userID string
err := db.QueryRow(ctx, `SELECT id::text FROM users WHERE lower(email) = $1`, email).Scan(&userID)
return userID, err
}
// ResolveProvisionUser finds an existing user by external_id or email, or creates one from Authentik enrollment data.
func ResolveProvisionUser(ctx context.Context, db *pgxpool.Pool, externalID, email, name string) (string, error) {
externalID = strings.TrimSpace(externalID)
email = strings.ToLower(strings.TrimSpace(email))
if externalID != "" {
userID, err := LookupUserID(ctx, db, externalID)
if err == nil {
return userID, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
}
if email != "" {
userID, err := LookupUserIDByEmail(ctx, db, email)
if err == nil {
return userID, nil
}
if !errors.Is(err, pgx.ErrNoRows) {
return "", err
}
}
if externalID == "" {
return "", fmt.Errorf("cannot provision user without external_id or existing email")
}
return EnsureUser(ctx, db, &auth.Claims{
Sub: externalID,
Email: email,
Name: name,
})
}

View File

@ -0,0 +1,26 @@
package users
import (
"testing"
)
func TestLookupUserIDEmptyExternalID(t *testing.T) {
_, err := LookupUserID(t.Context(), nil, "")
if err == nil {
t.Fatal("expected error for empty external id")
}
}
func TestLookupUserIDByEmailEmpty(t *testing.T) {
_, err := LookupUserIDByEmail(t.Context(), nil, "")
if err == nil {
t.Fatal("expected error for empty email")
}
}
func TestResolveProvisionUserRequiresIdentity(t *testing.T) {
_, err := ResolveProvisionUser(t.Context(), nil, "", "user@example.com", "User")
if err == nil {
t.Fatal("expected error without external_id or existing user")
}
}

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS migration_roster;

View File

@ -0,0 +1,17 @@
CREATE TABLE migration_roster (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
email TEXT NOT NULL,
display_name TEXT NOT NULL DEFAULT '',
alternate_emails TEXT[] NOT NULL DEFAULT '{}',
status TEXT NOT NULL DEFAULT 'pending',
invite_id UUID REFERENCES migration_invites(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE (project_id, email),
CONSTRAINT migration_roster_status_check CHECK (status IN ('pending', 'invited', 'claimed'))
);
CREATE INDEX idx_migration_roster_project ON migration_roster(project_id);
CREATE INDEX idx_migration_roster_status ON migration_roster(status);
CREATE INDEX idx_migration_roster_invite ON migration_roster(invite_id);

View File

@ -0,0 +1,5 @@
DROP TABLE IF EXISTS migration_shared_drive_items;
DROP TABLE IF EXISTS migration_shared_drives;
ALTER TABLE migration_projects DROP CONSTRAINT IF EXISTS migration_projects_shared_drive_mode_check;
ALTER TABLE migration_projects DROP COLUMN IF EXISTS shared_drive_mode;

View File

@ -0,0 +1,39 @@
-- Shared drive import mode and project-level dedup for Google Workspace migrations.
ALTER TABLE migration_projects
ADD COLUMN shared_drive_mode TEXT NOT NULL DEFAULT 'auto';
ALTER TABLE migration_projects
ADD CONSTRAINT migration_projects_shared_drive_mode_check
CHECK (shared_drive_mode IN ('auto', 'manual'));
CREATE TABLE migration_shared_drives (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
drive_id TEXT NOT NULL,
name TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL DEFAULT 'pending',
discovered_by_user_id UUID REFERENCES users(id) ON DELETE SET NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(project_id, drive_id),
CONSTRAINT migration_shared_drives_status_check
CHECK (status IN ('pending', 'approved', 'rejected'))
);
CREATE INDEX idx_migration_shared_drives_project_status
ON migration_shared_drives(project_id, status);
-- Project-level dedup: one import per shared-drive file across all users in a project.
CREATE TABLE migration_shared_drive_items (
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
drive_id TEXT NOT NULL,
source_id TEXT NOT NULL,
rel_path TEXT NOT NULL DEFAULT '',
imported_by_job_id UUID REFERENCES migration_jobs(id) ON DELETE SET NULL,
imported_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
UNIQUE(project_id, drive_id, source_id)
);
CREATE INDEX idx_migration_shared_drive_items_project
ON migration_shared_drive_items(project_id, drive_id);