feat(mail): integrate Stalwart hosted mail and migration features
- Added configuration options for Stalwart hosted mail in .env.example. - Updated Docker Compose to include Stalwart service with health checks. - Introduced new API endpoints for managing mail domains and migration projects. - Enhanced Authentik blueprints for user enrollment and post-migration security. - Updated OAuth handling for Google and Microsoft migration processes. - Improved error handling and response structures in the mail API. - Added integration tests for email claiming and migration workflows.
This commit is contained in:
parent
1d063237b9
commit
7143a36c19
49
.env.example
49
.env.example
@ -259,6 +259,55 @@ MAIL_MICROSOFT_OAUTH_TENANT=common
|
||||
MAIL_OAUTH_REDIRECT_URL=
|
||||
MAIL_APP_URL=http://localhost/mail
|
||||
# Cible nginx → suite frontend unifié mail+drive (dev: Next sur l'hôte :3004 ; prod: suite-frontend:3000)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Stalwart hosted mail (optional — enable for @ultisuite.fr / custom domains)
|
||||
# -----------------------------------------------------------------------------
|
||||
STALWART_ENABLED=false
|
||||
STALWART_API_URL=http://stalwart:8080
|
||||
# STALWART_API_KEY — API key from Stalwart webadmin (management JMAP API)
|
||||
STALWART_IMAP_HOST=stalwart
|
||||
STALWART_IMAP_PORT=993
|
||||
STALWART_IMAP_TLS=true
|
||||
STALWART_SMTP_HOST=stalwart
|
||||
STALWART_SMTP_PORT=587
|
||||
STALWART_SMTP_TLS=true
|
||||
STALWART_RECOVERY_ADMIN=admin:changeme-stalwart-admin
|
||||
PLATFORM_MAIL_DOMAIN=ultisuite.fr
|
||||
# PROVISION_WEBHOOK_SECRET — shared secret for Authentik enrollment webhook → ultid
|
||||
PROVISION_WEBHOOK_SECRET=changeme-provision-webhook
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Migration OAuth (Google Workspace / Microsoft 365 import)
|
||||
# Falls back to MAIL_* OAuth credentials when unset.
|
||||
# -----------------------------------------------------------------------------
|
||||
MIGRATION_GOOGLE_OAUTH_CLIENT_ID=
|
||||
MIGRATION_GOOGLE_OAUTH_CLIENT_SECRET=
|
||||
MIGRATION_MICROSOFT_OAUTH_CLIENT_ID=
|
||||
MIGRATION_MICROSOFT_OAUTH_CLIENT_SECRET=
|
||||
MIGRATION_MICROSOFT_OAUTH_TENANT=common
|
||||
MIGRATION_OAUTH_REDIRECT_URL=
|
||||
MIGRATION_WORKER_INTERVAL=30s
|
||||
# Worker picks up to JOB_LIMIT pending jobs per tick (0 = concurrency*3, min 5).
|
||||
MIGRATION_WORKER_CONCURRENCY=2
|
||||
MIGRATION_WORKER_JOB_LIMIT=0
|
||||
# Items processed per worker tick (mail/contacts/calendar vs drive file downloads).
|
||||
MIGRATION_IMPORT_BATCH_SIZE=25
|
||||
MIGRATION_DRIVE_BATCH_SIZE=10
|
||||
# Rate-limit backoff for Google/Microsoft migration API calls (429 + Retry-After).
|
||||
MIGRATION_RATE_LIMIT_MAX_RETRIES=6
|
||||
MIGRATION_RATE_LIMIT_BASE_DELAY=2s
|
||||
MIGRATION_RATE_LIMIT_MAX_DELAY=2m
|
||||
# Microsoft 365 app-only (client credentials) uses MIGRATION_MICROSOFT_OAUTH_* above.
|
||||
# Requires Azure AD application permissions (Mail.Read, Calendars.Read, Contacts.Read, Files.Read.All)
|
||||
# and tenant admin consent; per-project tenant id is stored after admin consent redirect.
|
||||
# Google Workspace domain-wide delegation (service account JSON, one line)
|
||||
MIGRATION_GOOGLE_SERVICE_ACCOUNT_JSON=
|
||||
# MX hosts expected at cutover (comma-separated). Defaults: mail.{PLATFORM_MAIL_DOMAIN}, then STALWART_IMAP_HOST if FQDN.
|
||||
MIGRATION_CUTOVER_MX_HOSTS=
|
||||
# Block cutover when live MX does not match expected hosts (requires domain_id on project).
|
||||
MIGRATION_CUTOVER_REQUIRE_MX=false
|
||||
|
||||
MAIL_FRONTEND_UPSTREAM=host.docker.internal:3004
|
||||
MAIL_WEBHOOK_SHARED_SECRET_ROTATED_AT=2026-01-01T00:00:00Z
|
||||
|
||||
|
||||
@ -4,9 +4,10 @@ Blueprints in `blueprints/` are mounted into Authentik at `/blueprints/custom` a
|
||||
|
||||
| Fichier | Rôle |
|
||||
|---------|------|
|
||||
| `01-ulti-enrollment.yaml` | Inscription self-service (`ulti-enrollment`) |
|
||||
| `01-ulti-enrollment.yaml` | Inscription self-service (`ulti-enrollment`, @ultisuite.fr) |
|
||||
| `02-ulti-brand.yaml` | Branding UltiSuite + lien « Créer un compte » sur login |
|
||||
| `03-ulti-suite-groups.yaml` | Claim OIDC `groups` (RBAC contacts/calendar/drive/photos) |
|
||||
| `04-ulti-post-migration-security.yaml` | Flow WebAuthn/TOTP post-migration (`ulti-post-migration-security`) |
|
||||
| `ulti-oidc.yaml` | App OIDC Ultimail |
|
||||
| `nextcloud-oidc.yaml` | App OIDC Nextcloud |
|
||||
| `onlyoffice-oidc.yaml` | App OIDC OnlyOffice |
|
||||
@ -44,6 +45,10 @@ Flow public : `http://localhost/auth/if/flow/ulti-enrollment/`
|
||||
2. Nom et prénom, téléphone (optionnel), avatar (optionnel)
|
||||
3. Création du compte + connexion automatique
|
||||
|
||||
L'email d'inscription est construit comme `username@ultisuite.fr`. ultid peut provisionner la boîte Stalwart via `POST /internal/provision/user` (header `X-Provision-Secret: $PROVISION_WEBHOOK_SECRET`).
|
||||
|
||||
Flow post-migration (WebAuthn/TOTP) : `/auth/if/flow/ulti-post-migration-security/`
|
||||
|
||||
Sur la page de connexion Authentik, lien **« Besoin d'un compte ? S'inscrire »** (identification stage).
|
||||
|
||||
## Branding
|
||||
|
||||
@ -22,12 +22,26 @@ entries:
|
||||
attrs:
|
||||
field_key: username
|
||||
label: Adresse e-mail
|
||||
type: email
|
||||
type: text
|
||||
required: true
|
||||
placeholder: vous@exemple.com
|
||||
placeholder: prenom.nom
|
||||
placeholder_expression: false
|
||||
order: 0
|
||||
|
||||
- model: authentik_stages_prompt.prompt
|
||||
id: ulti-enroll-field-domain-hint
|
||||
identifiers:
|
||||
name: ulti-enrollment-field-domain-hint
|
||||
attrs:
|
||||
field_key: domain_hint
|
||||
label: Votre adresse sera
|
||||
type: static
|
||||
required: false
|
||||
initial_value: "@ultisuite.fr"
|
||||
initial_value_expression: false
|
||||
placeholder_expression: false
|
||||
order: 1
|
||||
|
||||
- model: authentik_stages_prompt.prompt
|
||||
id: ulti-enroll-field-email-sync
|
||||
identifiers:
|
||||
@ -37,7 +51,7 @@ entries:
|
||||
label: E-mail
|
||||
type: hidden
|
||||
required: true
|
||||
initial_value: "{{ prompt_data.username }}"
|
||||
initial_value: "{{ prompt_data.username }}@ultisuite.fr"
|
||||
initial_value_expression: true
|
||||
placeholder_expression: false
|
||||
order: 1
|
||||
@ -114,6 +128,7 @@ entries:
|
||||
attrs:
|
||||
fields:
|
||||
- !KeyOf ulti-enroll-field-email
|
||||
- !KeyOf ulti-enroll-field-domain-hint
|
||||
- !KeyOf ulti-enroll-field-email-sync
|
||||
- !KeyOf ulti-enroll-field-password
|
||||
- !KeyOf ulti-enroll-field-password-repeat
|
||||
@ -141,6 +156,55 @@ entries:
|
||||
identifiers:
|
||||
name: ulti-enrollment-user-login
|
||||
|
||||
- model: authentik_policies_expression.expressionpolicy
|
||||
id: ulti-enroll-policy-username-available
|
||||
identifiers:
|
||||
name: ulti-enrollment-username-available
|
||||
attrs:
|
||||
name: Ultimail — adresse disponible
|
||||
expression: |
|
||||
import json
|
||||
from urllib.request import urlopen
|
||||
local = (request.context.get("prompt_data") or {}).get("username", "").strip().lower()
|
||||
if not local or len(local) < 2:
|
||||
return False
|
||||
url = f"http://ultid:8080/api/v1/mail/addresses/check?local={local}&domain=ultisuite.fr"
|
||||
try:
|
||||
with urlopen(url, timeout=5) as resp:
|
||||
data = json.loads(resp.read().decode("utf-8"))
|
||||
return data.get("available") is True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
- model: authentik_policies.policybinding
|
||||
identifiers:
|
||||
order: 0
|
||||
target: !KeyOf ulti-enroll-prompt-credentials
|
||||
policy: !KeyOf ulti-enroll-policy-username-available
|
||||
attrs:
|
||||
enabled: true
|
||||
timeout: 10
|
||||
failure_result: false
|
||||
|
||||
- model: authentik_stages_webhook.webhookstage
|
||||
id: ulti-enroll-provision-webhook
|
||||
identifiers:
|
||||
name: ulti-enrollment-provision-webhook
|
||||
attrs:
|
||||
url: http://ultid:8080/internal/provision/user?secret=changeme-provision-webhook
|
||||
method: POST
|
||||
headers:
|
||||
X-Provision-Secret: changeme-provision-webhook
|
||||
Content-Type: application/json
|
||||
body: |
|
||||
{
|
||||
"email": "{{ prompt_data.email }}",
|
||||
"username": "{{ prompt_data.username }}",
|
||||
"password": "{{ prompt_data.password }}",
|
||||
"name": "{{ prompt_data.name }}",
|
||||
"external_id": "{{ user.uuid }}"
|
||||
}
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-enrollment-flow
|
||||
@ -159,6 +223,12 @@ entries:
|
||||
stage: !KeyOf ulti-enroll-user-write
|
||||
order: 30
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-enrollment-flow
|
||||
stage: !KeyOf ulti-enroll-provision-webhook
|
||||
order: 40
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-enrollment-flow
|
||||
|
||||
@ -0,0 +1,82 @@
|
||||
# Post-migration — encourage WebAuthn / TOTP before disabling legacy IdP
|
||||
version: 1
|
||||
metadata:
|
||||
name: Ulti post-migration security
|
||||
labels:
|
||||
blueprints.goauthentik.io/instantiate: "true"
|
||||
entries:
|
||||
- model: authentik_flows.flow
|
||||
id: ulti-post-migration-flow
|
||||
identifiers:
|
||||
slug: ulti-post-migration-security
|
||||
attrs:
|
||||
name: UltiSuite — Sécuriser votre compte
|
||||
title: Sécuriser votre compte UltiSuite
|
||||
designation: stage_configuration
|
||||
authentication: require_authenticated
|
||||
|
||||
- model: authentik_stages_prompt.prompt
|
||||
id: ulti-post-mig-info
|
||||
identifiers:
|
||||
name: ulti-post-migration-info
|
||||
attrs:
|
||||
field_key: info
|
||||
label: Information
|
||||
type: static
|
||||
required: false
|
||||
initial_value: >
|
||||
Votre migration est en cours ou terminée. Enregistrez une clé de sécurité (WebAuthn)
|
||||
ou une application TOTP pour vous connecter sans dépendre de Google Workspace ou Microsoft 365.
|
||||
initial_value_expression: false
|
||||
placeholder_expression: false
|
||||
order: 0
|
||||
|
||||
- model: authentik_stages_prompt.promptstage
|
||||
id: ulti-post-mig-prompt
|
||||
identifiers:
|
||||
name: ulti-post-migration-prompt
|
||||
attrs:
|
||||
fields:
|
||||
- !KeyOf ulti-post-mig-info
|
||||
|
||||
- model: authentik_stages_authenticator_webauthn.authenticatorwebauthnstage
|
||||
id: ulti-post-mig-webauthn
|
||||
identifiers:
|
||||
name: ulti-post-migration-webauthn
|
||||
attrs:
|
||||
user_verification: preferred
|
||||
device_type_restrictions: no_restrictions
|
||||
|
||||
- model: authentik_stages_authenticator_totp.authenticatortotpstage
|
||||
id: ulti-post-mig-totp
|
||||
identifiers:
|
||||
name: ulti-post-migration-totp
|
||||
|
||||
- model: authentik_stages_user_login.userloginstage
|
||||
id: ulti-post-mig-done
|
||||
identifiers:
|
||||
name: ulti-post-migration-done
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-post-migration-flow
|
||||
stage: !KeyOf ulti-post-mig-prompt
|
||||
order: 10
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-post-migration-flow
|
||||
stage: !KeyOf ulti-post-mig-webauthn
|
||||
order: 20
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-post-migration-flow
|
||||
stage: !KeyOf ulti-post-mig-totp
|
||||
order: 30
|
||||
|
||||
- model: authentik_flows.flowstagebinding
|
||||
identifiers:
|
||||
target: !KeyOf ulti-post-migration-flow
|
||||
stage: !KeyOf ulti-post-mig-done
|
||||
order: 100
|
||||
@ -219,6 +219,22 @@ services:
|
||||
depends_on:
|
||||
- ultid
|
||||
|
||||
stalwart:
|
||||
image: stalwartlabs/stalwart:v0.16
|
||||
restart: unless-stopped
|
||||
environment:
|
||||
STALWART_RECOVERY_ADMIN: ${STALWART_RECOVERY_ADMIN:-admin:changeme-stalwart-admin}
|
||||
volumes:
|
||||
- stalwart_data:/opt/stalwart
|
||||
networks:
|
||||
- ulti-net
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "wget -qO- http://127.0.0.1:8080/healthz >/dev/null 2>&1 || exit 1"]
|
||||
interval: 15s
|
||||
timeout: 5s
|
||||
retries: 8
|
||||
start_period: 30s
|
||||
|
||||
networks:
|
||||
ulti-net:
|
||||
driver: bridge
|
||||
@ -229,3 +245,4 @@ volumes:
|
||||
rustfs_data:
|
||||
prometheus_data:
|
||||
grafana_data:
|
||||
stalwart_data:
|
||||
|
||||
@ -472,3 +472,25 @@ server {
|
||||
return 404 "Not found\n";
|
||||
}
|
||||
}
|
||||
|
||||
# Stalwart webadmin + JMAP (mail.${DOMAIN})
|
||||
server {
|
||||
listen 80;
|
||||
server_name mail.${DOMAIN};
|
||||
|
||||
client_max_body_size 100M;
|
||||
|
||||
location / {
|
||||
resolver 127.0.0.11 valid=10s ipv6=off;
|
||||
set $stalwart_upstream stalwart:8080;
|
||||
|
||||
proxy_pass http://$stalwart_upstream;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection $connection_upgrade;
|
||||
}
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@ -23,12 +23,13 @@ require (
|
||||
github.com/testcontainers/testcontainers-go/modules/minio v0.35.0
|
||||
github.com/testcontainers/testcontainers-go/modules/postgres v0.35.0
|
||||
golang.org/x/net v0.55.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
golang.org/x/oauth2 v0.36.0
|
||||
golang.org/x/text v0.37.0
|
||||
golang.org/x/time v0.15.0
|
||||
)
|
||||
|
||||
require (
|
||||
cloud.google.com/go/compute/metadata v0.3.0 // indirect
|
||||
dario.cat/mergo v1.0.0 // indirect
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 // indirect
|
||||
github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect
|
||||
|
||||
6
go.sum
6
go.sum
@ -1,3 +1,5 @@
|
||||
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
|
||||
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
|
||||
dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
|
||||
dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
|
||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20240806141605-e8a1dd7889d6 h1:He8afgbRMd7mFxO99hRNu+6tazq8nFF9lIwo9JFroBk=
|
||||
@ -255,8 +257,8 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/net v0.55.0 h1:bcvxaJn3e1U6InsFWt1JUq1aSjnRxLzT2rtD2KfkDF8=
|
||||
golang.org/x/net v0.55.0/go.mod h1:L5U2KuzuOe1lY7Z+aWVIKK6qEeJXnXV9yzGA+WCHJww=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs=
|
||||
golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
||||
@ -15,6 +15,8 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
"github.com/ultisuite/ulti-backend/internal/securityaudit"
|
||||
@ -33,6 +35,14 @@ func NewHandler(db *pgxpool.Pool, audit *securityaudit.Logger, cfg *config.Confi
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) SetHostedService(svc *hosted.Service) {
|
||||
h.svc.SetHostedService(svc)
|
||||
}
|
||||
|
||||
func (h *Handler) SetMigrationService(svc *migr.Service) {
|
||||
h.svc.SetMigrationService(svc)
|
||||
}
|
||||
|
||||
func (h *Handler) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
read := middleware.RequireAdminScope(permission.AdminScopeRead)
|
||||
@ -65,6 +75,7 @@ func (h *Handler) Routes() chi.Router {
|
||||
r.With(write).Post("/org/identity-providers/{providerID}/sync", h.SyncIdentityProvider)
|
||||
|
||||
h.registerDriveAdminRoutes(r, read, write)
|
||||
h.registerMailAdminRoutes(r, read, write)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
390
internal/api/admin/handlers_mail_domains.go
Normal file
390
internal/api/admin/handlers_mail_domains.go
Normal file
@ -0,0 +1,390 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
)
|
||||
|
||||
const maxAdminMailRequestBody = 1 << 20
|
||||
|
||||
func (h *Handler) registerMailAdminRoutes(r chi.Router, read, write func(http.Handler) http.Handler) {
|
||||
if h.svc.hosted == nil && h.svc.migration == nil {
|
||||
return
|
||||
}
|
||||
r.Route("/mail", func(r chi.Router) {
|
||||
if h.svc.hosted != nil {
|
||||
r.With(read).Get("/domains", h.ListMailDomains)
|
||||
r.With(write).Post("/domains", h.CreateMailDomain)
|
||||
r.With(read).Get("/domains/{domainID}", h.GetMailDomain)
|
||||
r.With(write).Post("/domains/{domainID}/verify-txt", h.VerifyMailDomainTXT)
|
||||
r.With(write).Post("/domains/{domainID}/verify-mx", h.VerifyMailDomainMX)
|
||||
}
|
||||
})
|
||||
r.Route("/migration", func(r chi.Router) {
|
||||
if h.svc.migration == nil {
|
||||
return
|
||||
}
|
||||
r.With(read).Get("/projects", h.ListMigrationProjects)
|
||||
r.With(write).Post("/projects", h.CreateMigrationProject)
|
||||
r.With(write).Post("/projects/{projectID}/activate", h.ActivateMigrationProject)
|
||||
r.With(read).Get("/projects/{projectID}/cutover-dns", h.PreflightMigrationCutoverDNS)
|
||||
r.With(write).Post("/projects/{projectID}/cutover", h.StartMigrationCutover)
|
||||
r.With(write).Post("/projects/{projectID}/invites", h.CreateMigrationInvite)
|
||||
r.With(write).Post("/projects/{projectID}/invites/import", h.ImportMigrationInvites)
|
||||
r.With(read).Get("/projects/{projectID}/jobs", h.ListMigrationProjectJobs)
|
||||
r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit", h.ListMigrationJobAudit)
|
||||
r.With(read).Get("/projects/{projectID}/jobs/{jobID}/audit/summary", h.MigrationJobAuditSummary)
|
||||
r.With(write).Post("/projects/{projectID}/jobs/retry-failed", h.RetryMigrationFailedJobs)
|
||||
r.With(write).Post("/projects/{projectID}/jobs/{jobID}/retry", h.RetryMigrationJob)
|
||||
r.With(write).Post("/projects/{projectID}/jobs/{jobID}/reset-cursor", h.ResetMigrationJobCursor)
|
||||
r.With(read).Get("/microsoft/admin-consent-url", h.MicrosoftMigrationAdminConsentURL)
|
||||
r.With(read).Get("/microsoft/admin-consents", h.ListMicrosoftAdminConsents)
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) ListMailDomains(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.svc.hosted.ListDomains(r.Context())
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"domains": rows})
|
||||
}
|
||||
|
||||
type createMailDomainRequest struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateMailDomain(w http.ResponseWriter, r *http.Request) {
|
||||
var req createMailDomainRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
row, err := h.svc.hosted.CreateDomain(r.Context(), req.Name, false)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusCreated, row)
|
||||
}
|
||||
|
||||
func (h *Handler) GetMailDomain(w http.ResponseWriter, r *http.Request) {
|
||||
row, err := h.svc.hosted.GetDomain(r.Context(), chi.URLParam(r, "domainID"))
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, row)
|
||||
}
|
||||
|
||||
func (h *Handler) VerifyMailDomainTXT(w http.ResponseWriter, r *http.Request) {
|
||||
domainID := chi.URLParam(r, "domainID")
|
||||
row, report, err := h.svc.hosted.VerifyDomainTXTRecord(r.Context(), domainID)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "dns_txt_not_verified", err.Error(), map[string]any{"dns": report})
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"domain": row, "dns": report})
|
||||
}
|
||||
|
||||
func (h *Handler) VerifyMailDomainMX(w http.ResponseWriter, r *http.Request) {
|
||||
domainID := chi.URLParam(r, "domainID")
|
||||
expected := h.migrationCutoverConfig().ExpectedMXHosts
|
||||
row, report, err := h.svc.hosted.VerifyDomainMXRecord(r.Context(), domainID, expected)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "dns_mx_not_verified", err.Error(), map[string]any{"dns": report})
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"domain": row, "dns": report})
|
||||
}
|
||||
|
||||
type createMigrationProjectRequest struct {
|
||||
Name string `json:"name"`
|
||||
SourceProvider string `json:"source_provider"`
|
||||
DomainID string `json:"domain_id"`
|
||||
AuthMode string `json:"auth_mode"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateMigrationProject(w http.ResponseWriter, r *http.Request) {
|
||||
var req createMigrationProjectRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
row, err := h.svc.migration.CreateProject(r.Context(), req.Name, req.SourceProvider, req.DomainID, req.AuthMode)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusCreated, row)
|
||||
}
|
||||
|
||||
func (h *Handler) ListMigrationProjects(w http.ResponseWriter, r *http.Request) {
|
||||
rows, err := h.svc.migration.ListProjects(r.Context())
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"projects": rows})
|
||||
}
|
||||
|
||||
func (h *Handler) ActivateMigrationProject(w http.ResponseWriter, r *http.Request) {
|
||||
row, err := h.svc.migration.ActivateProject(r.Context(), chi.URLParam(r, "projectID"))
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, row)
|
||||
}
|
||||
|
||||
func (h *Handler) PreflightMigrationCutoverDNS(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
report, err := h.svc.migration.PreflightCutoverDNS(r.Context(), chi.URLParam(r, "projectID"), h.migrationCutoverConfig())
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"dns": report})
|
||||
}
|
||||
|
||||
func (h *Handler) StartMigrationCutover(w http.ResponseWriter, r *http.Request) {
|
||||
result, err := h.svc.migration.StartCutover(r.Context(), chi.URLParam(r, "projectID"))
|
||||
if errors.Is(err, migr.ErrCutoverMXNotReady) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "migration_cutover_mx_not_ready", err.Error(), map[string]any{
|
||||
"dns": result.DNS,
|
||||
"project": result.Project,
|
||||
})
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func (h *Handler) migrationCutoverConfig() migr.CutoverConfig {
|
||||
if h.svc.cfg == nil {
|
||||
return migr.CutoverConfig{}
|
||||
}
|
||||
return migr.CutoverConfig{
|
||||
ExpectedMXHosts: migr.ParseCutoverMXHosts(
|
||||
h.svc.cfg.MigrationCutoverMXHosts,
|
||||
h.svc.cfg.PlatformMailDomain,
|
||||
h.svc.cfg.StalwartIMAPHost,
|
||||
),
|
||||
RequireMX: h.svc.cfg.MigrationCutoverRequireMX,
|
||||
}
|
||||
}
|
||||
|
||||
type createMigrationInviteRequest struct {
|
||||
Email string `json:"email"`
|
||||
AlternateEmails []string `json:"alternate_emails,omitempty"`
|
||||
}
|
||||
|
||||
func (h *Handler) CreateMigrationInvite(w http.ResponseWriter, r *http.Request) {
|
||||
var req createMigrationInviteRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
row, err := h.svc.migration.CreateInvite(r.Context(), chi.URLParam(r, "projectID"), req.Email, req.AlternateEmails)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusCreated, row)
|
||||
}
|
||||
|
||||
func (h *Handler) ImportMigrationInvites(w http.ResponseWriter, r *http.Request) {
|
||||
var emails []string
|
||||
contentType := r.Header.Get("Content-Type")
|
||||
if strings.Contains(contentType, "multipart/form-data") {
|
||||
file, _, err := r.FormFile("file")
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
reader := csv.NewReader(file)
|
||||
for {
|
||||
record, err := reader.Read()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
||||
Field: "file", Message: "invalid csv",
|
||||
}))
|
||||
return
|
||||
}
|
||||
if len(record) > 0 {
|
||||
emails = append(emails, record[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(emails) == 0 {
|
||||
var body struct {
|
||||
Emails []string `json:"emails"`
|
||||
}
|
||||
if err := apivalidate.DecodeJSON(w, r, maxAdminMailRequestBody, &body); err != nil {
|
||||
return
|
||||
}
|
||||
emails = body.Emails
|
||||
}
|
||||
count, err := h.svc.migration.ImportInvites(r.Context(), chi.URLParam(r, "projectID"), emails)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"imported": count})
|
||||
}
|
||||
|
||||
func (h *Handler) MicrosoftMigrationAdminConsentURL(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
consentURL, err := h.svc.migration.MicrosoftAdminConsentURL(
|
||||
r.URL.Query().Get("tenant"),
|
||||
r.URL.Query().Get("project_id"),
|
||||
)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "admin_consent_unavailable", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"url": consentURL})
|
||||
}
|
||||
|
||||
func (h *Handler) ListMicrosoftAdminConsents(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
rows, err := h.svc.migration.ListMicrosoftAdminConsents(r.Context())
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"consents": rows})
|
||||
}
|
||||
|
||||
func (h *Handler) ListMigrationProjectJobs(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
rows, err := h.svc.migration.ListProjectJobs(r.Context(), chi.URLParam(r, "projectID"))
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"jobs": rows})
|
||||
}
|
||||
|
||||
func (h *Handler) RetryMigrationJob(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
row, err := h.svc.migration.RetryJob(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "jobID"))
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, "migration_job_not_retryable", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, row)
|
||||
}
|
||||
|
||||
func (h *Handler) ResetMigrationJobCursor(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
row, err := h.svc.migration.ResetJobCursor(r.Context(), chi.URLParam(r, "projectID"), chi.URLParam(r, "jobID"))
|
||||
if err != nil {
|
||||
status := http.StatusNotFound
|
||||
code := "migration_job_not_resettable"
|
||||
if strings.Contains(err.Error(), "running") {
|
||||
status = http.StatusConflict
|
||||
} else if strings.Contains(err.Error(), "not found") {
|
||||
code = "migration_job_not_found"
|
||||
}
|
||||
apiresponse.WriteError(w, r, status, code, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, row)
|
||||
}
|
||||
|
||||
func (h *Handler) RetryMigrationFailedJobs(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
count, err := h.svc.migration.RetryFailedJobs(r.Context(), chi.URLParam(r, "projectID"))
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"retried": count})
|
||||
}
|
||||
|
||||
func (h *Handler) ListMigrationJobAudit(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
params, err := query.ParseListRequest(r)
|
||||
if err != nil {
|
||||
apivalidate.WriteQueryError(w, r, err)
|
||||
return
|
||||
}
|
||||
items, pagination, err := h.svc.migration.ListJobAudit(
|
||||
r.Context(),
|
||||
chi.URLParam(r, "projectID"),
|
||||
chi.URLParam(r, "jobID"),
|
||||
r.URL.Query().Get("status"),
|
||||
params,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, "migration_job_not_found", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"items": items,
|
||||
"pagination": pagination,
|
||||
})
|
||||
}
|
||||
|
||||
func (h *Handler) MigrationJobAuditSummary(w http.ResponseWriter, r *http.Request) {
|
||||
if h.svc.migration == nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
summary, err := h.svc.migration.JobAuditSummary(
|
||||
r.Context(),
|
||||
chi.URLParam(r, "projectID"),
|
||||
chi.URLParam(r, "jobID"),
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, "migration_job_not_found", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, summary)
|
||||
}
|
||||
@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/permission"
|
||||
"github.com/ultisuite/ulti-backend/internal/securityaudit"
|
||||
@ -27,11 +29,13 @@ import (
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
type Service struct {
|
||||
db *pgxpool.Pool
|
||||
audit *securityaudit.Logger
|
||||
cfg *config.Config
|
||||
nc *nextcloud.Client
|
||||
logger *slog.Logger
|
||||
db *pgxpool.Pool
|
||||
audit *securityaudit.Logger
|
||||
cfg *config.Config
|
||||
nc *nextcloud.Client
|
||||
hosted *hosted.Service
|
||||
migration *migr.Service
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewService(db *pgxpool.Pool, audit *securityaudit.Logger, cfg *config.Config, nc *nextcloud.Client) *Service {
|
||||
@ -44,6 +48,14 @@ func NewService(db *pgxpool.Pool, audit *securityaudit.Logger, cfg *config.Confi
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetHostedService(hostedSvc *hosted.Service) {
|
||||
s.hosted = hostedSvc
|
||||
}
|
||||
|
||||
func (s *Service) SetMigrationService(migrationSvc *migr.Service) {
|
||||
s.migration = migrationSvc
|
||||
}
|
||||
|
||||
type UsersList struct {
|
||||
Users []map[string]any `json:"users"`
|
||||
Pagination query.PaginationMeta `json:"pagination,omitempty"`
|
||||
|
||||
@ -95,6 +95,7 @@ func (h *Handler) Routes() chi.Router {
|
||||
r.Post("/accounts/{accountID}/test", h.TestStoredAccountConnection)
|
||||
r.Get("/accounts/oauth/providers", h.ListOAuthProviders)
|
||||
r.Post("/accounts/oauth/start", h.StartOAuthAccount)
|
||||
r.Get("/addresses/check", h.CheckAddressAvailability)
|
||||
r.Get("/accounts/{accountID}", h.GetAccount)
|
||||
r.Put("/accounts/{accountID}", h.UpdateAccount)
|
||||
r.Delete("/accounts/{accountID}", h.DeleteAccount)
|
||||
|
||||
48
internal/api/mail/handlers_hosted.go
Normal file
48
internal/api/mail/handlers_hosted.go
Normal file
@ -0,0 +1,48 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
)
|
||||
|
||||
func (h *Handler) SetHostedService(svc *hosted.Service) {
|
||||
if s, ok := h.svc.(*Service); ok {
|
||||
s.SetHostedService(svc)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) CheckAddressAvailability(w http.ResponseWriter, r *http.Request) {
|
||||
svc := h.hostedService()
|
||||
if svc == nil {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": true, "reason": "hosted_mail_disabled"})
|
||||
return
|
||||
}
|
||||
local := strings.TrimSpace(r.URL.Query().Get("local"))
|
||||
domain := strings.TrimSpace(r.URL.Query().Get("domain"))
|
||||
if domain == "" {
|
||||
domain = strings.TrimSpace(r.URL.Query().Get("domain_name"))
|
||||
}
|
||||
if local == "" || domain == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
||||
Field: "local", Message: "local and domain required",
|
||||
}))
|
||||
return
|
||||
}
|
||||
available, err := svc.IsAddressAvailable(r.Context(), domain, local)
|
||||
if err != nil {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": false, "reason": err.Error()})
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": available})
|
||||
}
|
||||
|
||||
func (h *Handler) hostedService() *hosted.Service {
|
||||
if s, ok := h.svc.(*Service); ok {
|
||||
return s.HostedService()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
@ -40,6 +41,7 @@ type Service struct {
|
||||
attachmentsBucket string
|
||||
driveUploader DriveUploader
|
||||
scanner *filescan.Scanner
|
||||
hosted *hosted.Service
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
@ -54,6 +56,14 @@ func NewService(db *pgxpool.Pool, audit *securityaudit.Logger, credentialManager
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) SetHostedService(svc *hosted.Service) {
|
||||
s.hosted = svc
|
||||
}
|
||||
|
||||
func (s *Service) HostedService() *hosted.Service {
|
||||
return s.hosted
|
||||
}
|
||||
|
||||
func (s *Service) DB() *pgxpool.Pool {
|
||||
return s.db
|
||||
}
|
||||
|
||||
257
internal/api/migration/handlers.go
Normal file
257
internal/api/migration/handlers.go
Normal file
@ -0,0 +1,257 @@
|
||||
package migrationapi
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
)
|
||||
|
||||
const maxMigrationRequestBody = 1 << 20
|
||||
|
||||
type Handler struct {
|
||||
svc *migr.Service
|
||||
oauth *migr.OAuthService
|
||||
appURL string
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewHandler(svc *migr.Service, oauth *migr.OAuthService, appURL string) *Handler {
|
||||
return &Handler{
|
||||
svc: svc,
|
||||
oauth: oauth,
|
||||
appURL: strings.TrimRight(strings.TrimSpace(appURL), "/"),
|
||||
logger: slog.Default().With("component", "migration-api"),
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Routes() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
r.Post("/claim", h.ClaimInvite)
|
||||
r.Get("/status", h.GetStatus)
|
||||
r.Get("/oauth/providers", h.ListOAuthProviders)
|
||||
r.Post("/oauth/start", h.StartOAuth)
|
||||
return r
|
||||
}
|
||||
|
||||
func (h *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
||||
if handled, err := h.handleMicrosoftAdminConsentCallback(w, r); handled {
|
||||
if err != nil {
|
||||
h.logger.Error("microsoft admin consent", "error", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
state := r.URL.Query().Get("state")
|
||||
code := r.URL.Query().Get("code")
|
||||
if state == "" || code == "" {
|
||||
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
pending, token, scopes, err := h.oauth.Exchange(r.Context(), state, code)
|
||||
if err != nil {
|
||||
h.logger.Error("oauth exchange", "error", err)
|
||||
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
if err := h.svc.StoreMigrationToken(r.Context(), pending.UserID, pending.ProjectID, pending.Provider, token, scopes); err != nil {
|
||||
h.logger.Error("store migration token", "error", err)
|
||||
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
||||
return
|
||||
}
|
||||
redirect := h.appURL + "/onboard/migration?oauth=success"
|
||||
if pending.InviteToken != "" {
|
||||
redirect = h.appURL + "/onboard/migration?oauth=success&token=" + pending.InviteToken
|
||||
}
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
}
|
||||
|
||||
func (h *Handler) handleMicrosoftAdminConsentCallback(w http.ResponseWriter, r *http.Request) (bool, error) {
|
||||
q := r.URL.Query()
|
||||
projectID := migr.ParseAdminConsentProjectID(q.Get("state"))
|
||||
granted := strings.EqualFold(q.Get("admin_consent"), "True")
|
||||
oauthErr := strings.TrimSpace(q.Get("error"))
|
||||
if !granted && (projectID == "" || oauthErr == "") {
|
||||
return false, nil
|
||||
}
|
||||
if h.oauth == nil {
|
||||
return true, migr.ErrProviderDisabled
|
||||
}
|
||||
record := migr.MicrosoftAdminConsentRecord{
|
||||
TenantID: q.Get("tenant"),
|
||||
ClientID: h.oauth.MicrosoftClientID(),
|
||||
ProjectID: projectID,
|
||||
Granted: granted,
|
||||
ErrorCode: oauthErr,
|
||||
ErrorDescription: q.Get("error_description"),
|
||||
}
|
||||
if err := h.svc.RecordMicrosoftAdminConsent(r.Context(), record); err != nil {
|
||||
redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error"
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
return true, err
|
||||
}
|
||||
redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=success"
|
||||
if !granted {
|
||||
redirect = h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error"
|
||||
}
|
||||
if tenant := strings.TrimSpace(record.TenantID); tenant != "" {
|
||||
redirect += "&tenant=" + url.QueryEscape(tenant)
|
||||
}
|
||||
if projectID != "" {
|
||||
redirect += "&project_id=" + url.QueryEscape(projectID)
|
||||
}
|
||||
http.Redirect(w, r, redirect, http.StatusFound)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (h *Handler) GetInvite(w http.ResponseWriter, r *http.Request) {
|
||||
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
||||
if token == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
||||
Field: "token", Message: "token required",
|
||||
}))
|
||||
return
|
||||
}
|
||||
inv, proj, err := h.svc.GetInviteByToken(r.Context(), token)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"invite": inv,
|
||||
"project": proj,
|
||||
"onboarding": h.svc.BuildInviteOnboardingHints(proj, inv),
|
||||
})
|
||||
}
|
||||
|
||||
type claimRequest struct {
|
||||
Token string `json:"token"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func (h *Handler) ClaimInvite(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
||||
return
|
||||
}
|
||||
var req claimRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
status, err := h.svc.ClaimInvite(r.Context(), req.Token, userID, migr.ClaimIdentityFromAuth(claims), claims.Name, req.Password)
|
||||
if err != nil {
|
||||
code := http.StatusBadRequest
|
||||
errCode := "claim_failed"
|
||||
switch {
|
||||
case err == migr.ErrInviteNotFound:
|
||||
code = http.StatusNotFound
|
||||
errCode = "invite_not_found"
|
||||
case err == migr.ErrInviteClaimed:
|
||||
errCode = "invite_already_claimed"
|
||||
case err == migr.ErrEmailMismatch:
|
||||
errCode = "email_mismatch"
|
||||
case err == migr.ErrMigrationDomainNotActive:
|
||||
errCode = "migration_domain_not_active"
|
||||
case err == migr.ErrMigrationDomainMismatch:
|
||||
errCode = "migration_domain_mismatch"
|
||||
}
|
||||
apiresponse.WriteError(w, r, code, errCode, err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, status)
|
||||
}
|
||||
|
||||
func (h *Handler) GetStatus(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
||||
return
|
||||
}
|
||||
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
status, err := h.svc.GetActiveUserStatus(r.Context(), userID)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, status)
|
||||
}
|
||||
|
||||
func (h *Handler) ListOAuthProviders(w http.ResponseWriter, r *http.Request) {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"providers": h.oauth.EnabledProviders()})
|
||||
}
|
||||
|
||||
type startOAuthRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
InviteToken string `json:"invite_token"`
|
||||
ProjectID string `json:"project_id"`
|
||||
}
|
||||
|
||||
func (h *Handler) StartOAuth(w http.ResponseWriter, r *http.Request) {
|
||||
claims := middleware.ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
||||
return
|
||||
}
|
||||
var req startOAuthRequest
|
||||
if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil {
|
||||
return
|
||||
}
|
||||
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
||||
if err != nil {
|
||||
apivalidate.WriteInternal(w, r)
|
||||
return
|
||||
}
|
||||
projectID := req.ProjectID
|
||||
if projectID == "" && req.InviteToken != "" {
|
||||
inv, proj, err := h.svc.GetInviteByToken(r.Context(), req.InviteToken)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
if inv.UserID != "" && inv.UserID != userID {
|
||||
apiresponse.WriteError(w, r, http.StatusForbidden, "forbidden", "invite belongs to another user", nil)
|
||||
return
|
||||
}
|
||||
projectID = proj.ID
|
||||
}
|
||||
if projectID == "" {
|
||||
st, err := h.svc.GetActiveUserStatus(r.Context(), userID)
|
||||
if err != nil || st.Project.ID == "" {
|
||||
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
||||
Field: "project_id", Message: "project_id required",
|
||||
}))
|
||||
return
|
||||
}
|
||||
projectID = st.Project.ID
|
||||
}
|
||||
provider := migr.Provider(strings.ToLower(strings.TrimSpace(req.Provider)))
|
||||
if provider == "" {
|
||||
provider = migr.ProviderGoogle
|
||||
}
|
||||
authURL, _, err := h.oauth.Start(r.Context(), migr.PendingOAuth{
|
||||
UserID: userID,
|
||||
ProjectID: projectID,
|
||||
InviteToken: req.InviteToken,
|
||||
}, provider)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "oauth_start_failed", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"auth_url": authURL})
|
||||
}
|
||||
@ -12,14 +12,16 @@ import (
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
Sub string
|
||||
Email string
|
||||
Name string
|
||||
Groups []string
|
||||
Source string
|
||||
HD string
|
||||
TID string
|
||||
Org string
|
||||
Sub string
|
||||
Email string
|
||||
PreferredUsername string
|
||||
UPN string
|
||||
Name string
|
||||
Groups []string
|
||||
Source string
|
||||
HD string
|
||||
TID string
|
||||
Org string
|
||||
}
|
||||
|
||||
type Verifier struct {
|
||||
@ -96,27 +98,31 @@ func (v *Verifier) Verify(ctx context.Context, rawToken string) (*Claims, error)
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Groups []string `json:"groups"`
|
||||
HD string `json:"hd"`
|
||||
TID string `json:"tid"`
|
||||
Org string `json:"org"`
|
||||
Source string `json:"ak-source"`
|
||||
Sub string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
UPN string `json:"upn"`
|
||||
Name string `json:"name"`
|
||||
Groups []string `json:"groups"`
|
||||
HD string `json:"hd"`
|
||||
TID string `json:"tid"`
|
||||
Org string `json:"org"`
|
||||
Source string `json:"ak-source"`
|
||||
}
|
||||
if err := token.Claims(&claims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Claims{
|
||||
Sub: claims.Sub,
|
||||
Email: claims.Email,
|
||||
Name: claims.Name,
|
||||
Groups: claims.Groups,
|
||||
HD: claims.HD,
|
||||
TID: claims.TID,
|
||||
Org: claims.Org,
|
||||
Source: claims.Source,
|
||||
Sub: claims.Sub,
|
||||
Email: claims.Email,
|
||||
PreferredUsername: claims.PreferredUsername,
|
||||
UPN: claims.UPN,
|
||||
Name: claims.Name,
|
||||
Groups: claims.Groups,
|
||||
HD: claims.HD,
|
||||
TID: claims.TID,
|
||||
Org: claims.Org,
|
||||
Source: claims.Source,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -115,6 +115,38 @@ type Config struct {
|
||||
MailOAuthRedirectURL string
|
||||
MailAppURL string
|
||||
|
||||
// Stalwart hosted mail
|
||||
StalwartEnabled bool
|
||||
StalwartAPIURL string
|
||||
StalwartAPIKey string
|
||||
StalwartIMAPHost string
|
||||
StalwartIMAPPort int
|
||||
StalwartIMAPTLS bool
|
||||
StalwartSMTPHost string
|
||||
StalwartSMTPPort int
|
||||
StalwartSMTPTLS bool
|
||||
PlatformMailDomain string
|
||||
ProvisionWebhookSecret string
|
||||
|
||||
// Migration OAuth (Google/Microsoft bulk import)
|
||||
MigrationGoogleOAuthClientID string
|
||||
MigrationGoogleOAuthClientSecret string
|
||||
MigrationMicrosoftOAuthClientID string
|
||||
MigrationMicrosoftOAuthSecret string
|
||||
MigrationMicrosoftOAuthTenant string
|
||||
MigrationOAuthRedirectURL string
|
||||
MigrationWorkerInterval time.Duration
|
||||
MigrationGoogleServiceAccountJSON string
|
||||
MigrationRateLimitMaxRetries int
|
||||
MigrationRateLimitBaseDelay time.Duration
|
||||
MigrationRateLimitMaxDelay time.Duration
|
||||
MigrationWorkerConcurrency int
|
||||
MigrationWorkerJobLimit int
|
||||
MigrationImportBatchSize int
|
||||
MigrationDriveBatchSize int
|
||||
MigrationCutoverMXHosts string
|
||||
MigrationCutoverRequireMX bool
|
||||
|
||||
// Secret rotation policy
|
||||
SecretRotationMaxAge time.Duration
|
||||
OIDCSecretRotatedAt time.Time
|
||||
@ -236,6 +268,36 @@ func Load() (*Config, error) {
|
||||
MailOAuthRedirectURL: os.Getenv("MAIL_OAUTH_REDIRECT_URL"),
|
||||
MailAppURL: envOrDefault("MAIL_APP_URL", envOrDefault("NEXT_PUBLIC_APP_URL", "http://localhost:3004")),
|
||||
|
||||
StalwartEnabled: envBool("STALWART_ENABLED", false),
|
||||
StalwartAPIURL: envOrDefault("STALWART_API_URL", "http://stalwart:8080"),
|
||||
StalwartAPIKey: secrets.Env("STALWART_API_KEY"),
|
||||
StalwartIMAPHost: envOrDefault("STALWART_IMAP_HOST", "stalwart"),
|
||||
StalwartIMAPPort: envInt("STALWART_IMAP_PORT", 993),
|
||||
StalwartIMAPTLS: envBool("STALWART_IMAP_TLS", true),
|
||||
StalwartSMTPHost: envOrDefault("STALWART_SMTP_HOST", "stalwart"),
|
||||
StalwartSMTPPort: envInt("STALWART_SMTP_PORT", 587),
|
||||
StalwartSMTPTLS: envBool("STALWART_SMTP_TLS", true),
|
||||
PlatformMailDomain: envOrDefault("PLATFORM_MAIL_DOMAIN", "ultisuite.fr"),
|
||||
ProvisionWebhookSecret: secrets.Env("PROVISION_WEBHOOK_SECRET"),
|
||||
|
||||
MigrationGoogleOAuthClientID: os.Getenv("MIGRATION_GOOGLE_OAUTH_CLIENT_ID"),
|
||||
MigrationGoogleOAuthClientSecret: secrets.Env("MIGRATION_GOOGLE_OAUTH_CLIENT_SECRET"),
|
||||
MigrationMicrosoftOAuthClientID: os.Getenv("MIGRATION_MICROSOFT_OAUTH_CLIENT_ID"),
|
||||
MigrationMicrosoftOAuthSecret: secrets.Env("MIGRATION_MICROSOFT_OAUTH_CLIENT_SECRET"),
|
||||
MigrationMicrosoftOAuthTenant: envOrDefault("MIGRATION_MICROSOFT_OAUTH_TENANT", "common"),
|
||||
MigrationOAuthRedirectURL: os.Getenv("MIGRATION_OAUTH_REDIRECT_URL"),
|
||||
MigrationWorkerInterval: envDuration("MIGRATION_WORKER_INTERVAL", 30*time.Second),
|
||||
MigrationGoogleServiceAccountJSON: secrets.Env("MIGRATION_GOOGLE_SERVICE_ACCOUNT_JSON"),
|
||||
MigrationRateLimitMaxRetries: envInt("MIGRATION_RATE_LIMIT_MAX_RETRIES", 6),
|
||||
MigrationRateLimitBaseDelay: envDuration("MIGRATION_RATE_LIMIT_BASE_DELAY", 2*time.Second),
|
||||
MigrationRateLimitMaxDelay: envDuration("MIGRATION_RATE_LIMIT_MAX_DELAY", 2*time.Minute),
|
||||
MigrationWorkerConcurrency: envInt("MIGRATION_WORKER_CONCURRENCY", 2),
|
||||
MigrationWorkerJobLimit: envInt("MIGRATION_WORKER_JOB_LIMIT", 0),
|
||||
MigrationImportBatchSize: envInt("MIGRATION_IMPORT_BATCH_SIZE", 25),
|
||||
MigrationDriveBatchSize: envInt("MIGRATION_DRIVE_BATCH_SIZE", 10),
|
||||
MigrationCutoverMXHosts: os.Getenv("MIGRATION_CUTOVER_MX_HOSTS"),
|
||||
MigrationCutoverRequireMX: envBool("MIGRATION_CUTOVER_REQUIRE_MX", false),
|
||||
|
||||
SecretRotationMaxAge: envDuration("SECRET_ROTATION_MAX_AGE", 90*24*time.Hour),
|
||||
OIDCSecretRotatedAt: envTime("ULTID_OIDC_CLIENT_SECRET_ROTATED_AT"),
|
||||
SMTPCredentialKeyRotatedAt: envTime("MAIL_CREDENTIAL_KEY_ROTATED_AT"),
|
||||
|
||||
@ -5,6 +5,7 @@ package integrationtest
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
@ -13,23 +14,28 @@ import (
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/dbmigrate"
|
||||
"github.com/ultisuite/ulti-backend/internal/server"
|
||||
mailstorage "github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
)
|
||||
|
||||
// Harness is the shared integration test environment.
|
||||
type Harness struct {
|
||||
Env Env
|
||||
Infra *infra
|
||||
OIDC *OIDCServer
|
||||
App *server.App
|
||||
Server *httptest.Server
|
||||
Pool *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
Env Env
|
||||
Infra *infra
|
||||
OIDC *OIDCServer
|
||||
App *server.App
|
||||
Server *httptest.Server
|
||||
Pool *pgxpool.Pool
|
||||
Redis *redis.Client
|
||||
AttachmentStorage *mailstorage.Client
|
||||
AttachmentsBucket string
|
||||
}
|
||||
|
||||
var (
|
||||
@ -128,6 +134,20 @@ func newHarness(ctx context.Context) (*Harness, error) {
|
||||
|
||||
cfg := buildTestConfig(env, infra, oidc)
|
||||
|
||||
minioClient, err := minio.New(cfg.RustFSEndpoint, &minio.Options{
|
||||
Creds: credentials.NewStaticV4(cfg.RustFSAccessKey, cfg.RustFSSecretKey, ""),
|
||||
Secure: cfg.RustFSUseSSL,
|
||||
})
|
||||
if err != nil {
|
||||
pool.Close()
|
||||
_ = rdb.Close()
|
||||
return nil, fmt.Errorf("rustfs client: %w", err)
|
||||
}
|
||||
attachmentStorage := mailstorage.NewClient(minioClient, cfg.MailAttachmentsBucket)
|
||||
if err := attachmentStorage.EnsureBucket(ctx); err != nil {
|
||||
slog.Warn("mail attachments bucket check failed", "error", err)
|
||||
}
|
||||
|
||||
app, err := server.New(ctx, cfg, server.Options{
|
||||
WithoutWorkers: true,
|
||||
SkipAuthentikProvisioner: true,
|
||||
@ -144,13 +164,15 @@ func newHarness(ctx context.Context) (*Harness, error) {
|
||||
ts := httptest.NewServer(app.Router)
|
||||
|
||||
return &Harness{
|
||||
Env: env,
|
||||
Infra: infra,
|
||||
OIDC: oidc,
|
||||
App: app,
|
||||
Server: ts,
|
||||
Pool: pool,
|
||||
Redis: rdb,
|
||||
Env: env,
|
||||
Infra: infra,
|
||||
OIDC: oidc,
|
||||
App: app,
|
||||
Server: ts,
|
||||
Pool: pool,
|
||||
Redis: rdb,
|
||||
AttachmentStorage: attachmentStorage,
|
||||
AttachmentsBucket: cfg.MailAttachmentsBucket,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
129
internal/integrationtest/migration/claim_email_test.go
Normal file
129
internal/integrationtest/migration/claim_email_test.go
Normal file
@ -0,0 +1,129 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func TestClaimInviteFlexibleEmailMatch(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Flexible email match",
|
||||
"source_provider": "microsoft",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
inviteEmail := "alice-" + uuid.NewString()[:8] + "@example.com"
|
||||
ssoEmail := "alice.sso-" + uuid.NewString()[:8] + "@example.com"
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]any{
|
||||
"email": inviteEmail,
|
||||
"alternate_emails": []string{ssoEmail},
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
|
||||
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("flex-claim"))
|
||||
migrateeClaims.Email = ssoEmail
|
||||
migrateeClaims.PreferredUsername = inviteEmail
|
||||
migrateeClient, err := h.Client(migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "migratee client")
|
||||
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, migrateeClaims); err != nil {
|
||||
t.Fatalf("ensure migratee: %v", err)
|
||||
}
|
||||
|
||||
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
"password": "test-password-123",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
integrationtest.FailUnlessStatus(t, claimResp, 200)
|
||||
}
|
||||
|
||||
func TestClaimInviteRejectsEmailMismatch(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Reject mismatch",
|
||||
"source_provider": "google",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
inviteEmail := "victim-" + uuid.NewString() + "@example.com"
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
|
||||
"email": inviteEmail,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
|
||||
attackerClaims := integrationtest.RegularUser(integrationtest.NewExternalID("flex-claim-bad"))
|
||||
attackerClaims.Email = "attacker-" + uuid.NewString() + "@example.com"
|
||||
attackerClient, err := h.Client(attackerClaims)
|
||||
integrationtest.FailIf(err, t, "attacker client")
|
||||
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, attackerClaims); err != nil {
|
||||
t.Fatalf("ensure attacker: %v", err)
|
||||
}
|
||||
|
||||
claimResp, err := attackerClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
integrationtest.AssertErrorCode(t, claimResp, 400, "email_mismatch")
|
||||
}
|
||||
562
internal/integrationtest/migration/delta_test.go
Normal file
562
internal/integrationtest/migration/delta_test.go
Normal file
@ -0,0 +1,562 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func TestMigrationCutoverResetsCompletedJobs(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Cutover test",
|
||||
"source_provider": "google",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
migrateeEmail := "cutover-" + created.ID[:8] + "@example.com"
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
|
||||
"email": migrateeEmail,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
|
||||
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("cutover"))
|
||||
migrateeClaims.Email = migrateeEmail
|
||||
migrateeClient, err := h.Client(migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "migratee client")
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "ensure migratee")
|
||||
|
||||
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
"password": "test-password-123",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
integrationtest.FailUnlessStatus(t, claimResp, 200)
|
||||
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
UPDATE migration_jobs SET status = 'completed', updated_at = NOW()
|
||||
WHERE project_id = $1::uuid AND user_id = $2::uuid
|
||||
`, created.ID, userID)
|
||||
integrationtest.FailIf(err, t, "mark jobs completed")
|
||||
|
||||
cutoverResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/cutover", nil)
|
||||
integrationtest.FailIf(err, t, "cutover")
|
||||
integrationtest.FailUnlessStatus(t, cutoverResp, 200)
|
||||
|
||||
var cutover struct {
|
||||
Project struct {
|
||||
Status string `json:"status"`
|
||||
DeltaMode bool `json:"delta_mode"`
|
||||
CutoverAt *string `json:"cutover_at"`
|
||||
} `json:"project"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, cutoverResp, &cutover)
|
||||
project := cutover.Project
|
||||
if project.Status != "cutover" || !project.DeltaMode || project.CutoverAt == nil {
|
||||
t.Fatalf("cutover project: %#v", project)
|
||||
}
|
||||
|
||||
var pendingCount int
|
||||
if err := h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM migration_jobs
|
||||
WHERE project_id = $1::uuid AND user_id = $2::uuid AND status = 'pending'
|
||||
`, created.ID, userID).Scan(&pendingCount); err != nil {
|
||||
t.Fatalf("count pending jobs: %v", err)
|
||||
}
|
||||
if pendingCount != 4 {
|
||||
t.Fatalf("pending jobs = %d, want 4", pendingCount)
|
||||
}
|
||||
|
||||
listResp, err := adminClient.Get("/api/v1/admin/migration/projects")
|
||||
integrationtest.FailIf(err, t, "list projects after cutover")
|
||||
integrationtest.FailUnlessStatus(t, listResp, 200)
|
||||
|
||||
var listed struct {
|
||||
Projects []struct {
|
||||
ID string `json:"id"`
|
||||
CutoverDNS *struct {
|
||||
Warnings []string `json:"warnings"`
|
||||
} `json:"cutover_dns"`
|
||||
} `json:"projects"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, listResp, &listed)
|
||||
var found bool
|
||||
for _, p := range listed.Projects {
|
||||
if p.ID != created.ID {
|
||||
continue
|
||||
}
|
||||
found = true
|
||||
if p.CutoverDNS == nil {
|
||||
t.Fatal("expected cutover_dns on listed project")
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatal("cutover project not found in list")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleContactsDeltaDeletesRemoved(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "contacts-delta")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/v1/people/me/connections") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"connections":[{
|
||||
"resourceName":"people/deleted-1",
|
||||
"metadata":{"deleted":true}
|
||||
}],
|
||||
"nextSyncToken":"sync-next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewContactsImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"syncToken": "sync-old",
|
||||
"imported_ids": map[string]any{
|
||||
"people/deleted-1": true,
|
||||
},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "google", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
if status != "completed" {
|
||||
t.Fatalf("status = %q, want completed", status)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
if cursor["syncToken"] != "sync-next" {
|
||||
t.Fatalf("sync token = %v", cursor["syncToken"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestGoogleCalendarDeltaUpdatesExisting(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "calendar-delta-update")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/calendar/v3/users/me/calendarList"):
|
||||
_, _ = w.Write([]byte(`{"items":[{"id":"primary","summary":"Primary"}]}`))
|
||||
case strings.Contains(r.URL.Path, "/calendar/v3/calendars/") && strings.Contains(r.URL.Path, "/events"):
|
||||
_, _ = w.Write([]byte(`{
|
||||
"items":[{"id":"evt-1","status":"confirmed","summary":"Updated meeting","start":{"dateTime":"2026-06-13T10:00:00Z"},"end":{"dateTime":"2026-06-13T11:00:00Z"}}],
|
||||
"nextSyncToken":"cal-sync-next"
|
||||
}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
importer := migr.NewCalendarImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"calendarSyncTokens": map[string]any{"primary": "cal-sync-old"},
|
||||
"imported_ids": map[string]any{"primary:evt-1": true},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "google", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
updated, _ := stats["delta_updated"].(float64)
|
||||
if updated != 1 {
|
||||
t.Fatalf("delta_updated = %v, want 1", stats["delta_updated"])
|
||||
}
|
||||
imported, _ := stats["delta_imported"].(float64)
|
||||
if imported != 0 {
|
||||
t.Fatalf("delta_imported = %v, want 0", stats["delta_imported"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestGoogleContactsDeltaUpdatesExisting(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "contacts-delta-update")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/v1/people/me/connections") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"connections":[{
|
||||
"resourceName":"people/abc",
|
||||
"names":[{"displayName":"Alice Updated"}],
|
||||
"emailAddresses":[{"value":"alice@example.com"}]
|
||||
}],
|
||||
"nextSyncToken":"sync-next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewContactsImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"syncToken": "sync-old",
|
||||
"imported_ids": map[string]any{
|
||||
"people/abc": true,
|
||||
},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "google", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
updated, _ := stats["delta_updated"].(float64)
|
||||
if updated != 1 {
|
||||
t.Fatalf("delta_updated = %v, want 1", stats["delta_updated"])
|
||||
}
|
||||
imported, _ := stats["delta_imported"].(float64)
|
||||
if imported != 0 {
|
||||
t.Fatalf("delta_imported = %v, want 0", stats["delta_imported"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestGoogleDriveDeltaDeletesRemovedFile(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "drive-delta")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/drive/v3/changes") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"changes":[{"fileId":"file-removed","removed":true}],
|
||||
"newStartPageToken":"token-next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewDriveImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"driveChangeToken": "token-old",
|
||||
"imported_paths": map[string]any{
|
||||
"file-removed": "Docs/report.docx",
|
||||
},
|
||||
"imported_ids": map[string]any{
|
||||
"file-removed": true,
|
||||
},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "google", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
if status != "completed" {
|
||||
t.Fatalf("status = %q, want completed", status)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
if cursor["driveChangeToken"] != "token-next" {
|
||||
t.Fatalf("change token = %v", cursor["driveChangeToken"])
|
||||
}
|
||||
if _, ok := cursor["imported_ids"]; ok {
|
||||
t.Fatal("expected imported_ids stripped from cursor")
|
||||
}
|
||||
if _, ok := cursor["imported_paths"]; ok {
|
||||
t.Fatal("expected imported_paths stripped from cursor")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestGoogleCalendarDeltaDeletesCancelled(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "calendar-delta")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/calendar/v3/users/me/calendarList"):
|
||||
_, _ = w.Write([]byte(`{"items":[{"id":"primary","summary":"Primary"}]}`))
|
||||
case strings.Contains(r.URL.Path, "/calendar/v3/calendars/") && strings.Contains(r.URL.Path, "/events"):
|
||||
_, _ = w.Write([]byte(`{
|
||||
"items":[{"id":"evt-1","status":"cancelled","summary":"Old meeting"}],
|
||||
"nextSyncToken":"cal-sync-next"
|
||||
}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
importer := migr.NewCalendarImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"calendarSyncTokens": map[string]any{"primary": "cal-sync-old"},
|
||||
"imported_ids": map[string]any{"primary:evt-1": true},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "google", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestMicrosoftContactsDeltaRemoved(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, email := insertMigrationTestUser(t, h.Pool, "ms-contacts-delta")
|
||||
nc, _ := mockNextcloudClient(t, h.Pool, email)
|
||||
|
||||
client := graphRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/v1.0/me/contacts") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"value":[{"id":"c-1","@removed":{"reason":"deleted"}}],
|
||||
"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/contacts/delta?token=next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewContactsImporter(h.Pool, nc).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{
|
||||
"deltaLink": "https://graph.microsoft.com/v1.0/me/contacts/delta?token=old",
|
||||
"imported_ids": map[string]any{
|
||||
"c-1": true,
|
||||
},
|
||||
},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err := importer.ImportBatch(ctx, job, "token", "microsoft", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
}
|
||||
|
||||
func TestGraphMailDeltaDeletesRemoved(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("graph-delta-mail")))
|
||||
integrationtest.FailIf(err, t, "ensure user")
|
||||
|
||||
var accountID string
|
||||
err = h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_accounts (user_id, email, provider, is_active)
|
||||
VALUES ($1::uuid, 'graph-delta@test.local', 'hosted', true)
|
||||
RETURNING id::text
|
||||
`, userID).Scan(&accountID)
|
||||
integrationtest.FailIf(err, t, "insert mail account")
|
||||
|
||||
uid := migr.RemoteMessageUIDForTest("msg-removed-1")
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, date, snippet, body_text, body_html, flags, labels)
|
||||
SELECT $1::uuid, f.id, $2, '<test@local>', 'To delete', '[]', '[]', NOW(), '', '', '', '{}', '{}'
|
||||
FROM mail_folders f WHERE f.account_id = $1::uuid AND f.remote_name = 'INBOX' LIMIT 1
|
||||
`, accountID, uid)
|
||||
integrationtest.FailIf(err, t, "seed message")
|
||||
|
||||
client := graphRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/mailFolders") {
|
||||
_, _ = w.Write([]byte(`{"value":[{"id":"inbox-id","displayName":"Inbox","wellKnownName":"inbox"}]}`))
|
||||
return
|
||||
}
|
||||
if strings.Contains(r.URL.Path, "/messages") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"value":[{"id":"msg-removed-1","@removed":{"reason":"deleted"}}],
|
||||
"@odata.deltaLink":"https://graph.microsoft.com/v1.0/me/messages/delta?token=next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewGraphImporter(h.Pool).WithHTTPClient(client).WithBaseURL("https://graph.microsoft.com")
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{"deltaLink": "https://graph.microsoft.com/v1.0/me/messages/delta?token=old"},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err = importer.ImportBatch(ctx, job, "token", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
if status != "completed" {
|
||||
t.Fatalf("status = %q, want completed", status)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
|
||||
var count int
|
||||
if err := h.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid).Scan(&count); err != nil {
|
||||
t.Fatalf("count messages: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("message count = %d, want 0", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailHistoryDeltaDeletesMessage(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("gmail-delta-mail")))
|
||||
integrationtest.FailIf(err, t, "ensure user")
|
||||
|
||||
var accountID string
|
||||
err = h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_accounts (user_id, email, provider, is_active)
|
||||
VALUES ($1::uuid, 'gmail-delta@test.local', 'hosted', true)
|
||||
RETURNING id::text
|
||||
`, userID).Scan(&accountID)
|
||||
integrationtest.FailIf(err, t, "insert mail account")
|
||||
|
||||
gmailID := "abc123deleted"
|
||||
uid := migr.GmailUIDForTest(gmailID)
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, date, snippet, body_text, body_html, flags, labels)
|
||||
SELECT $1::uuid, f.id, $2, '<test@local>', 'To delete', '[]', '[]', NOW(), '', '', '', '{}', '{}'
|
||||
FROM mail_folders f WHERE f.account_id = $1::uuid AND f.remote_name = 'INBOX' LIMIT 1
|
||||
`, accountID, uid)
|
||||
integrationtest.FailIf(err, t, "seed message")
|
||||
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/gmail/v1/users/me/history") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"history":[{"messagesDeleted":[{"message":{"id":"` + gmailID + `"}}]}],
|
||||
"historyId":"99999"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
importer := migr.NewGmailImporter(h.Pool).WithHTTPClient(client)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{"historyId": "88888"},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
|
||||
err = importer.ImportBatch(ctx, job, "token", true, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
deleted, _ := stats["delta_deleted"].(float64)
|
||||
if deleted != 1 {
|
||||
t.Fatalf("delta_deleted = %v, want 1", stats["delta_deleted"])
|
||||
}
|
||||
if cursor["historyId"] != "99999" {
|
||||
t.Fatalf("historyId = %v", cursor["historyId"])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
|
||||
var count int
|
||||
if err := h.Pool.QueryRow(ctx, `SELECT COUNT(*) FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid).Scan(&count); err != nil {
|
||||
t.Fatalf("count messages: %v", err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatalf("message count = %d, want 0", count)
|
||||
}
|
||||
}
|
||||
159
internal/integrationtest/migration/domain_claim_test.go
Normal file
159
internal/integrationtest/migration/domain_claim_test.go
Normal file
@ -0,0 +1,159 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func TestClaimInviteRequiresActiveProjectDomain(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
domainName := "migration-" + uuid.NewString()[:8] + ".test"
|
||||
var domainID string
|
||||
err := h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_domains (name, status, is_platform_domain)
|
||||
VALUES ($1, 'pending_verification', false)
|
||||
RETURNING id::text
|
||||
`, domainName).Scan(&domainID)
|
||||
integrationtest.FailIf(err, t, "insert domain")
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Domain-bound migration",
|
||||
"source_provider": "google",
|
||||
"domain_id": domainID,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
migrateeEmail := "user@" + domainName
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
|
||||
"email": migrateeEmail,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
|
||||
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("domain-claim"))
|
||||
migrateeClaims.Email = migrateeEmail
|
||||
migrateeClient, err := h.Client(migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "migratee client")
|
||||
|
||||
_, err = users.EnsureUser(ctx, h.Pool, migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "ensure migratee")
|
||||
|
||||
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
"password": "test-password-123",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
if claimResp.Status != 400 {
|
||||
t.Fatalf("status = %d, want 400 for inactive domain; body=%s", claimResp.Status, string(claimResp.Body))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaimInviteWithActiveDomainSucceeds(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
domainName := "active-" + uuid.NewString()[:8] + ".test"
|
||||
var domainID string
|
||||
err := h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_domains (name, status, is_platform_domain)
|
||||
VALUES ($1, 'active', false)
|
||||
RETURNING id::text
|
||||
`, domainName).Scan(&domainID)
|
||||
integrationtest.FailIf(err, t, "insert domain")
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Active domain migration",
|
||||
"source_provider": "google",
|
||||
"domain_id": domainID,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
migrateeEmail := "user@" + domainName
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
|
||||
"email": migrateeEmail,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
|
||||
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("domain-claim-ok"))
|
||||
migrateeClaims.Email = migrateeEmail
|
||||
migrateeClient, err := h.Client(migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "migratee client")
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "ensure migratee")
|
||||
|
||||
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
"password": "test-password-123",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
integrationtest.FailUnlessStatus(t, claimResp, 200)
|
||||
|
||||
var mailboxCount int
|
||||
if err := h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM mailboxes mb
|
||||
JOIN mail_domains md ON md.id = mb.domain_id
|
||||
WHERE mb.user_id = $1::uuid AND md.id = $2::uuid
|
||||
`, userID, domainID).Scan(&mailboxCount); err != nil {
|
||||
t.Fatalf("count mailboxes: %v", err)
|
||||
}
|
||||
if mailboxCount != 1 {
|
||||
t.Fatalf("mailbox count = %d, want 1", mailboxCount)
|
||||
}
|
||||
}
|
||||
117
internal/integrationtest/migration/helpers_test.go
Normal file
117
internal/integrationtest/migration/helpers_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func mockNextcloudServer(t *testing.T) *httptest.Server {
|
||||
t.Helper()
|
||||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodPut, "MKCOL", "MKCALENDAR":
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
case http.MethodDelete:
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
case "PROPFIND", "REPORT", http.MethodGet:
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`<?xml version="1.0"?><d:multistatus xmlns:d="DAV:"></d:multistatus>`))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
func testCredentialManager(t *testing.T) *credentials.Manager {
|
||||
t.Helper()
|
||||
mgr, err := credentials.NewManager(
|
||||
"v1:MDEyMzQ1Njc4OWFiY2RlZjAxMjM0NTY3ODlhYmNkZWY=",
|
||||
"v1",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("credential manager: %v", err)
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
func mockNextcloudClient(t *testing.T, pool *pgxpool.Pool, email string) (*nextcloud.Client, string) {
|
||||
t.Helper()
|
||||
srv := mockNextcloudServer(t)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
ncUserID := nextcloud.UserIDFromClaims(email, "")
|
||||
store := nextcloud.NewDAVCredentialStore(pool, testCredentialManager(t))
|
||||
if err := store.SaveToken(context.Background(), ncUserID, "mock-app-password"); err != nil {
|
||||
t.Fatalf("seed dav token: %v", err)
|
||||
}
|
||||
client := nextcloud.NewClient(srv.URL, "admin", "admin").WithDAVCredentials(store)
|
||||
return client, ncUserID
|
||||
}
|
||||
|
||||
func insertMigrationTestUser(t *testing.T, pool *pgxpool.Pool, prefix string) (userID, email string) {
|
||||
t.Helper()
|
||||
claims := integrationtest.RegularUser(integrationtest.NewExternalID(prefix))
|
||||
claims.Email = prefix + "-" + uuid.NewString() + "@migration.test"
|
||||
id, err := users.EnsureUser(context.Background(), pool, claims)
|
||||
integrationtest.FailIf(err, t, "ensure user")
|
||||
return id, claims.Email
|
||||
}
|
||||
|
||||
func googleRewriteClient(t *testing.T, handler http.HandlerFunc) *http.Client {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
return &http.Client{
|
||||
Transport: &hostRewriteTransport{
|
||||
mockBase: srv.URL,
|
||||
match: func(host string) bool {
|
||||
return strings.Contains(host, "googleapis.com")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func graphRewriteClient(t *testing.T, handler http.HandlerFunc) *http.Client {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
return &http.Client{
|
||||
Transport: &hostRewriteTransport{
|
||||
mockBase: srv.URL,
|
||||
match: func(host string) bool {
|
||||
return strings.Contains(host, "graph.microsoft.com")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type hostRewriteTransport struct {
|
||||
mockBase string
|
||||
match func(host string) bool
|
||||
}
|
||||
|
||||
func (rt *hostRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if rt.match(req.URL.Host) {
|
||||
mockURL, err := url.Parse(rt.mockBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.URL.Scheme = mockURL.Scheme
|
||||
req.URL.Host = mockURL.Host
|
||||
}
|
||||
return http.DefaultTransport.RoundTrip(req)
|
||||
}
|
||||
388
internal/integrationtest/migration/migration_test.go
Normal file
388
internal/integrationtest/migration/migration_test.go
Normal file
@ -0,0 +1,388 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
func TestMigrationInviteClaimFlow(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
adminClient, adminClaims := integrationtest.RequireAdminClient(t, h)
|
||||
if _, err := users.EnsureUser(ctx, h.Pool, adminClaims); err != nil {
|
||||
t.Fatalf("ensure admin: %v", err)
|
||||
}
|
||||
if err := users.GrantPlatformAdmin(ctx, h.Pool, adminClaims.Sub); err != nil {
|
||||
t.Fatalf("grant admin: %v", err)
|
||||
}
|
||||
|
||||
createResp, err := adminClient.Post("/api/v1/admin/migration/projects", map[string]any{
|
||||
"name": "Test migration",
|
||||
"source_provider": "microsoft",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create project")
|
||||
integrationtest.FailUnlessStatus(t, createResp, 201)
|
||||
|
||||
var created struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, createResp, &created)
|
||||
if created.ID == "" {
|
||||
t.Fatalf("missing project id")
|
||||
}
|
||||
|
||||
actResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/activate", nil)
|
||||
integrationtest.FailIf(err, t, "activate project")
|
||||
integrationtest.FailUnlessStatus(t, actResp, 200)
|
||||
|
||||
migrateeEmail := "migratee-" + uuid.NewString() + "@example.com"
|
||||
inviteResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/invites", map[string]string{
|
||||
"email": migrateeEmail,
|
||||
})
|
||||
integrationtest.FailIf(err, t, "create invite")
|
||||
integrationtest.FailUnlessStatus(t, inviteResp, 201)
|
||||
|
||||
var invite struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, inviteResp, &invite)
|
||||
if invite.Token == "" {
|
||||
t.Fatalf("missing invite token")
|
||||
}
|
||||
|
||||
migrateeClaims := integrationtest.RegularUser(integrationtest.NewExternalID("migratee"))
|
||||
migrateeClaims.Email = migrateeEmail
|
||||
migrateeClient, err := h.Client(migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "migratee client")
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, migrateeClaims)
|
||||
integrationtest.FailIf(err, t, "ensure migratee")
|
||||
|
||||
claimResp, err := migrateeClient.Post("/api/v1/migration/claim", map[string]string{
|
||||
"token": invite.Token,
|
||||
"password": "test-password-123",
|
||||
})
|
||||
integrationtest.FailIf(err, t, "claim invite")
|
||||
integrationtest.FailUnlessStatus(t, claimResp, 200)
|
||||
|
||||
var status struct {
|
||||
Jobs []struct {
|
||||
Service string `json:"service"`
|
||||
Status string `json:"status"`
|
||||
} `json:"jobs"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, claimResp, &status)
|
||||
if len(status.Jobs) != 4 {
|
||||
t.Fatalf("expected 4 jobs, got %d", len(status.Jobs))
|
||||
}
|
||||
|
||||
var jobCount int
|
||||
if err := h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM migration_jobs WHERE project_id = $1::uuid AND user_id = $2::uuid
|
||||
`, created.ID, userID).Scan(&jobCount); err != nil {
|
||||
t.Fatalf("count jobs: %v", err)
|
||||
}
|
||||
if jobCount != 4 {
|
||||
t.Fatalf("db job count = %d, want 4", jobCount)
|
||||
}
|
||||
|
||||
jobsResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs")
|
||||
integrationtest.FailIf(err, t, "list admin jobs")
|
||||
integrationtest.FailUnlessStatus(t, jobsResp, 200)
|
||||
|
||||
var adminJobs struct {
|
||||
Jobs []struct {
|
||||
ID string `json:"id"`
|
||||
Service string `json:"service"`
|
||||
Status string `json:"status"`
|
||||
Email string `json:"user_email"`
|
||||
} `json:"jobs"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, jobsResp, &adminJobs)
|
||||
if len(adminJobs.Jobs) != 4 {
|
||||
t.Fatalf("admin jobs = %d, want 4", len(adminJobs.Jobs))
|
||||
}
|
||||
for _, job := range adminJobs.Jobs {
|
||||
if job.Email != migrateeEmail {
|
||||
t.Fatalf("user_email = %q, want %q", job.Email, migrateeEmail)
|
||||
}
|
||||
}
|
||||
|
||||
var mailJobID string
|
||||
for _, job := range adminJobs.Jobs {
|
||||
if job.Service == "mail" {
|
||||
mailJobID = job.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
if mailJobID == "" {
|
||||
t.Fatal("mail job not found")
|
||||
}
|
||||
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
UPDATE migration_jobs SET status = 'failed', error = 'simulated failure', updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, mailJobID)
|
||||
integrationtest.FailIf(err, t, "mark job failed")
|
||||
|
||||
retryResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/jobs/"+mailJobID+"/retry", nil)
|
||||
integrationtest.FailIf(err, t, "retry job")
|
||||
integrationtest.FailUnlessStatus(t, retryResp, 200)
|
||||
|
||||
var retried struct {
|
||||
Status string `json:"status"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, retryResp, &retried)
|
||||
if retried.Status != "pending" {
|
||||
t.Fatalf("retried status = %q, want pending", retried.Status)
|
||||
}
|
||||
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
UPDATE migration_jobs
|
||||
SET cursor_json = '{"historyId":"123"}'::jsonb,
|
||||
stats_json = '{"imported":42}'::jsonb,
|
||||
status = 'completed',
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, mailJobID)
|
||||
integrationtest.FailIf(err, t, "seed job cursor")
|
||||
|
||||
_, err = h.Pool.Exec(ctx, `
|
||||
INSERT INTO migration_imported_items (job_id, source_id, status, reason)
|
||||
VALUES ($1::uuid, 'msg-abc', 'imported', ''),
|
||||
($1::uuid, 'msg-fail', 'failed', 'upload timeout'),
|
||||
($1::uuid, 'msg-skip', 'skipped', 'file too large')
|
||||
`, mailJobID)
|
||||
integrationtest.FailIf(err, t, "seed imported items")
|
||||
|
||||
summaryResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit/summary")
|
||||
integrationtest.FailIf(err, t, "audit summary")
|
||||
integrationtest.FailUnlessStatus(t, summaryResp, 200)
|
||||
var auditSummary struct {
|
||||
Imported int64 `json:"imported"`
|
||||
Failed int64 `json:"failed"`
|
||||
Skipped int64 `json:"skipped"`
|
||||
Total int64 `json:"total"`
|
||||
Service string `json:"service"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, summaryResp, &auditSummary)
|
||||
if auditSummary.Imported != 1 || auditSummary.Failed != 1 || auditSummary.Skipped != 1 || auditSummary.Total != 3 {
|
||||
t.Fatalf("audit summary = %+v, want 1 imported / 1 failed / 1 skipped", auditSummary)
|
||||
}
|
||||
if auditSummary.Service != "mail" {
|
||||
t.Fatalf("audit service = %q, want mail", auditSummary.Service)
|
||||
}
|
||||
|
||||
failedResp, err := adminClient.Get("/api/v1/admin/migration/projects/" + created.ID + "/jobs/" + mailJobID + "/audit?status=failed")
|
||||
integrationtest.FailIf(err, t, "audit failed list")
|
||||
integrationtest.FailUnlessStatus(t, failedResp, 200)
|
||||
var failedList struct {
|
||||
Items []struct {
|
||||
SourceID string `json:"source_id"`
|
||||
Status string `json:"status"`
|
||||
Reason string `json:"reason"`
|
||||
} `json:"items"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, failedResp, &failedList)
|
||||
if len(failedList.Items) != 1 || failedList.Items[0].SourceID != "msg-fail" {
|
||||
t.Fatalf("failed audit items = %+v", failedList.Items)
|
||||
}
|
||||
|
||||
resetResp, err := adminClient.Post("/api/v1/admin/migration/projects/"+created.ID+"/jobs/"+mailJobID+"/reset-cursor", nil)
|
||||
integrationtest.FailIf(err, t, "reset cursor")
|
||||
integrationtest.FailUnlessStatus(t, resetResp, 200)
|
||||
|
||||
var reset struct {
|
||||
Status string `json:"status"`
|
||||
CursorJSON map[string]any `json:"cursor_json"`
|
||||
StatsJSON map[string]any `json:"stats_json"`
|
||||
}
|
||||
integrationtest.DecodeJSON(t, resetResp, &reset)
|
||||
if reset.Status != "pending" {
|
||||
t.Fatalf("reset status = %q, want pending", reset.Status)
|
||||
}
|
||||
if len(reset.CursorJSON) != 0 {
|
||||
t.Fatalf("cursor not cleared: %#v", reset.CursorJSON)
|
||||
}
|
||||
if len(reset.StatsJSON) != 0 {
|
||||
t.Fatalf("stats not cleared: %#v", reset.StatsJSON)
|
||||
}
|
||||
|
||||
var importedCount int
|
||||
if err := h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM migration_imported_items WHERE job_id = $1::uuid
|
||||
`, mailJobID).Scan(&importedCount); err != nil {
|
||||
t.Fatalf("count imported items: %v", err)
|
||||
}
|
||||
if importedCount != 0 {
|
||||
t.Fatalf("imported items = %d, want 0", importedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphImportWritesMessages(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
ctx := context.Background()
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("graph-import")))
|
||||
integrationtest.FailIf(err, t, "ensure user")
|
||||
|
||||
var accountID string
|
||||
err = h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_accounts (user_id, email, provider, is_active)
|
||||
VALUES ($1::uuid, 'graph-import@test.local', 'hosted', true)
|
||||
RETURNING id::text
|
||||
`, userID).Scan(&accountID)
|
||||
integrationtest.FailIf(err, t, "insert mail account")
|
||||
|
||||
folderID := "inbox-folder-id"
|
||||
messagesListed := false
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/mailFolders"):
|
||||
_, _ = w.Write([]byte(`{"value":[{"id":"` + folderID + `","displayName":"Inbox","wellKnownName":"inbox"}]}`))
|
||||
case strings.Contains(r.URL.Path, "/messages"):
|
||||
messagesListed = true
|
||||
_, _ = w.Write([]byte(`{"value":[{
|
||||
"id":"msg-1",
|
||||
"subject":"Hello Graph",
|
||||
"bodyPreview":"Preview text",
|
||||
"body":{"contentType":"text","content":"Body text"},
|
||||
"from":{"emailAddress":{"name":"Alice","address":"alice@example.com"}},
|
||||
"toRecipients":[{"emailAddress":{"name":"Bob","address":"bob@example.com"}}],
|
||||
"receivedDateTime":"2024-05-01T10:00:00Z",
|
||||
"parentFolderId":"` + folderID + `",
|
||||
"isRead":true,
|
||||
"internetMessageId":"<graph-test@example.com>"
|
||||
}]}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
importer := migr.NewGraphImporter(h.Pool).WithBaseURL(srv.URL)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
err = importer.ImportBatch(ctx, job, "test-token", false, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
if status != "completed" {
|
||||
t.Fatalf("status = %q, want completed", status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
if !messagesListed {
|
||||
t.Fatal("graph messages endpoint not called")
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM messages WHERE account_id = $1::uuid AND subject = 'Hello Graph'
|
||||
`, accountID).Scan(&count); err != nil {
|
||||
t.Fatalf("count messages: %v", err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatalf("message count = %d, want 1", count)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGmailImportStoresAttachments(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
if h.AttachmentStorage == nil {
|
||||
t.Skip("attachment storage unavailable")
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
userID, err := users.EnsureUser(ctx, h.Pool, integrationtest.RegularUser(integrationtest.NewExternalID("gmail-att-import")))
|
||||
integrationtest.FailIf(err, t, "ensure user")
|
||||
|
||||
var accountID string
|
||||
err = h.Pool.QueryRow(ctx, `
|
||||
INSERT INTO mail_accounts (user_id, email, provider, is_active)
|
||||
VALUES ($1::uuid, 'gmail-att@test.local', 'hosted', true)
|
||||
RETURNING id::text
|
||||
`, userID).Scan(&accountID)
|
||||
integrationtest.FailIf(err, t, "insert mail account")
|
||||
|
||||
gmailID := "msg-with-att"
|
||||
client := googleRewriteClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case strings.Contains(r.URL.Path, "/users/me/profile"):
|
||||
_, _ = w.Write([]byte(`{"historyId":"12345"}`))
|
||||
case strings.Contains(r.URL.Path, "/attachments/att-123"):
|
||||
_, _ = w.Write([]byte(`{"size":5,"data":"aGVsbG8="}`))
|
||||
case strings.Contains(r.URL.Path, "/messages/"+gmailID):
|
||||
_, _ = w.Write([]byte(`{
|
||||
"id":"` + gmailID + `",
|
||||
"threadId":"t1",
|
||||
"labelIds":["INBOX"],
|
||||
"snippet":"see attached",
|
||||
"payload":{
|
||||
"mimeType":"multipart/mixed",
|
||||
"headers":[{"name":"Subject","value":"With attachment"}],
|
||||
"parts":[
|
||||
{"mimeType":"text/plain","body":{"data":"dGV4dA=="}},
|
||||
{
|
||||
"mimeType":"application/pdf",
|
||||
"headers":[{"name":"Content-Disposition","value":"attachment; filename=\"report.pdf\""}],
|
||||
"body":{"attachmentId":"att-123","size":5}
|
||||
}
|
||||
]
|
||||
}
|
||||
}`))
|
||||
case strings.HasSuffix(r.URL.Path, "/messages"):
|
||||
_, _ = w.Write([]byte(`{"messages":[{"id":"` + gmailID + `"}]}`))
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
})
|
||||
|
||||
importer := migr.NewGmailImporter(h.Pool).
|
||||
WithHTTPClient(client).
|
||||
WithStorage(h.AttachmentStorage, h.AttachmentsBucket)
|
||||
job := &migr.Job{
|
||||
UserID: userID,
|
||||
CursorJSON: map[string]any{},
|
||||
StatsJSON: map[string]any{},
|
||||
}
|
||||
err = importer.ImportBatch(ctx, job, "token", false, func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
if jobErr != "" {
|
||||
t.Fatalf("import error: %s", jobErr)
|
||||
}
|
||||
if status != "completed" {
|
||||
t.Fatalf("status = %q, want completed", status)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
integrationtest.FailIf(err, t, "import batch")
|
||||
|
||||
var attCount int
|
||||
var hasAttachments bool
|
||||
err = h.Pool.QueryRow(ctx, `
|
||||
SELECT COUNT(*)::int, COALESCE(BOOL_OR(m.has_attachments), false)
|
||||
FROM attachments a
|
||||
JOIN messages m ON m.id = a.message_id
|
||||
WHERE m.account_id = $1::uuid AND a.filename = 'report.pdf'
|
||||
`, accountID).Scan(&attCount, &hasAttachments)
|
||||
integrationtest.FailIf(err, t, "count attachments")
|
||||
if attCount != 1 || !hasAttachments {
|
||||
t.Fatalf("attachments = %d has_attachments = %v", attCount, hasAttachments)
|
||||
}
|
||||
}
|
||||
14
internal/integrationtest/migration/suite_test.go
Normal file
14
internal/integrationtest/migration/suite_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
//go:build integration
|
||||
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/integrationtest"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(integrationtest.RunMain(m))
|
||||
}
|
||||
@ -93,9 +93,11 @@ func (s *OIDCServer) IssueToken(claims *auth.Claims) (string, error) {
|
||||
Expiry: jwt.NewNumericDate(now.Add(time.Hour)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
}).Claims(map[string]any{
|
||||
"email": claims.Email,
|
||||
"name": claims.Name,
|
||||
"groups": claims.Groups,
|
||||
"email": claims.Email,
|
||||
"preferred_username": claims.PreferredUsername,
|
||||
"upn": claims.UPN,
|
||||
"name": claims.Name,
|
||||
"groups": claims.Groups,
|
||||
})
|
||||
return builder.Serialize()
|
||||
}
|
||||
|
||||
166
internal/mail/hosted/dns_verify.go
Normal file
166
internal/mail/hosted/dns_verify.go
Normal file
@ -0,0 +1,166 @@
|
||||
package hosted
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// DNSCheckReport summarizes live DNS checks for a hosted mail domain.
|
||||
type DNSCheckReport struct {
|
||||
Domain string `json:"domain"`
|
||||
TXTVerified bool `json:"txt_verified"`
|
||||
TXTRecords []string `json:"txt_records,omitempty"`
|
||||
TXTExpected string `json:"txt_expected,omitempty"`
|
||||
MXVerified bool `json:"mx_verified"`
|
||||
MXRecords []string `json:"mx_records"`
|
||||
ExpectedMX []string `json:"expected_mx"`
|
||||
Warnings []string `json:"warnings,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
}
|
||||
|
||||
func LookupDomainMX(ctx context.Context, domain string) ([]string, error) {
|
||||
domain = strings.ToLower(strings.TrimSpace(domain))
|
||||
if domain == "" {
|
||||
return nil, fmt.Errorf("domain required")
|
||||
}
|
||||
mxRecords, err := (&net.Resolver{}).LookupMX(ctx, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Slice(mxRecords, func(i, j int) bool {
|
||||
return mxRecords[i].Pref < mxRecords[j].Pref
|
||||
})
|
||||
out := make([]string, 0, len(mxRecords))
|
||||
for _, mx := range mxRecords {
|
||||
host := strings.TrimSuffix(strings.ToLower(mx.Host), ".")
|
||||
if host != "" {
|
||||
out = append(out, host)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func LookupDomainTXT(ctx context.Context, name string) ([]string, error) {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
if name == "" {
|
||||
return nil, fmt.Errorf("txt name required")
|
||||
}
|
||||
records, err := (&net.Resolver{}).LookupTXT(ctx, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]string, 0, len(records))
|
||||
for _, record := range records {
|
||||
record = strings.TrimSpace(record)
|
||||
if record != "" {
|
||||
out = append(out, record)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func MXMatchesExpected(mxHosts, expected []string) bool {
|
||||
if len(mxHosts) == 0 || len(expected) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, mx := range mxHosts {
|
||||
mx = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(mx)), ".")
|
||||
for _, want := range expected {
|
||||
want = strings.TrimSuffix(strings.ToLower(strings.TrimSpace(want)), ".")
|
||||
if want == "" {
|
||||
continue
|
||||
}
|
||||
if mx == want || strings.HasSuffix(mx, "."+want) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TXTContainsToken(records []string, token string) bool {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return false
|
||||
}
|
||||
for _, record := range records {
|
||||
if strings.TrimSpace(record) == token {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Service) CheckDomainDNS(ctx context.Context, domainID string, expectedMX []string) (DomainRow, DNSCheckReport, error) {
|
||||
row, err := s.GetDomain(ctx, domainID)
|
||||
if err != nil {
|
||||
return DomainRow{}, DNSCheckReport{}, err
|
||||
}
|
||||
report := DNSCheckReport{
|
||||
Domain: row.Name,
|
||||
ExpectedMX: append([]string(nil), expectedMX...),
|
||||
TXTExpected: strings.TrimSpace(row.VerificationToken),
|
||||
}
|
||||
|
||||
txtName := "_ultisuite-verify." + row.Name
|
||||
txtRecords, err := LookupDomainTXT(ctx, txtName)
|
||||
if err != nil {
|
||||
report.Errors = append(report.Errors, "txt lookup: "+err.Error())
|
||||
} else {
|
||||
report.TXTRecords = txtRecords
|
||||
report.TXTVerified = TXTContainsToken(txtRecords, row.VerificationToken)
|
||||
if !report.TXTVerified && row.TXTVerifiedAt != nil {
|
||||
report.Warnings = append(report.Warnings, "txt record not found but domain was previously verified")
|
||||
report.TXTVerified = true
|
||||
}
|
||||
}
|
||||
|
||||
mxRecords, err := LookupDomainMX(ctx, row.Name)
|
||||
if err != nil {
|
||||
report.Errors = append(report.Errors, "mx lookup: "+err.Error())
|
||||
} else {
|
||||
report.MXRecords = mxRecords
|
||||
report.MXVerified = MXMatchesExpected(mxRecords, expectedMX)
|
||||
if !report.MXVerified && row.MXVerifiedAt != nil && len(expectedMX) == 0 {
|
||||
report.MXVerified = len(mxRecords) > 0
|
||||
}
|
||||
}
|
||||
return row, report, nil
|
||||
}
|
||||
|
||||
func (s *Service) VerifyDomainTXTRecord(ctx context.Context, domainID string) (DomainRow, DNSCheckReport, error) {
|
||||
row, report, err := s.CheckDomainDNS(ctx, domainID, nil)
|
||||
if err != nil {
|
||||
return DomainRow{}, DNSCheckReport{}, err
|
||||
}
|
||||
if !report.TXTVerified {
|
||||
return row, report, fmt.Errorf("txt verification token not found at _ultisuite-verify.%s", row.Name)
|
||||
}
|
||||
updated, err := s.MarkDomainVerified(ctx, domainID)
|
||||
if err != nil {
|
||||
return row, report, err
|
||||
}
|
||||
return updated, report, nil
|
||||
}
|
||||
|
||||
func (s *Service) VerifyDomainMXRecord(ctx context.Context, domainID string, expectedMX []string) (DomainRow, DNSCheckReport, error) {
|
||||
row, report, err := s.CheckDomainDNS(ctx, domainID, expectedMX)
|
||||
if err != nil {
|
||||
return DomainRow{}, DNSCheckReport{}, err
|
||||
}
|
||||
if len(expectedMX) == 0 {
|
||||
report.Warnings = append(report.Warnings, "expected mx hosts not configured")
|
||||
return row, report, fmt.Errorf("expected mx hosts not configured")
|
||||
}
|
||||
if !report.MXVerified {
|
||||
return row, report, fmt.Errorf("mx records %v do not match expected %v", report.MXRecords, expectedMX)
|
||||
}
|
||||
updated, err := s.MarkDomainMXVerified(ctx, domainID)
|
||||
if err != nil {
|
||||
return row, report, err
|
||||
}
|
||||
return updated, report, nil
|
||||
}
|
||||
24
internal/mail/hosted/dns_verify_test.go
Normal file
24
internal/mail/hosted/dns_verify_test.go
Normal file
@ -0,0 +1,24 @@
|
||||
package hosted
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestMXMatchesExpected(t *testing.T) {
|
||||
if !MXMatchesExpected([]string{"mail.acme.com."}, []string{"mail.acme.com"}) {
|
||||
t.Fatal("exact mx match")
|
||||
}
|
||||
if !MXMatchesExpected([]string{"mx1.mail.ultisuite.fr"}, []string{"mail.ultisuite.fr"}) {
|
||||
t.Fatal("suffix mx match")
|
||||
}
|
||||
if MXMatchesExpected([]string{"aspmx.l.google.com"}, []string{"mail.ultisuite.fr"}) {
|
||||
t.Fatal("google mx should not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTXTContainsToken(t *testing.T) {
|
||||
if !TXTContainsToken([]string{"abc123", "other"}, "abc123") {
|
||||
t.Fatal("expected token match")
|
||||
}
|
||||
if TXTContainsToken([]string{"wrong"}, "abc123") {
|
||||
t.Fatal("unexpected token match")
|
||||
}
|
||||
}
|
||||
435
internal/mail/hosted/service.go
Normal file
435
internal/mail/hosted/service.go
Normal file
@ -0,0 +1,435 @@
|
||||
package hosted
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/stalwart"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAddressTaken = errors.New("mail address already taken")
|
||||
ErrInvalidLocalPart = errors.New("invalid local part")
|
||||
ErrDomainNotActive = errors.New("mail domain not active")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *pgxpool.Pool
|
||||
stlw *stalwart.Client
|
||||
creds *credentials.Manager
|
||||
imapHost string
|
||||
imapPort int
|
||||
imapTLS bool
|
||||
smtpHost string
|
||||
smtpPort int
|
||||
smtpTLS bool
|
||||
}
|
||||
|
||||
func NewService(db *pgxpool.Pool, stlw *stalwart.Client, creds *credentials.Manager) *Service {
|
||||
s := &Service{db: db, stlw: stlw, creds: creds}
|
||||
if stlw != nil {
|
||||
s.imapHost, s.imapPort, s.imapTLS = stlw.IMAPEndpoint()
|
||||
s.smtpHost, s.smtpPort, s.smtpTLS = stlw.SMTPEndpoint()
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
type DomainRow struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Status string `json:"status"`
|
||||
VerificationToken string `json:"verification_token,omitempty"`
|
||||
DKIMSelector string `json:"dkim_selector,omitempty"`
|
||||
DKIMPublicKey string `json:"dkim_public_key,omitempty"`
|
||||
StalwartDomainID string `json:"stalwart_domain_id,omitempty"`
|
||||
IsPlatformDomain bool `json:"is_platform_domain"`
|
||||
MXVerifiedAt *string `json:"mx_verified_at,omitempty"`
|
||||
TXTVerifiedAt *string `json:"txt_verified_at,omitempty"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
}
|
||||
|
||||
type MailboxRow struct {
|
||||
ID string `json:"id"`
|
||||
DomainID string `json:"domain_id"`
|
||||
LocalPart string `json:"local_part"`
|
||||
Email string `json:"email"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
MailAccountID string `json:"mail_account_id,omitempty"`
|
||||
StalwartAccountID string `json:"stalwart_account_id,omitempty"`
|
||||
QuotaBytes int64 `json:"quota_bytes"`
|
||||
Status string `json:"status"`
|
||||
}
|
||||
|
||||
func normalizeLocalPart(v string) (string, error) {
|
||||
v = strings.ToLower(strings.TrimSpace(v))
|
||||
if v == "" || len(v) > 64 {
|
||||
return "", ErrInvalidLocalPart
|
||||
}
|
||||
for _, ch := range v {
|
||||
if (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '.' || ch == '-' || ch == '_' || ch == '+' {
|
||||
continue
|
||||
}
|
||||
return "", ErrInvalidLocalPart
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (s *Service) IsAddressAvailable(ctx context.Context, domainName, localPart string) (bool, error) {
|
||||
localPart, err := normalizeLocalPart(localPart)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
domainName = strings.ToLower(strings.TrimSpace(domainName))
|
||||
var exists bool
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM mailboxes m
|
||||
JOIN mail_domains d ON d.id = m.domain_id
|
||||
WHERE d.name = $1 AND m.local_part = $2
|
||||
)
|
||||
`, domainName, localPart).Scan(&exists)
|
||||
return !exists, err
|
||||
}
|
||||
|
||||
func (s *Service) EnsurePlatformDomain(ctx context.Context, name string) (DomainRow, error) {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
if name == "" {
|
||||
return DomainRow{}, fmt.Errorf("platform domain name required")
|
||||
}
|
||||
var row DomainRow
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
FROM mail_domains WHERE name = $1
|
||||
`, name).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
if err == nil {
|
||||
return row, nil
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return DomainRow{}, err
|
||||
}
|
||||
token, err := randomToken(16)
|
||||
if err != nil {
|
||||
return DomainRow{}, err
|
||||
}
|
||||
stlwID := ""
|
||||
if s.stlw != nil && s.stlw.Enabled() {
|
||||
d, err := s.stlw.CreateDomain(ctx, name)
|
||||
if err != nil {
|
||||
return DomainRow{}, fmt.Errorf("stalwart create domain: %w", err)
|
||||
}
|
||||
stlwID = d.ID
|
||||
}
|
||||
status := "active"
|
||||
if !strings.HasSuffix(name, ".local") {
|
||||
status = "pending_verification"
|
||||
}
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO mail_domains (name, status, verification_token, stalwart_domain_id, is_platform_domain, txt_verified_at)
|
||||
VALUES ($1, $2, $3, $4, true, CASE WHEN $2 = 'active' THEN NOW() ELSE NULL END)
|
||||
RETURNING id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
`, name, status, token, stlwID).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (s *Service) CreateDomain(ctx context.Context, name string, platform bool) (DomainRow, error) {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
if name == "" {
|
||||
return DomainRow{}, fmt.Errorf("domain name required")
|
||||
}
|
||||
token, err := randomToken(16)
|
||||
if err != nil {
|
||||
return DomainRow{}, err
|
||||
}
|
||||
stlwID := ""
|
||||
if s.stlw != nil && s.stlw.Enabled() {
|
||||
d, err := s.stlw.CreateDomain(ctx, name)
|
||||
if err != nil {
|
||||
return DomainRow{}, fmt.Errorf("stalwart create domain: %w", err)
|
||||
}
|
||||
stlwID = d.ID
|
||||
}
|
||||
var row DomainRow
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO mail_domains (name, status, verification_token, stalwart_domain_id, is_platform_domain)
|
||||
VALUES ($1, 'pending_verification', $2, $3, $4)
|
||||
RETURNING id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
`, name, token, stlwID, platform).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (s *Service) ListDomains(ctx context.Context) ([]DomainRow, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
FROM mail_domains ORDER BY is_platform_domain DESC, name ASC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []DomainRow
|
||||
for rows.Next() {
|
||||
var row DomainRow
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) GetDomain(ctx context.Context, domainID string) (DomainRow, error) {
|
||||
var row DomainRow
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
FROM mail_domains WHERE id = $1
|
||||
`, domainID).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (s *Service) MarkDomainVerified(ctx context.Context, domainID string) (DomainRow, error) {
|
||||
var row DomainRow
|
||||
err := s.db.QueryRow(ctx, `
|
||||
UPDATE mail_domains
|
||||
SET status = 'active', txt_verified_at = COALESCE(txt_verified_at, NOW()), updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
`, domainID).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func (s *Service) MarkDomainMXVerified(ctx context.Context, domainID string) (DomainRow, error) {
|
||||
var row DomainRow
|
||||
err := s.db.QueryRow(ctx, `
|
||||
UPDATE mail_domains
|
||||
SET mx_verified_at = NOW(), status = 'active', updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
`, domainID).Scan(
|
||||
&row.ID, &row.Name, &row.Status, &row.VerificationToken, &row.DKIMSelector, &row.DKIMPublicKey,
|
||||
&row.StalwartDomainID, &row.IsPlatformDomain, &row.MXVerifiedAt, &row.TXTVerifiedAt, &row.CreatedAt,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
type ProvisionMailboxInput struct {
|
||||
UserID string
|
||||
Email string
|
||||
DisplayName string
|
||||
Password string
|
||||
QuotaBytes int64
|
||||
DomainID string
|
||||
}
|
||||
|
||||
type ProvisionMailboxResult struct {
|
||||
Mailbox MailboxRow
|
||||
MailAccountID string
|
||||
}
|
||||
|
||||
func (s *Service) ProvisionMailbox(ctx context.Context, in ProvisionMailboxInput) (ProvisionMailboxResult, error) {
|
||||
email := strings.ToLower(strings.TrimSpace(in.Email))
|
||||
at := strings.LastIndex(email, "@")
|
||||
if at <= 0 {
|
||||
return ProvisionMailboxResult{}, fmt.Errorf("invalid email")
|
||||
}
|
||||
localPart := email[:at]
|
||||
domainName := email[at+1:]
|
||||
localPart, err := normalizeLocalPart(localPart)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
|
||||
var domain DomainRow
|
||||
if strings.TrimSpace(in.DomainID) != "" {
|
||||
domain, err = s.GetDomain(ctx, in.DomainID)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
if !strings.EqualFold(domain.Name, domainName) {
|
||||
return ProvisionMailboxResult{}, fmt.Errorf("email domain %q does not match project domain %q", domainName, domain.Name)
|
||||
}
|
||||
} else {
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT id::text, name, status, verification_token, dkim_selector, dkim_public_key,
|
||||
stalwart_domain_id, is_platform_domain,
|
||||
mx_verified_at::text, txt_verified_at::text, created_at::text
|
||||
FROM mail_domains WHERE name = $1
|
||||
`, domainName).Scan(
|
||||
&domain.ID, &domain.Name, &domain.Status, &domain.VerificationToken, &domain.DKIMSelector, &domain.DKIMPublicKey,
|
||||
&domain.StalwartDomainID, &domain.IsPlatformDomain, &domain.MXVerifiedAt, &domain.TXTVerifiedAt, &domain.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
domain, err = s.EnsurePlatformDomain(ctx, domainName)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
} else {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
if domain.Status != "active" && !domain.IsPlatformDomain {
|
||||
return ProvisionMailboxResult{}, ErrDomainNotActive
|
||||
}
|
||||
|
||||
available, err := s.IsAddressAvailable(ctx, domainName, localPart)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
if !available {
|
||||
return ProvisionMailboxResult{}, ErrAddressTaken
|
||||
}
|
||||
|
||||
quota := in.QuotaBytes
|
||||
if quota <= 0 {
|
||||
quota = 5 * 1024 * 1024 * 1024
|
||||
}
|
||||
|
||||
stlwAccountID := ""
|
||||
if s.stlw != nil {
|
||||
acct, err := s.stlw.CreateAccount(ctx, domain.StalwartDomainID, localPart, in.Password, quota)
|
||||
if err != nil && !errors.Is(err, stalwart.ErrDisabled) {
|
||||
return ProvisionMailboxResult{}, fmt.Errorf("stalwart account: %w", err)
|
||||
}
|
||||
stlwAccountID = acct.ID
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
var mailboxID string
|
||||
err = tx.QueryRow(ctx, `
|
||||
INSERT INTO mailboxes (domain_id, local_part, user_id, stalwart_account_id, quota_bytes, status)
|
||||
VALUES ($1, $2, NULLIF($3, '')::uuid, $4, $5, 'active')
|
||||
RETURNING id::text
|
||||
`, domain.ID, localPart, in.UserID, stlwAccountID, quota).Scan(&mailboxID)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
|
||||
var mailAccountID string
|
||||
if in.UserID != "" {
|
||||
enc, err := s.encryptHostedCredential(email, in.Password)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
err = tx.QueryRow(ctx, `
|
||||
INSERT INTO mail_accounts (
|
||||
user_id, name, email, provider,
|
||||
imap_host, imap_port, imap_tls,
|
||||
smtp_host, smtp_port, smtp_tls,
|
||||
credentials, is_active
|
||||
)
|
||||
VALUES ($1, $2, $3, 'hosted', $4, $5, $6, $7, $8, $9, $10, true)
|
||||
RETURNING id::text
|
||||
`, in.UserID, in.DisplayName, email,
|
||||
s.imapHost, s.imapPort, s.imapTLS,
|
||||
s.smtpHost, s.smtpPort, s.smtpTLS,
|
||||
enc,
|
||||
).Scan(&mailAccountID)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE mailboxes SET user_id = $1::uuid, mail_account_id = $2::uuid, updated_at = NOW()
|
||||
WHERE id = $3::uuid
|
||||
`, in.UserID, mailAccountID, mailboxID)
|
||||
if err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return ProvisionMailboxResult{}, err
|
||||
}
|
||||
|
||||
return ProvisionMailboxResult{
|
||||
Mailbox: MailboxRow{
|
||||
ID: mailboxID,
|
||||
DomainID: domain.ID,
|
||||
LocalPart: localPart,
|
||||
Email: email,
|
||||
UserID: in.UserID,
|
||||
MailAccountID: mailAccountID,
|
||||
StalwartAccountID: stlwAccountID,
|
||||
QuotaBytes: quota,
|
||||
Status: "active",
|
||||
},
|
||||
MailAccountID: mailAccountID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) encryptHostedCredential(email, password string) ([]byte, error) {
|
||||
if s.creds == nil {
|
||||
return nil, fmt.Errorf("credential manager not configured")
|
||||
}
|
||||
return s.creds.EncryptCredential(credentials.Credential{
|
||||
AuthType: credentials.AuthPassword,
|
||||
Username: email,
|
||||
Password: password,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) LinkMailboxToUser(ctx context.Context, mailboxID, userID string) error {
|
||||
_, err := s.db.Exec(ctx, `
|
||||
UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW()
|
||||
WHERE id = $2::uuid AND (user_id IS NULL OR user_id = $1::uuid)
|
||||
`, userID, mailboxID)
|
||||
return err
|
||||
}
|
||||
|
||||
func randomToken(n int) (string, error) {
|
||||
b := make([]byte, n)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func NewInviteToken() (string, error) {
|
||||
return uuid.NewString(), nil
|
||||
}
|
||||
233
internal/mail/stalwart/client.go
Normal file
233
internal/mail/stalwart/client.go
Normal file
@ -0,0 +1,233 @@
|
||||
package stalwart
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrDisabled = errors.New("stalwart client disabled")
|
||||
ErrNotFound = errors.New("stalwart resource not found")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
BaseURL string
|
||||
APIKey string
|
||||
IMAPHost string
|
||||
IMAPPort int
|
||||
IMAPTLS bool
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPTLS bool
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
type Client struct {
|
||||
cfg Config
|
||||
}
|
||||
|
||||
func NewClient(cfg Config) *Client {
|
||||
if cfg.HTTPClient == nil {
|
||||
cfg.HTTPClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
if cfg.IMAPPort == 0 {
|
||||
cfg.IMAPPort = 993
|
||||
}
|
||||
if cfg.SMTPPort == 0 {
|
||||
cfg.SMTPPort = 587
|
||||
}
|
||||
return &Client{cfg: cfg}
|
||||
}
|
||||
|
||||
func (c *Client) Enabled() bool {
|
||||
return c != nil && c.cfg.Enabled && strings.TrimSpace(c.cfg.BaseURL) != ""
|
||||
}
|
||||
|
||||
func (c *Client) IMAPEndpoint() (host string, port int, tls bool) {
|
||||
return c.cfg.IMAPHost, c.cfg.IMAPPort, c.cfg.IMAPTLS
|
||||
}
|
||||
|
||||
func (c *Client) SMTPEndpoint() (host string, port int, tls bool) {
|
||||
return c.cfg.SMTPHost, c.cfg.SMTPPort, c.cfg.SMTPTLS
|
||||
}
|
||||
|
||||
type Domain struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
type Account struct {
|
||||
ID string
|
||||
Email string
|
||||
}
|
||||
|
||||
type jmapRequest struct {
|
||||
Using []string `json:"using"`
|
||||
MethodCalls []any `json:"methodCalls"`
|
||||
}
|
||||
|
||||
type jmapResponse struct {
|
||||
MethodResponses []json.RawMessage `json:"methodResponses"`
|
||||
}
|
||||
|
||||
func (c *Client) call(ctx context.Context, method string, args any, callID string) (json.RawMessage, error) {
|
||||
if !c.Enabled() {
|
||||
return nil, ErrDisabled
|
||||
}
|
||||
body, err := json.Marshal(jmapRequest{
|
||||
Using: []string{"urn:ietf:params:jmap:core", "urn:stalwart:jmap"},
|
||||
MethodCalls: []any{[]any{method, args, callID}},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
url := strings.TrimRight(c.cfg.BaseURL, "/") + "/api"
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
if c.cfg.APIKey != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+c.cfg.APIKey)
|
||||
}
|
||||
resp, err := c.cfg.HTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
raw, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if resp.StatusCode >= 400 {
|
||||
return nil, fmt.Errorf("stalwart api %s: %s", resp.Status, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
var envelope jmapResponse
|
||||
if err := json.Unmarshal(raw, &envelope); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(envelope.MethodResponses) == 0 {
|
||||
return nil, fmt.Errorf("stalwart: empty method response")
|
||||
}
|
||||
var parts []json.RawMessage
|
||||
if err := json.Unmarshal(envelope.MethodResponses[0], &parts); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(parts) < 2 {
|
||||
return nil, fmt.Errorf("stalwart: malformed method response")
|
||||
}
|
||||
var status string
|
||||
if err := json.Unmarshal(parts[0], &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if status == "error" {
|
||||
return nil, fmt.Errorf("stalwart error: %s", string(parts[1]))
|
||||
}
|
||||
return parts[1], nil
|
||||
}
|
||||
|
||||
func (c *Client) CreateDomain(ctx context.Context, name string) (Domain, error) {
|
||||
name = strings.ToLower(strings.TrimSpace(name))
|
||||
if name == "" {
|
||||
return Domain{}, fmt.Errorf("domain name required")
|
||||
}
|
||||
if !c.Enabled() {
|
||||
return Domain{ID: "local-" + name, Name: name}, nil
|
||||
}
|
||||
raw, err := c.call(ctx, "x:Domain/set", map[string]any{
|
||||
"create": map[string]any{
|
||||
"d1": map[string]any{"name": name},
|
||||
},
|
||||
}, "c1")
|
||||
if err != nil {
|
||||
return Domain{}, err
|
||||
}
|
||||
var parsed struct {
|
||||
Created map[string]struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"created"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return Domain{}, err
|
||||
}
|
||||
for _, v := range parsed.Created {
|
||||
return Domain{ID: v.ID, Name: name}, nil
|
||||
}
|
||||
return Domain{}, fmt.Errorf("stalwart: domain not created")
|
||||
}
|
||||
|
||||
func (c *Client) CreateAccount(ctx context.Context, domainID, localPart, password string, quotaBytes int64) (Account, error) {
|
||||
localPart = strings.ToLower(strings.TrimSpace(localPart))
|
||||
if localPart == "" {
|
||||
return Account{}, fmt.Errorf("local part required")
|
||||
}
|
||||
if !c.Enabled() {
|
||||
return Account{ID: "local-" + localPart, Email: localPart + "@local"}, nil
|
||||
}
|
||||
fields := map[string]any{
|
||||
"name": localPart,
|
||||
"domainId": domainID,
|
||||
}
|
||||
if password != "" {
|
||||
fields["credentials"] = map[string]any{"password": password}
|
||||
}
|
||||
if quotaBytes > 0 {
|
||||
fields["quota"] = map[string]any{"maxDiskQuota": quotaBytes}
|
||||
}
|
||||
raw, err := c.call(ctx, "x:Account/set", map[string]any{
|
||||
"create": map[string]any{"a1": fields},
|
||||
}, "c1")
|
||||
if err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
var parsed struct {
|
||||
Created map[string]struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"created"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return Account{}, err
|
||||
}
|
||||
for _, v := range parsed.Created {
|
||||
return Account{ID: v.ID, Email: localPart}, nil
|
||||
}
|
||||
return Account{}, fmt.Errorf("stalwart: account not created")
|
||||
}
|
||||
|
||||
func (c *Client) SetAccountPassword(ctx context.Context, accountID, password string) error {
|
||||
if password == "" {
|
||||
return fmt.Errorf("password required")
|
||||
}
|
||||
if !c.Enabled() {
|
||||
return nil
|
||||
}
|
||||
_, err := c.call(ctx, "x:Account/set", map[string]any{
|
||||
"update": map[string]any{
|
||||
accountID: map[string]any{
|
||||
"credentials": map[string]any{"password": password},
|
||||
},
|
||||
},
|
||||
}, "c1")
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) DeleteAccount(ctx context.Context, accountID string) error {
|
||||
if accountID == "" {
|
||||
return nil
|
||||
}
|
||||
if !c.Enabled() {
|
||||
return nil
|
||||
}
|
||||
_, err := c.call(ctx, "x:Account/set", map[string]any{
|
||||
"destroy": []string{accountID},
|
||||
}, "c1")
|
||||
return err
|
||||
}
|
||||
28
internal/mail/stalwart/client_test.go
Normal file
28
internal/mail/stalwart/client_test.go
Normal file
@ -0,0 +1,28 @@
|
||||
package stalwart
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientDisabledCreateDomain(t *testing.T) {
|
||||
c := NewClient(Config{Enabled: false})
|
||||
d, err := c.CreateDomain(context.Background(), "example.com")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if d.Name != "example.com" {
|
||||
t.Fatalf("expected example.com, got %q", d.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientDisabledCreateAccount(t *testing.T) {
|
||||
c := NewClient(Config{Enabled: false})
|
||||
a, err := c.CreateAccount(context.Background(), "local-example.com", "alice", "secret", 1024)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if a.ID == "" {
|
||||
t.Fatal("expected local account id")
|
||||
}
|
||||
}
|
||||
161
internal/migration/admin_jobs.go
Normal file
161
internal/migration/admin_jobs.go
Normal file
@ -0,0 +1,161 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
// AdminJob is a migration job enriched with the invited user email for admin dashboards.
|
||||
type AdminJob struct {
|
||||
Job
|
||||
UserEmail string `json:"user_email"`
|
||||
}
|
||||
|
||||
func (s *Service) ListProjectJobs(ctx context.Context, projectID string) ([]AdminJob, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT j.id::text, j.project_id::text, j.user_id::text, j.service, j.status,
|
||||
j.cursor_json, j.stats_json, j.error, j.started_at::text, j.completed_at::text,
|
||||
COALESCE(i.email, '')
|
||||
FROM migration_jobs j
|
||||
LEFT JOIN migration_invites i
|
||||
ON i.project_id = j.project_id AND i.user_id = j.user_id
|
||||
WHERE j.project_id = $1::uuid
|
||||
ORDER BY COALESCE(i.email, ''), j.service ASC
|
||||
`, projectID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []AdminJob
|
||||
for rows.Next() {
|
||||
var row AdminJob
|
||||
var cursorRaw, statsRaw []byte
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
||||
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
||||
&row.UserEmail,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
||||
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
||||
if row.CursorJSON == nil {
|
||||
row.CursorJSON = map[string]any{}
|
||||
}
|
||||
if row.StatsJSON == nil {
|
||||
row.StatsJSON = map[string]any{}
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) RetryJob(ctx context.Context, projectID, jobID string) (Job, error) {
|
||||
var row Job
|
||||
var cursorRaw, statsRaw []byte
|
||||
err := s.db.QueryRow(ctx, `
|
||||
UPDATE migration_jobs
|
||||
SET status = 'pending', error = '', updated_at = NOW()
|
||||
WHERE id = $1::uuid AND project_id = $2::uuid AND status = 'failed'
|
||||
RETURNING id::text, project_id::text, user_id::text, service, status,
|
||||
cursor_json, stats_json, error, started_at::text, completed_at::text
|
||||
`, jobID, projectID).Scan(
|
||||
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
||||
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
||||
)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Job{}, fmt.Errorf("job not found or not retryable")
|
||||
}
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
||||
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
||||
if row.CursorJSON == nil {
|
||||
row.CursorJSON = map[string]any{}
|
||||
}
|
||||
if row.StatsJSON == nil {
|
||||
row.StatsJSON = map[string]any{}
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
|
||||
func (s *Service) RetryFailedJobs(ctx context.Context, projectID string) (int64, error) {
|
||||
tag, err := s.db.Exec(ctx, `
|
||||
UPDATE migration_jobs
|
||||
SET status = 'pending', error = '', updated_at = NOW()
|
||||
WHERE project_id = $1::uuid AND status = 'failed'
|
||||
`, projectID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return tag.RowsAffected(), nil
|
||||
}
|
||||
|
||||
func (s *Service) ResetJobCursor(ctx context.Context, projectID, jobID string) (Job, error) {
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
var status string
|
||||
err = tx.QueryRow(ctx, `
|
||||
SELECT status FROM migration_jobs
|
||||
WHERE id = $1::uuid AND project_id = $2::uuid
|
||||
`, jobID, projectID).Scan(&status)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Job{}, fmt.Errorf("job not found")
|
||||
}
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
if status == "running" {
|
||||
return Job{}, fmt.Errorf("job running; wait for completion before reset")
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(ctx, `
|
||||
DELETE FROM migration_imported_items WHERE job_id = $1::uuid
|
||||
`, jobID); err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
|
||||
var row Job
|
||||
var cursorRaw, statsRaw []byte
|
||||
err = tx.QueryRow(ctx, `
|
||||
UPDATE migration_jobs
|
||||
SET status = 'pending',
|
||||
cursor_json = '{}'::jsonb,
|
||||
stats_json = '{}'::jsonb,
|
||||
error = '',
|
||||
started_at = NULL,
|
||||
completed_at = NULL,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid AND project_id = $2::uuid
|
||||
RETURNING id::text, project_id::text, user_id::text, service, status,
|
||||
cursor_json, stats_json, error, started_at::text, completed_at::text
|
||||
`, jobID, projectID).Scan(
|
||||
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
||||
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return Job{}, err
|
||||
}
|
||||
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
||||
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
||||
if row.CursorJSON == nil {
|
||||
row.CursorJSON = map[string]any{}
|
||||
}
|
||||
if row.StatsJSON == nil {
|
||||
row.StatsJSON = map[string]any{}
|
||||
}
|
||||
return row, nil
|
||||
}
|
||||
36
internal/migration/auth_modes.go
Normal file
36
internal/migration/auth_modes.go
Normal file
@ -0,0 +1,36 @@
|
||||
package migration
|
||||
|
||||
import "strings"
|
||||
|
||||
const AuthModeOAuth = "oauth"
|
||||
const AuthModeGoogleDWD = "google_dwd"
|
||||
const AuthModeMicrosoftApp = "microsoft_app"
|
||||
|
||||
func NormalizeAuthMode(provider, authMode string) string {
|
||||
authMode = strings.ToLower(strings.TrimSpace(authMode))
|
||||
if authMode == "" {
|
||||
return AuthModeOAuth
|
||||
}
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
switch authMode {
|
||||
case AuthModeGoogleDWD:
|
||||
if provider == "google" {
|
||||
return AuthModeGoogleDWD
|
||||
}
|
||||
case AuthModeMicrosoftApp:
|
||||
if provider == "microsoft" {
|
||||
return AuthModeMicrosoftApp
|
||||
}
|
||||
}
|
||||
return AuthModeOAuth
|
||||
}
|
||||
|
||||
func UsesUserOAuth(provider, authMode string) bool {
|
||||
authMode = NormalizeAuthMode(provider, authMode)
|
||||
switch authMode {
|
||||
case AuthModeGoogleDWD, AuthModeMicrosoftApp:
|
||||
return false
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
46
internal/migration/batch_config.go
Normal file
46
internal/migration/batch_config.go
Normal file
@ -0,0 +1,46 @@
|
||||
package migration
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
defaultMailImportBatchSize = 25
|
||||
defaultDriveImportBatchSize = 10
|
||||
)
|
||||
|
||||
// ImportBatchConfig controls how many items each migration importer processes per worker tick.
|
||||
type ImportBatchConfig struct {
|
||||
Mail int
|
||||
Drive int
|
||||
}
|
||||
|
||||
var (
|
||||
importBatchMu sync.RWMutex
|
||||
importBatchConfig = ImportBatchConfig{
|
||||
Mail: defaultMailImportBatchSize,
|
||||
Drive: defaultDriveImportBatchSize,
|
||||
}
|
||||
)
|
||||
|
||||
// ConfigureImportBatch sets package-wide batch sizes for migration importers.
|
||||
func ConfigureImportBatch(cfg ImportBatchConfig) {
|
||||
importBatchMu.Lock()
|
||||
defer importBatchMu.Unlock()
|
||||
if cfg.Mail > 0 {
|
||||
importBatchConfig.Mail = cfg.Mail
|
||||
}
|
||||
if cfg.Drive > 0 {
|
||||
importBatchConfig.Drive = cfg.Drive
|
||||
}
|
||||
}
|
||||
|
||||
func mailImportBatchSize() int {
|
||||
importBatchMu.RLock()
|
||||
defer importBatchMu.RUnlock()
|
||||
return importBatchConfig.Mail
|
||||
}
|
||||
|
||||
func driveImportBatchSize() int {
|
||||
importBatchMu.RLock()
|
||||
defer importBatchMu.RUnlock()
|
||||
return importBatchConfig.Drive
|
||||
}
|
||||
25
internal/migration/batch_config_test.go
Normal file
25
internal/migration/batch_config_test.go
Normal file
@ -0,0 +1,25 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestConfigureImportBatch(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
ConfigureImportBatch(ImportBatchConfig{
|
||||
Mail: defaultMailImportBatchSize,
|
||||
Drive: defaultDriveImportBatchSize,
|
||||
})
|
||||
})
|
||||
|
||||
ConfigureImportBatch(ImportBatchConfig{Mail: 7, Drive: 3})
|
||||
if got := mailImportBatchSize(); got != 7 {
|
||||
t.Fatalf("mail batch = %d", got)
|
||||
}
|
||||
if got := driveImportBatchSize(); got != 3 {
|
||||
t.Fatalf("drive batch = %d", got)
|
||||
}
|
||||
|
||||
ConfigureImportBatch(ImportBatchConfig{Mail: 0, Drive: 0})
|
||||
if got := mailImportBatchSize(); got != 7 {
|
||||
t.Fatalf("mail batch unchanged = %d", got)
|
||||
}
|
||||
}
|
||||
70
internal/migration/calendar_delta_test.go
Normal file
70
internal/migration/calendar_delta_test.go
Normal file
@ -0,0 +1,70 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListGoogleCalendarCancelledEvent(t *testing.T) {
|
||||
client := mockGoogleHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/calendar/v3/calendars/") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"items":[{"id":"e1","status":"cancelled","summary":"gone"}],
|
||||
"nextSyncToken":"sync-1"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
c := NewCalendarImporter(nil, nil).WithHTTPClient(client)
|
||||
events, _, syncToken, err := c.listSourceEvents(
|
||||
context.Background(),
|
||||
"token",
|
||||
"google",
|
||||
sourceCalendar{ID: "primary"},
|
||||
"",
|
||||
"sync-old",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("list events: %v", err)
|
||||
}
|
||||
if len(events) != 1 || !events[0].Deleted {
|
||||
t.Fatalf("events: %#v", events)
|
||||
}
|
||||
if syncToken != "sync-1" {
|
||||
t.Fatalf("sync token = %q", syncToken)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMicrosoftCalendarDeltaRemoved(t *testing.T) {
|
||||
client := mockGraphHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/events/delta") {
|
||||
_, _ = w.Write([]byte(`{
|
||||
"value":[{"id":"e1","@removed":{"reason":"deleted"}}],
|
||||
"@odata.deltaLink":"https://graph.microsoft.com/delta/next"
|
||||
}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
c := NewCalendarImporter(nil, nil).WithHTTPClient(client)
|
||||
events, next, err := c.listMicrosoftCalendarDelta(
|
||||
context.Background(),
|
||||
"token",
|
||||
"cal-1",
|
||||
"https://graph.microsoft.com/v1.0/me/calendars/cal-1/events/delta",
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("list delta: %v", err)
|
||||
}
|
||||
if len(events) != 1 || !events[0].Deleted {
|
||||
t.Fatalf("events: %#v", events)
|
||||
}
|
||||
if next == "" {
|
||||
t.Fatal("expected next cursor")
|
||||
}
|
||||
}
|
||||
542
internal/migration/calendar_import.go
Normal file
542
internal/migration/calendar_import.go
Normal file
@ -0,0 +1,542 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
const migrationCalendarID = "migration-import"
|
||||
|
||||
type CalendarImporter struct {
|
||||
db *pgxpool.Pool
|
||||
nc *nextcloud.Client
|
||||
client *http.Client
|
||||
userUPN string
|
||||
}
|
||||
|
||||
func NewCalendarImporter(db *pgxpool.Pool, nc *nextcloud.Client) *CalendarImporter {
|
||||
return &CalendarImporter{db: db, nc: nc, client: migrationHTTPClient()}
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) WithUserPrincipal(upn string) *CalendarImporter {
|
||||
c.userUPN = strings.TrimSpace(upn)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) WithHTTPClient(client *http.Client) *CalendarImporter {
|
||||
if client != nil {
|
||||
c.client = client
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error {
|
||||
if c.nc == nil {
|
||||
return fmt.Errorf("nextcloud required for calendar migration")
|
||||
}
|
||||
user, err := resolveMigrationUser(ctx, c.db, job.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ncUserID := nextcloud.UserIDFromClaims(user.Email, user.ExternalID)
|
||||
if _, err := c.nc.EnsurePrincipal(ctx, user.Email, user.ExternalID, user.Name); err != nil {
|
||||
return fmt.Errorf("nextcloud user: %w", err)
|
||||
}
|
||||
calPath, err := c.ensureMigrationCalendar(ctx, ncUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := LoadImportedItemStore(ctx, c.db, job.ID, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delta && c.hasDeltaCursor(job, provider) {
|
||||
return c.importDelta(ctx, job, accessToken, provider, ncUserID, calPath, items, update)
|
||||
}
|
||||
|
||||
return c.importFull(ctx, job, accessToken, provider, ncUserID, calPath, delta, items, update)
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) hasDeltaCursor(job *Job, provider string) bool {
|
||||
if provider == "google" {
|
||||
return len(calendarSyncTokens(job.CursorJSON)) > 0
|
||||
}
|
||||
return len(calendarDeltaLinks(job.CursorJSON)) > 0
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) importFull(ctx context.Context, job *Job, accessToken, provider, ncUserID, calPath string, captureDelta bool, items *ImportedItemStore, update progressUpdater) error {
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
batch := 0
|
||||
|
||||
calIndex := int(jsonNumber(job.CursorJSON["calendarIndex"]))
|
||||
sourceCalendars, err := c.listSourceCalendars(ctx, accessToken, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(sourceCalendars) == 0 {
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
if calIndex >= len(sourceCalendars) {
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
sourceCal := sourceCalendars[calIndex]
|
||||
pageToken, _ := job.CursorJSON["pageToken"].(string)
|
||||
events, nextToken, syncToken, err := c.listSourceEvents(ctx, accessToken, provider, sourceCal, pageToken, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(events) && batch < mailImportBatchSize(); i++ {
|
||||
ev := events[i]
|
||||
if alreadyImported(items, ev.SourceID) {
|
||||
continue
|
||||
}
|
||||
if err := c.nc.CreateEvent(ctx, ncUserID, calPath, ev.ToNextcloudEvent(provider)); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, ev.SourceID, err.Error(), ""); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, ev.SourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
imported++
|
||||
batch++
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
|
||||
if listIndex+batch < len(events) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if nextToken != "" {
|
||||
job.CursorJSON["pageToken"] = nextToken
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "pageToken")
|
||||
|
||||
if captureDelta {
|
||||
if provider == "google" && syncToken != "" {
|
||||
setCalendarSyncToken(job.CursorJSON, sourceCal.ID, syncToken)
|
||||
}
|
||||
if provider != "google" {
|
||||
if link, err := c.bootstrapCalendarDelta(ctx, accessToken, sourceCal.ID); err == nil && link != "" {
|
||||
setCalendarDeltaLink(job.CursorJSON, sourceCal.ID, link)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
job.CursorJSON["calendarIndex"] = float64(calIndex + 1)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) importDelta(ctx context.Context, job *Job, accessToken, provider, ncUserID, calPath string, items *ImportedItemStore, update progressUpdater) error {
|
||||
calIndex := int(jsonNumber(job.CursorJSON["calendarIndex"]))
|
||||
sourceCalendars, err := c.listSourceCalendars(ctx, accessToken, provider)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if calIndex >= len(sourceCalendars) {
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
sourceCal := sourceCalendars[calIndex]
|
||||
imported, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
updated, _ := job.StatsJSON["delta_updated"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
|
||||
var events []sourceEvent
|
||||
var nextCursor string
|
||||
if provider == "google" {
|
||||
syncToken := calendarSyncTokens(job.CursorJSON)[sourceCal.ID]
|
||||
pageToken, _ := job.CursorJSON["pageToken"].(string)
|
||||
var syncTokenOut string
|
||||
events, nextCursor, syncTokenOut, err = c.listSourceEvents(ctx, accessToken, provider, sourceCal, pageToken, syncToken)
|
||||
if syncTokenOut != "" {
|
||||
setCalendarSyncToken(job.CursorJSON, sourceCal.ID, syncTokenOut)
|
||||
}
|
||||
} else {
|
||||
deltaLink := calendarDeltaLinks(job.CursorJSON)[sourceCal.ID]
|
||||
if deltaLink == "" {
|
||||
deltaLink, _ = job.CursorJSON["pageToken"].(string)
|
||||
}
|
||||
events, nextCursor, err = c.listMicrosoftCalendarDelta(ctx, accessToken, sourceCal.ID, deltaLink)
|
||||
if nextCursor != "" && strings.Contains(nextCursor, "delta") {
|
||||
setCalendarDeltaLink(job.CursorJSON, sourceCal.ID, nextCursor)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batch := 0
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(events) && batch < mailImportBatchSize(); i++ {
|
||||
ev := events[i]
|
||||
if ev.Deleted {
|
||||
eventPath := migrationEventPath(calPath, provider, ev.SourceID)
|
||||
if err := c.nc.DeleteEvent(ctx, ncUserID, eventPath); err != nil && !isDeleteNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if err := items.Unmark(ctx, ev.SourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
deleted++
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
wasUpdate, err := c.upsertEvent(ctx, ncUserID, calPath, provider, ev, items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if wasUpdate {
|
||||
updated++
|
||||
} else if items.Has(ev.SourceID) {
|
||||
imported++
|
||||
} else {
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
}
|
||||
batch++
|
||||
}
|
||||
job.StatsJSON["delta_imported"] = imported
|
||||
job.StatsJSON["delta_updated"] = updated
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
|
||||
if listIndex+batch < len(events) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if nextCursor != "" {
|
||||
job.CursorJSON["pageToken"] = nextCursor
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "pageToken")
|
||||
job.CursorJSON["calendarIndex"] = float64(calIndex + 1)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) bootstrapCalendarDelta(ctx context.Context, accessToken, calID string) (string, error) {
|
||||
url := graphMicrosoftURL(c.userUPN, fmt.Sprintf("/calendars/%s/events/delta?$select=id,subject,body,start,end,isAllDay,location", url.PathEscape(calID)))
|
||||
body, err := apiGet(ctx, c.client, url, accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed.DeltaLink != "" {
|
||||
return parsed.DeltaLink, nil
|
||||
}
|
||||
return parsed.NextLink, nil
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) listMicrosoftCalendarDelta(ctx context.Context, accessToken, calID, deltaLink string) ([]sourceEvent, string, error) {
|
||||
if deltaLink == "" {
|
||||
deltaLink = graphMicrosoftURL(c.userUPN, fmt.Sprintf("/calendars/%s/events/delta?$select=id,subject,body,start,end,isAllDay,location", url.PathEscape(calID)))
|
||||
}
|
||||
body, err := apiGet(ctx, c.client, deltaLink, accessToken)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []graphCalendarEvent `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
out := make([]sourceEvent, 0, len(parsed.Value))
|
||||
for _, item := range parsed.Value {
|
||||
out = append(out, item.toSourceEvent(calID))
|
||||
}
|
||||
next := parsed.NextLink
|
||||
if parsed.DeltaLink != "" && parsed.NextLink == "" {
|
||||
next = parsed.DeltaLink
|
||||
}
|
||||
return out, next, nil
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) ensureMigrationCalendar(ctx context.Context, ncUserID string) (string, error) {
|
||||
path := fmt.Sprintf("/remote.php/dav/calendars/%s/%s/", ncUserID, migrationCalendarID)
|
||||
if err := c.nc.CreateCalendar(ctx, ncUserID, migrationCalendarID, "Migration Import", "#1a73e8"); err != nil {
|
||||
msg := strings.ToLower(err.Error())
|
||||
if !strings.Contains(msg, "405") && !strings.Contains(msg, "409") && !strings.Contains(msg, "423") {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
type sourceCalendar struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
type sourceEvent struct {
|
||||
SourceID string
|
||||
Summary string
|
||||
Description string
|
||||
Location string
|
||||
Start time.Time
|
||||
End time.Time
|
||||
AllDay bool
|
||||
Deleted bool
|
||||
}
|
||||
|
||||
func (e sourceEvent) ToNextcloudEvent(provider string) *nextcloud.Event {
|
||||
uid := sanitizeMigrationUID(provider, e.SourceID)
|
||||
start := e.Start.UTC()
|
||||
end := e.End.UTC()
|
||||
if end.IsZero() || !end.After(start) {
|
||||
end = start.Add(time.Hour)
|
||||
}
|
||||
ev := &nextcloud.Event{
|
||||
UID: uid,
|
||||
Summary: e.Summary,
|
||||
Description: e.Description,
|
||||
Location: e.Location,
|
||||
AllDay: e.AllDay,
|
||||
}
|
||||
if e.AllDay {
|
||||
ev.Start = start.Format("20060102")
|
||||
ev.End = end.Format("20060102")
|
||||
} else {
|
||||
ev.Start = start.Format("20060102T150405Z")
|
||||
ev.End = end.Format("20060102T150405Z")
|
||||
}
|
||||
return ev
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) listSourceCalendars(ctx context.Context, accessToken, provider string) ([]sourceCalendar, error) {
|
||||
switch provider {
|
||||
case "google":
|
||||
body, err := apiGet(ctx, c.client, "https://www.googleapis.com/calendar/v3/users/me/calendarList?maxResults=100", accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var parsed struct {
|
||||
Items []struct {
|
||||
ID string `json:"id"`
|
||||
Summary string `json:"summary"`
|
||||
} `json:"items"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]sourceCalendar, 0, len(parsed.Items))
|
||||
for _, item := range parsed.Items {
|
||||
out = append(out, sourceCalendar{ID: item.ID, Name: item.Summary})
|
||||
}
|
||||
return out, nil
|
||||
default:
|
||||
body, err := apiGet(ctx, c.client, graphMicrosoftURL(c.userUPN, "/calendars?$top=100"), accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
} `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make([]sourceCalendar, 0, len(parsed.Value))
|
||||
for _, item := range parsed.Value {
|
||||
out = append(out, sourceCalendar{ID: item.ID, Name: item.Name})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) listSourceEvents(ctx context.Context, accessToken, provider string, cal sourceCalendar, pageToken, syncToken string) ([]sourceEvent, string, string, error) {
|
||||
switch provider {
|
||||
case "google":
|
||||
listURL := fmt.Sprintf(
|
||||
"https://www.googleapis.com/calendar/v3/calendars/%s/events?maxResults=100&singleEvents=true&orderBy=startTime",
|
||||
url.PathEscape(cal.ID),
|
||||
)
|
||||
if syncToken != "" {
|
||||
listURL += "&syncToken=" + url.QueryEscape(syncToken) + "&showDeleted=true"
|
||||
} else if pageToken != "" {
|
||||
listURL += "&pageToken=" + url.QueryEscape(pageToken)
|
||||
}
|
||||
body, err := apiGet(ctx, c.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
var parsed struct {
|
||||
Items []googleCalendarEvent `json:"items"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
NextSyncToken string `json:"nextSyncToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
out := make([]sourceEvent, 0, len(parsed.Items))
|
||||
for _, item := range parsed.Items {
|
||||
out = append(out, item.toSourceEvent(cal.ID))
|
||||
}
|
||||
return out, parsed.NextPageToken, parsed.NextSyncToken, nil
|
||||
default:
|
||||
listURL := graphMicrosoftURL(c.userUPN, fmt.Sprintf("/calendars/%s/events?$top=100&$select=id,subject,body,start,end,isAllDay,location", url.PathEscape(cal.ID)))
|
||||
if pageToken != "" {
|
||||
listURL = pageToken
|
||||
}
|
||||
body, err := apiGet(ctx, c.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []graphCalendarEvent `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
out := make([]sourceEvent, 0, len(parsed.Value))
|
||||
for _, item := range parsed.Value {
|
||||
out = append(out, item.toSourceEvent(cal.ID))
|
||||
}
|
||||
return out, parsed.NextLink, "", nil
|
||||
}
|
||||
}
|
||||
|
||||
type googleCalendarEvent struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Summary string `json:"summary"`
|
||||
Description string `json:"description"`
|
||||
Location string `json:"location"`
|
||||
Start struct {
|
||||
DateTime string `json:"dateTime"`
|
||||
Date string `json:"date"`
|
||||
} `json:"start"`
|
||||
End struct {
|
||||
DateTime string `json:"dateTime"`
|
||||
Date string `json:"date"`
|
||||
} `json:"end"`
|
||||
}
|
||||
|
||||
func (e googleCalendarEvent) toSourceEvent(calID string) sourceEvent {
|
||||
allDay := e.Start.Date != ""
|
||||
start := parseFlexibleTime(e.Start.DateTime, e.Start.Date)
|
||||
end := parseFlexibleTime(e.End.DateTime, e.End.Date)
|
||||
return sourceEvent{
|
||||
SourceID: calID + ":" + e.ID,
|
||||
Summary: e.Summary,
|
||||
Description: e.Description,
|
||||
Location: e.Location,
|
||||
Start: start,
|
||||
End: end,
|
||||
AllDay: allDay,
|
||||
Deleted: e.Status == "cancelled",
|
||||
}
|
||||
}
|
||||
|
||||
type graphCalendarEvent struct {
|
||||
ID string `json:"id"`
|
||||
Removed *struct {
|
||||
Reason string `json:"reason"`
|
||||
} `json:"@removed"`
|
||||
Subject string `json:"subject"`
|
||||
Body struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"body"`
|
||||
IsAllDay bool `json:"isAllDay"`
|
||||
Location struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
} `json:"location"`
|
||||
Start graphDateTime `json:"start"`
|
||||
End graphDateTime `json:"end"`
|
||||
}
|
||||
|
||||
type graphDateTime struct {
|
||||
DateTime string `json:"dateTime"`
|
||||
Date string `json:"date"`
|
||||
}
|
||||
|
||||
func (e graphCalendarEvent) toSourceEvent(calID string) sourceEvent {
|
||||
allDay := e.IsAllDay || e.Start.Date != ""
|
||||
start := parseFlexibleTime(e.Start.DateTime, e.Start.Date)
|
||||
end := parseFlexibleTime(e.End.DateTime, e.End.Date)
|
||||
return sourceEvent{
|
||||
SourceID: calID + ":" + e.ID,
|
||||
Summary: e.Subject,
|
||||
Description: e.Body.Content,
|
||||
Location: e.Location.DisplayName,
|
||||
Start: start,
|
||||
End: end,
|
||||
AllDay: allDay,
|
||||
Deleted: e.Removed != nil,
|
||||
}
|
||||
}
|
||||
|
||||
func parseFlexibleTime(dateTime, date string) time.Time {
|
||||
if strings.TrimSpace(dateTime) != "" {
|
||||
if t, err := time.Parse(time.RFC3339, dateTime); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(date) != "" {
|
||||
if t, err := time.Parse("2006-01-02", date); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (c *CalendarImporter) upsertEvent(
|
||||
ctx context.Context,
|
||||
ncUserID, calPath, provider string,
|
||||
ev sourceEvent,
|
||||
items *ImportedItemStore,
|
||||
) (updated bool, err error) {
|
||||
ncEv := ev.ToNextcloudEvent(provider)
|
||||
if alreadyImported(items, ev.SourceID) {
|
||||
eventPath := migrationEventPath(calPath, provider, ev.SourceID)
|
||||
if _, err := c.nc.UpdateEvent(ctx, ncUserID, eventPath, "", ncEv); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
if err := c.nc.CreateEvent(ctx, ncUserID, calPath, ncEv); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, ev.SourceID, err.Error(), ""); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err := items.MarkImported(ctx, ev.SourceID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
150
internal/migration/claim_email_match.go
Normal file
150
internal/migration/claim_email_match.go
Normal file
@ -0,0 +1,150 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
)
|
||||
|
||||
// ClaimIdentity holds SSO identity fields checked against a migration invite.
|
||||
type ClaimIdentity struct {
|
||||
Email string
|
||||
PreferredUsername string
|
||||
UPN string
|
||||
}
|
||||
|
||||
func ClaimIdentityFromAuth(c *auth.Claims) ClaimIdentity {
|
||||
if c == nil {
|
||||
return ClaimIdentity{}
|
||||
}
|
||||
return ClaimIdentity{
|
||||
Email: c.Email,
|
||||
PreferredUsername: c.PreferredUsername,
|
||||
UPN: c.UPN,
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeInviteEmail(email string) string {
|
||||
return strings.ToLower(strings.TrimSpace(email))
|
||||
}
|
||||
|
||||
func isEmailAddress(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
at := strings.LastIndex(s, "@")
|
||||
return at > 0 && at < len(s)-1
|
||||
}
|
||||
|
||||
func identityCandidateEmails(id ClaimIdentity) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
for _, raw := range []string{id.Email, id.PreferredUsername, id.UPN} {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" || !isEmailAddress(raw) {
|
||||
continue
|
||||
}
|
||||
norm := normalizeInviteEmail(raw)
|
||||
if _, ok := seen[norm]; ok {
|
||||
continue
|
||||
}
|
||||
seen[norm] = struct{}{}
|
||||
out = append(out, norm)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func normalizeEmailLocalPart(local string) string {
|
||||
local = strings.ToLower(strings.TrimSpace(local))
|
||||
if plus := strings.Index(local, "+"); plus >= 0 {
|
||||
local = local[:plus]
|
||||
}
|
||||
return strings.ReplaceAll(local, ".", "")
|
||||
}
|
||||
|
||||
func emailLocalAndDomain(email string) (local, domain string, ok bool) {
|
||||
email = normalizeInviteEmail(email)
|
||||
at := strings.LastIndex(email, "@")
|
||||
if at <= 0 || at == len(email)-1 {
|
||||
return "", "", false
|
||||
}
|
||||
return email[:at], email[at+1:], true
|
||||
}
|
||||
|
||||
func inviteMatchTargets(inviteEmail string, alternateEmails []string) []string {
|
||||
seen := make(map[string]struct{})
|
||||
var out []string
|
||||
add := func(e string) {
|
||||
e = normalizeInviteEmail(e)
|
||||
if e == "" || !isEmailAddress(e) {
|
||||
return
|
||||
}
|
||||
if _, ok := seen[e]; ok {
|
||||
return
|
||||
}
|
||||
seen[e] = struct{}{}
|
||||
out = append(out, e)
|
||||
}
|
||||
add(inviteEmail)
|
||||
for _, alt := range alternateEmails {
|
||||
add(alt)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func localPartAliasMatch(a, b string) bool {
|
||||
aLocal, aDomain, okA := emailLocalAndDomain(a)
|
||||
bLocal, bDomain, okB := emailLocalAndDomain(b)
|
||||
if !okA || !okB {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(aDomain, bDomain) {
|
||||
return false
|
||||
}
|
||||
return normalizeEmailLocalPart(aLocal) == normalizeEmailLocalPart(bLocal)
|
||||
}
|
||||
|
||||
// projectDomainUPNMatch accepts claims when the invite is on the hosted project domain
|
||||
// and a UPN-style identity (preferred_username / upn) shares the same mailbox local-part.
|
||||
// Typical Microsoft case: invite alice@acme.com, SSO preferred_username alice@tenant.onmicrosoft.com.
|
||||
func projectDomainUPNMatch(inviteEmail, projectDomain string, identity ClaimIdentity) bool {
|
||||
if projectDomain == "" {
|
||||
return false
|
||||
}
|
||||
projectDomain = strings.ToLower(strings.TrimSpace(projectDomain))
|
||||
invLocal, invDomain, ok := emailLocalAndDomain(inviteEmail)
|
||||
if !ok || !strings.EqualFold(invDomain, projectDomain) {
|
||||
return false
|
||||
}
|
||||
for _, raw := range []string{identity.PreferredUsername, identity.UPN} {
|
||||
candLocal, _, ok := emailLocalAndDomain(raw)
|
||||
if ok && strings.EqualFold(candLocal, invLocal) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// InviteEmailMatchesIdentity reports whether SSO identity may claim the invite.
|
||||
// projectDomain is the hosted mail domain when the migration project is domain-bound.
|
||||
func InviteEmailMatchesIdentity(inviteEmail string, alternateEmails []string, projectDomain string, identity ClaimIdentity) bool {
|
||||
targets := inviteMatchTargets(inviteEmail, alternateEmails)
|
||||
if len(targets) == 0 {
|
||||
return false
|
||||
}
|
||||
candidates := identityCandidateEmails(identity)
|
||||
if len(candidates) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, target := range targets {
|
||||
for _, candidate := range candidates {
|
||||
if candidate == target {
|
||||
return true
|
||||
}
|
||||
if localPartAliasMatch(target, candidate) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return projectDomainUPNMatch(inviteEmail, projectDomain, identity)
|
||||
}
|
||||
92
internal/migration/claim_email_match_test.go
Normal file
92
internal/migration/claim_email_match_test.go
Normal file
@ -0,0 +1,92 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestInviteEmailMatchesIdentityExact(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "Alice@Acme.com"}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected case-insensitive exact match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityPreferredUsername(t *testing.T) {
|
||||
id := ClaimIdentity{
|
||||
Email: "alice.smith@acme.com",
|
||||
PreferredUsername: "alice@acme.com",
|
||||
}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected preferred_username match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityUPN(t *testing.T) {
|
||||
id := ClaimIdentity{
|
||||
Email: "alice.smith@acme.com",
|
||||
UPN: "alice@acme.com",
|
||||
}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected upn match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityAlternateEmail(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "alice.smith@acme.com"}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", []string{"alice.smith@acme.com"}, "", id) {
|
||||
t.Fatal("expected alternate email match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityGmailDotAlias(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "alice.smith@acme.com"}
|
||||
if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) {
|
||||
t.Fatal("expected exact match baseline")
|
||||
}
|
||||
id = ClaimIdentity{Email: "a.l.i.c.e.smith@acme.com"}
|
||||
if !InviteEmailMatchesIdentity("alice.smith@acme.com", nil, "", id) {
|
||||
t.Fatal("expected dot-insensitive local-part match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityPlusTag(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "alice+tag@acme.com"}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected plus-tag stripped match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityProjectDomainUPN(t *testing.T) {
|
||||
id := ClaimIdentity{
|
||||
Email: "alice.smith@acme.com",
|
||||
PreferredUsername: "alice@contoso.onmicrosoft.com",
|
||||
}
|
||||
if !InviteEmailMatchesIdentity("alice@acme.com", nil, "acme.com", id) {
|
||||
t.Fatal("expected project-domain UPN local-part match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityRejectsDifferentUserSameDomain(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "bob@acme.com"}
|
||||
if InviteEmailMatchesIdentity("alice@acme.com", nil, "acme.com", id) {
|
||||
t.Fatal("expected reject for different local-part on same domain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityRejectsUnrelatedDomain(t *testing.T) {
|
||||
id := ClaimIdentity{Email: "alice@evil.com"}
|
||||
if InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected reject for different domain without alias")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityEmptyIdentity(t *testing.T) {
|
||||
if InviteEmailMatchesIdentity("alice@acme.com", nil, "", ClaimIdentity{}) {
|
||||
t.Fatal("expected reject for empty identity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInviteEmailMatchesIdentityIgnoresNonEmailPreferredUsername(t *testing.T) {
|
||||
id := ClaimIdentity{PreferredUsername: "alice"}
|
||||
if InviteEmailMatchesIdentity("alice@acme.com", nil, "", id) {
|
||||
t.Fatal("expected reject when preferred_username is not an email")
|
||||
}
|
||||
}
|
||||
54
internal/migration/contacts_delta_test.go
Normal file
54
internal/migration/contacts_delta_test.go
Normal file
@ -0,0 +1,54 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGooglePersonDeletedMetadata(t *testing.T) {
|
||||
person := googlePerson{
|
||||
ResourceName: "people/abc",
|
||||
Metadata: &struct{ Deleted bool `json:"deleted"` }{Deleted: true},
|
||||
}
|
||||
if person.Metadata == nil || !person.Metadata.Deleted {
|
||||
t.Fatal("expected deleted metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphContactRemoved(t *testing.T) {
|
||||
removed := struct {
|
||||
Reason string `json:"reason"`
|
||||
}{Reason: "deleted"}
|
||||
item := graphContact{ID: "c1", Removed: &removed}
|
||||
if item.Removed == nil {
|
||||
t.Fatal("expected removed marker")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGooglePersonToContact(t *testing.T) {
|
||||
person := googlePerson{
|
||||
ResourceName: "people/abc",
|
||||
Names: []struct{ DisplayName string `json:"displayName"` }{{DisplayName: "Alice"}},
|
||||
EmailAddresses: []struct {
|
||||
Value string `json:"value"`
|
||||
}{{Value: "Alice@Example.COM"}},
|
||||
}
|
||||
contact := googlePersonToContact("people/abc", person)
|
||||
if contact.FullName != "Alice" || contact.Email != "alice@example.com" {
|
||||
t.Fatalf("contact: %#v", contact)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphContactToNC(t *testing.T) {
|
||||
item := graphContact{
|
||||
ID: "c1",
|
||||
GivenName: "Bob",
|
||||
Surname: "Smith",
|
||||
MobilePhone: "+33123456789",
|
||||
CompanyName: "Acme",
|
||||
EmailAddresses: []struct {
|
||||
Address string `json:"address"`
|
||||
}{{Address: "bob@example.com"}},
|
||||
}
|
||||
contact := graphContactToNC("c1", item)
|
||||
if contact.FullName != "Bob Smith" || contact.Org != "Acme" {
|
||||
t.Fatalf("contact: %#v", contact)
|
||||
}
|
||||
}
|
||||
476
internal/migration/contacts_import.go
Normal file
476
internal/migration/contacts_import.go
Normal file
@ -0,0 +1,476 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
const migrationContactsBookID = "migration-import"
|
||||
|
||||
type ContactsImporter struct {
|
||||
db *pgxpool.Pool
|
||||
nc *nextcloud.Client
|
||||
client *http.Client
|
||||
userUPN string
|
||||
}
|
||||
|
||||
func NewContactsImporter(db *pgxpool.Pool, nc *nextcloud.Client) *ContactsImporter {
|
||||
return &ContactsImporter{db: db, nc: nc, client: migrationHTTPClient()}
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) WithUserPrincipal(upn string) *ContactsImporter {
|
||||
c.userUPN = strings.TrimSpace(upn)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) WithHTTPClient(client *http.Client) *ContactsImporter {
|
||||
if client != nil {
|
||||
c.client = client
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error {
|
||||
if c.nc == nil {
|
||||
return fmt.Errorf("nextcloud required for contacts migration")
|
||||
}
|
||||
user, err := resolveMigrationUser(ctx, c.db, job.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ncUserID := nextcloud.UserIDFromClaims(user.Email, user.ExternalID)
|
||||
if _, err := c.nc.EnsurePrincipal(ctx, user.Email, user.ExternalID, user.Name); err != nil {
|
||||
return fmt.Errorf("nextcloud user: %w", err)
|
||||
}
|
||||
bookPath := nextcloud.AddressBookPath(ncUserID, migrationContactsBookID)
|
||||
items, err := LoadImportedItemStore(ctx, c.db, job.ID, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delta {
|
||||
if provider == "google" {
|
||||
if syncToken, _ := job.CursorJSON["syncToken"].(string); strings.TrimSpace(syncToken) != "" {
|
||||
return c.importGoogleDelta(ctx, job, accessToken, ncUserID, bookPath, items, update)
|
||||
}
|
||||
} else if deltaLink, _ := job.CursorJSON["deltaLink"].(string); strings.TrimSpace(deltaLink) != "" {
|
||||
return c.importMicrosoftDelta(ctx, job, accessToken, ncUserID, bookPath, deltaLink, items, update)
|
||||
} else {
|
||||
return c.bootstrapMicrosoftDelta(ctx, job, accessToken, ncUserID, bookPath, update)
|
||||
}
|
||||
}
|
||||
|
||||
switch provider {
|
||||
case "google":
|
||||
return c.importGoogleFull(ctx, job, accessToken, ncUserID, bookPath, delta, items, update)
|
||||
default:
|
||||
return c.importMicrosoftFull(ctx, job, accessToken, ncUserID, bookPath, delta, items, update)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) importGoogleFull(ctx context.Context, job *Job, accessToken, ncUserID, bookPath string, captureToken bool, items *ImportedItemStore, update progressUpdater) error {
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
batch := 0
|
||||
|
||||
pageToken, _ := job.CursorJSON["pageToken"].(string)
|
||||
listURL := "https://people.googleapis.com/v1/people/me/connections?pageSize=100&personFields=names,emailAddresses,phoneNumbers,organizations,metadata"
|
||||
if captureToken {
|
||||
listURL += "&requestSyncToken=true"
|
||||
}
|
||||
if pageToken != "" {
|
||||
listURL += "&pageToken=" + url.QueryEscape(pageToken)
|
||||
}
|
||||
body, err := apiGet(ctx, c.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Connections []googlePerson `json:"connections"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
NextSyncToken string `json:"nextSyncToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(listed.Connections) && batch < mailImportBatchSize(); i++ {
|
||||
person := listed.Connections[i]
|
||||
sourceID := strings.TrimSpace(person.ResourceName)
|
||||
if sourceID == "" {
|
||||
sourceID = fmt.Sprintf("google-person-%d", i)
|
||||
}
|
||||
if alreadyImported(items, sourceID) {
|
||||
continue
|
||||
}
|
||||
contact := googlePersonToContact(sourceID, person)
|
||||
if contact.Email == "" && contact.FullName == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := c.nc.CreateContact(ctx, ncUserID, bookPath, contact); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, sourceID, err.Error(), ""); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
imported++
|
||||
batch++
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
|
||||
if listIndex+batch < len(listed.Connections) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if listed.NextPageToken != "" {
|
||||
job.CursorJSON["pageToken"] = listed.NextPageToken
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "pageToken")
|
||||
if listed.NextSyncToken != "" {
|
||||
job.CursorJSON["syncToken"] = listed.NextSyncToken
|
||||
}
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) importGoogleDelta(ctx context.Context, job *Job, accessToken, ncUserID, bookPath string, items *ImportedItemStore, update progressUpdater) error {
|
||||
syncToken, _ := job.CursorJSON["syncToken"].(string)
|
||||
listURL := "https://people.googleapis.com/v1/people/me/connections?syncToken=" + url.QueryEscape(syncToken) +
|
||||
"&personFields=names,emailAddresses,phoneNumbers,organizations,metadata&requestSyncToken=true"
|
||||
body, err := apiGet(ctx, c.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Connections []googlePerson `json:"connections"`
|
||||
NextSyncToken string `json:"nextSyncToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
deltaCount, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deltaUpdated, _ := job.StatsJSON["delta_updated"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
for _, person := range listed.Connections {
|
||||
sourceID := strings.TrimSpace(person.ResourceName)
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
if person.Metadata != nil && person.Metadata.Deleted {
|
||||
contactPath := migrationContactPath(bookPath, "google", sourceID)
|
||||
if err := c.nc.DeleteContact(ctx, ncUserID, contactPath); err != nil && !isDeleteNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if err := items.Unmark(ctx, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
deleted++
|
||||
continue
|
||||
}
|
||||
contact := googlePersonToContact(sourceID, person)
|
||||
if contact.Email == "" && contact.FullName == "" {
|
||||
continue
|
||||
}
|
||||
updated, err := c.upsertContact(ctx, ncUserID, bookPath, "google", sourceID, contact, items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if updated {
|
||||
deltaUpdated++
|
||||
} else if items.Has(sourceID) {
|
||||
deltaCount++
|
||||
} else {
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
}
|
||||
}
|
||||
if listed.NextSyncToken != "" {
|
||||
job.CursorJSON["syncToken"] = listed.NextSyncToken
|
||||
}
|
||||
job.StatsJSON["delta_imported"] = deltaCount
|
||||
job.StatsJSON["delta_updated"] = deltaUpdated
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) importMicrosoftFull(ctx context.Context, job *Job, accessToken, ncUserID, bookPath string, captureDelta bool, items *ImportedItemStore, update progressUpdater) error {
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
batch := 0
|
||||
|
||||
nextLink, _ := job.CursorJSON["nextLink"].(string)
|
||||
listURL := graphMicrosoftURL(c.userUPN, "/contacts?$top=100&$select=id,displayName,givenName,surname,emailAddresses,mobilePhone,businessPhones,companyName")
|
||||
if nextLink != "" {
|
||||
listURL = nextLink
|
||||
}
|
||||
body, err := apiGet(ctx, c.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Value []graphContact `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(listed.Value) && batch < mailImportBatchSize(); i++ {
|
||||
item := listed.Value[i]
|
||||
sourceID := strings.TrimSpace(item.ID)
|
||||
if sourceID == "" {
|
||||
sourceID = fmt.Sprintf("graph-contact-%d", i)
|
||||
}
|
||||
if alreadyImported(items, sourceID) {
|
||||
continue
|
||||
}
|
||||
contact := graphContactToNC(sourceID, item)
|
||||
if contact.Email == "" && contact.FullName == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := c.nc.CreateContact(ctx, ncUserID, bookPath, contact); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, sourceID, err.Error(), ""); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
imported++
|
||||
batch++
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
|
||||
if listIndex+batch < len(listed.Value) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if listed.NextLink != "" {
|
||||
job.CursorJSON["nextLink"] = listed.NextLink
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "nextLink")
|
||||
|
||||
if captureDelta {
|
||||
return c.bootstrapMicrosoftDelta(ctx, job, accessToken, ncUserID, bookPath, update)
|
||||
}
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) bootstrapMicrosoftDelta(ctx context.Context, job *Job, accessToken, ncUserID, bookPath string, update progressUpdater) error {
|
||||
body, err := apiGet(ctx, c.client, graphMicrosoftURL(c.userUPN, "/contacts/delta?$select=id,displayName,givenName,surname,emailAddresses,mobilePhone,businessPhones,companyName"), accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Value []graphContact `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
if listed.DeltaLink != "" {
|
||||
job.CursorJSON["deltaLink"] = listed.DeltaLink
|
||||
job.StatsJSON["phase"] = "delta_ready"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
if listed.NextLink != "" {
|
||||
job.CursorJSON["nextLink"] = listed.NextLink
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) importMicrosoftDelta(ctx context.Context, job *Job, accessToken, ncUserID, bookPath, deltaLink string, items *ImportedItemStore, update progressUpdater) error {
|
||||
body, err := apiGet(ctx, c.client, deltaLink, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Value []graphContact `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
deltaCount, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deltaUpdated, _ := job.StatsJSON["delta_updated"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
for _, item := range listed.Value {
|
||||
sourceID := strings.TrimSpace(item.ID)
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
if item.Removed != nil {
|
||||
contactPath := migrationContactPath(bookPath, "microsoft", sourceID)
|
||||
if err := c.nc.DeleteContact(ctx, ncUserID, contactPath); err != nil && !isDeleteNotFound(err) {
|
||||
return err
|
||||
}
|
||||
if err := items.Unmark(ctx, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
deleted++
|
||||
continue
|
||||
}
|
||||
contact := graphContactToNC(sourceID, item)
|
||||
if contact.Email == "" && contact.FullName == "" {
|
||||
continue
|
||||
}
|
||||
updated, err := c.upsertContact(ctx, ncUserID, bookPath, "microsoft", sourceID, contact, items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if updated {
|
||||
deltaUpdated++
|
||||
} else if items.Has(sourceID) {
|
||||
deltaCount++
|
||||
} else {
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
}
|
||||
}
|
||||
if listed.NextLink != "" {
|
||||
job.CursorJSON["deltaLink"] = listed.NextLink
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
if listed.DeltaLink != "" {
|
||||
job.CursorJSON["deltaLink"] = listed.DeltaLink
|
||||
}
|
||||
job.StatsJSON["delta_imported"] = deltaCount
|
||||
job.StatsJSON["delta_updated"] = deltaUpdated
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
type googlePerson struct {
|
||||
ResourceName string `json:"resourceName"`
|
||||
Metadata *struct {
|
||||
Deleted bool `json:"deleted"`
|
||||
} `json:"metadata"`
|
||||
Names []struct {
|
||||
DisplayName string `json:"displayName"`
|
||||
} `json:"names"`
|
||||
EmailAddresses []struct {
|
||||
Value string `json:"value"`
|
||||
} `json:"emailAddresses"`
|
||||
PhoneNumbers []struct {
|
||||
Value string `json:"value"`
|
||||
} `json:"phoneNumbers"`
|
||||
Organizations []struct {
|
||||
Name string `json:"name"`
|
||||
} `json:"organizations"`
|
||||
}
|
||||
|
||||
type graphContact struct {
|
||||
ID string `json:"id"`
|
||||
Removed *struct {
|
||||
Reason string `json:"reason"`
|
||||
} `json:"@removed"`
|
||||
DisplayName string `json:"displayName"`
|
||||
GivenName string `json:"givenName"`
|
||||
Surname string `json:"surname"`
|
||||
MobilePhone string `json:"mobilePhone"`
|
||||
BusinessPhones []string `json:"businessPhones"`
|
||||
CompanyName string `json:"companyName"`
|
||||
EmailAddresses []struct {
|
||||
Address string `json:"address"`
|
||||
} `json:"emailAddresses"`
|
||||
}
|
||||
|
||||
func googlePersonToContact(sourceID string, p googlePerson) *nextcloud.Contact {
|
||||
name := ""
|
||||
if len(p.Names) > 0 {
|
||||
name = strings.TrimSpace(p.Names[0].DisplayName)
|
||||
}
|
||||
email := ""
|
||||
if len(p.EmailAddresses) > 0 {
|
||||
email = strings.ToLower(strings.TrimSpace(p.EmailAddresses[0].Value))
|
||||
}
|
||||
phone := ""
|
||||
if len(p.PhoneNumbers) > 0 {
|
||||
phone = strings.TrimSpace(p.PhoneNumbers[0].Value)
|
||||
}
|
||||
org := ""
|
||||
if len(p.Organizations) > 0 {
|
||||
org = strings.TrimSpace(p.Organizations[0].Name)
|
||||
}
|
||||
return &nextcloud.Contact{
|
||||
UID: sanitizeMigrationUID("google", sourceID),
|
||||
FullName: name,
|
||||
Email: email,
|
||||
Phone: phone,
|
||||
Org: org,
|
||||
}
|
||||
}
|
||||
|
||||
func graphContactToNC(sourceID string, c graphContact) *nextcloud.Contact {
|
||||
name := strings.TrimSpace(c.DisplayName)
|
||||
if name == "" {
|
||||
name = strings.TrimSpace(strings.TrimSpace(c.GivenName + " " + c.Surname))
|
||||
}
|
||||
email := ""
|
||||
if len(c.EmailAddresses) > 0 {
|
||||
email = strings.ToLower(strings.TrimSpace(c.EmailAddresses[0].Address))
|
||||
}
|
||||
phone := strings.TrimSpace(c.MobilePhone)
|
||||
if phone == "" && len(c.BusinessPhones) > 0 {
|
||||
phone = strings.TrimSpace(c.BusinessPhones[0])
|
||||
}
|
||||
return &nextcloud.Contact{
|
||||
UID: sanitizeMigrationUID("microsoft", sourceID),
|
||||
FullName: name,
|
||||
Email: email,
|
||||
Phone: phone,
|
||||
Org: strings.TrimSpace(c.CompanyName),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ContactsImporter) upsertContact(
|
||||
ctx context.Context,
|
||||
ncUserID, bookPath, provider, sourceID string,
|
||||
contact *nextcloud.Contact,
|
||||
items *ImportedItemStore,
|
||||
) (updated bool, err error) {
|
||||
if alreadyImported(items, sourceID) {
|
||||
contactPath := migrationContactPath(bookPath, provider, sourceID)
|
||||
if _, err := c.nc.UpdateContact(ctx, ncUserID, contactPath, "", contact); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
if _, err := c.nc.CreateContact(ctx, ncUserID, bookPath, contact); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, sourceID, err.Error(), ""); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
if err := items.MarkImported(ctx, sourceID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
142
internal/migration/cutover_dns.go
Normal file
142
internal/migration/cutover_dns.go
Normal file
@ -0,0 +1,142 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
)
|
||||
|
||||
// CutoverConfig controls DNS expectations during migration cutover.
|
||||
type CutoverConfig struct {
|
||||
ExpectedMXHosts []string
|
||||
RequireMX bool
|
||||
}
|
||||
|
||||
// CutoverResult is returned when a migration project enters cutover.
|
||||
type CutoverResult struct {
|
||||
Project Project `json:"project"`
|
||||
DNS hosted.DNSCheckReport `json:"dns"`
|
||||
}
|
||||
|
||||
var ErrCutoverMXNotReady = fmt.Errorf("migration cutover blocked: mx records not pointing to ultimail")
|
||||
|
||||
func (s *Service) PreflightCutoverDNS(ctx context.Context, projectID string, cfg CutoverConfig) (hosted.DNSCheckReport, error) {
|
||||
domainID, err := s.projectDomainID(ctx, projectID)
|
||||
if err != nil {
|
||||
return hosted.DNSCheckReport{}, err
|
||||
}
|
||||
if domainID == "" {
|
||||
return hosted.DNSCheckReport{
|
||||
Warnings: []string{"project has no linked mail domain; mx/txt checks skipped"},
|
||||
}, nil
|
||||
}
|
||||
if s.hosted == nil {
|
||||
return hosted.DNSCheckReport{}, fmt.Errorf("hosted mail service not configured")
|
||||
}
|
||||
_, report, err := s.hosted.CheckDomainDNS(ctx, domainID, cfg.ExpectedMXHosts)
|
||||
return report, err
|
||||
}
|
||||
|
||||
func (s *Service) StartCutover(ctx context.Context, projectID string) (CutoverResult, error) {
|
||||
return s.startCutover(ctx, projectID, s.cutover)
|
||||
}
|
||||
|
||||
func (s *Service) startCutover(ctx context.Context, projectID string, cfg CutoverConfig) (CutoverResult, error) {
|
||||
domainID, err := s.projectDomainID(ctx, projectID)
|
||||
if err != nil {
|
||||
return CutoverResult{}, err
|
||||
}
|
||||
|
||||
var report hosted.DNSCheckReport
|
||||
if domainID != "" {
|
||||
if s.hosted == nil {
|
||||
return CutoverResult{}, fmt.Errorf("hosted mail service not configured")
|
||||
}
|
||||
var checkErr error
|
||||
_, report, checkErr = s.hosted.CheckDomainDNS(ctx, domainID, cfg.ExpectedMXHosts)
|
||||
if checkErr != nil {
|
||||
return CutoverResult{}, checkErr
|
||||
}
|
||||
if report.TXTVerified {
|
||||
if _, _, err := s.hosted.VerifyDomainTXTRecord(ctx, domainID); err != nil {
|
||||
report.Warnings = append(report.Warnings, "auto txt verify: "+err.Error())
|
||||
}
|
||||
}
|
||||
if report.MXVerified && len(cfg.ExpectedMXHosts) > 0 {
|
||||
if _, _, err := s.hosted.VerifyDomainMXRecord(ctx, domainID, cfg.ExpectedMXHosts); err != nil {
|
||||
report.Warnings = append(report.Warnings, "auto mx verify: "+err.Error())
|
||||
} else {
|
||||
report.Warnings = append(report.Warnings, "mx verified and domain marked active")
|
||||
}
|
||||
} else if !report.MXVerified {
|
||||
if cfg.RequireMX {
|
||||
return CutoverResult{DNS: report}, ErrCutoverMXNotReady
|
||||
}
|
||||
report.Warnings = append(report.Warnings, "mx not pointing to ultimail yet; cutover flag set anyway")
|
||||
}
|
||||
} else {
|
||||
report.Warnings = append(report.Warnings, "no domain_id on project; configure mail domain before mx cutover")
|
||||
}
|
||||
|
||||
rawDNS, _ := json.Marshal(report)
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE migration_projects
|
||||
SET status = 'cutover', cutover_at = NOW(), delta_mode = true,
|
||||
cutover_dns_json = $2, updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, projectID, rawDNS)
|
||||
if err != nil {
|
||||
return CutoverResult{}, err
|
||||
}
|
||||
_, _ = s.db.Exec(ctx, `
|
||||
UPDATE migration_jobs
|
||||
SET status = 'pending', error = '', updated_at = NOW()
|
||||
WHERE project_id = $1::uuid AND status = 'completed'
|
||||
`, projectID)
|
||||
|
||||
sc := newProjectScanner()
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT `+projectSelectSQL("")+`
|
||||
FROM migration_projects WHERE id = $1::uuid
|
||||
`, projectID).Scan(sc.targets()...)
|
||||
if err != nil {
|
||||
return CutoverResult{}, err
|
||||
}
|
||||
return CutoverResult{Project: sc.result(), DNS: report}, nil
|
||||
}
|
||||
|
||||
func (s *Service) projectDomainID(ctx context.Context, projectID string) (string, error) {
|
||||
var domainID string
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(domain_id::text, '') FROM migration_projects WHERE id = $1::uuid
|
||||
`, projectID).Scan(&domainID)
|
||||
return domainID, err
|
||||
}
|
||||
|
||||
func ParseCutoverMXHosts(raw string, platformMailDomain, stalwartIMAPHost string) []string {
|
||||
var out []string
|
||||
for _, part := range strings.Split(raw, ",") {
|
||||
part = strings.ToLower(strings.TrimSpace(part))
|
||||
if part != "" {
|
||||
out = append(out, part)
|
||||
}
|
||||
}
|
||||
if len(out) > 0 {
|
||||
return out
|
||||
}
|
||||
platformMailDomain = strings.ToLower(strings.TrimSpace(platformMailDomain))
|
||||
if platformMailDomain != "" {
|
||||
out = append(out, "mail."+platformMailDomain)
|
||||
}
|
||||
stalwartIMAPHost = strings.ToLower(strings.TrimSpace(stalwartIMAPHost))
|
||||
if stalwartIMAPHost != "" && !strings.Contains(stalwartIMAPHost, ".") {
|
||||
return out
|
||||
}
|
||||
if stalwartIMAPHost != "" {
|
||||
out = append(out, stalwartIMAPHost)
|
||||
}
|
||||
return out
|
||||
}
|
||||
37
internal/migration/cutover_dns_test.go
Normal file
37
internal/migration/cutover_dns_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseCutoverMXHostsExplicit(t *testing.T) {
|
||||
got := ParseCutoverMXHosts(" Mail.Acme.Com , mx2.acme.com ", "", "")
|
||||
want := []string{"mail.acme.com", "mx2.acme.com"}
|
||||
if len(got) != len(want) {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
for i := range want {
|
||||
if got[i] != want[i] {
|
||||
t.Fatalf("got %v want %v", got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCutoverMXHostsFallbackPlatform(t *testing.T) {
|
||||
got := ParseCutoverMXHosts("", "Ultisuite.Fr", "")
|
||||
if len(got) != 1 || got[0] != "mail.ultisuite.fr" {
|
||||
t.Fatalf("got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCutoverMXHostsFallbackStalwart(t *testing.T) {
|
||||
got := ParseCutoverMXHosts("", "", "mail.hosted.example.com")
|
||||
if len(got) != 1 || got[0] != "mail.hosted.example.com" {
|
||||
t.Fatalf("got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCutoverMXHostsSkipsBareStalwartHost(t *testing.T) {
|
||||
got := ParseCutoverMXHosts("", "acme.fr", "stalwart")
|
||||
if len(got) != 1 || got[0] != "mail.acme.fr" {
|
||||
t.Fatalf("got %v", got)
|
||||
}
|
||||
}
|
||||
425
internal/migration/drive_delta.go
Normal file
425
internal/migration/drive_delta.go
Normal file
@ -0,0 +1,425 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (d *DriveImporter) hasDriveDeltaCursor(job *Job, provider string) bool {
|
||||
if provider == "google" {
|
||||
token, _ := job.CursorJSON["driveChangeToken"].(string)
|
||||
return strings.TrimSpace(token) != ""
|
||||
}
|
||||
link, _ := job.CursorJSON["driveDeltaLink"].(string)
|
||||
return strings.TrimSpace(link) != ""
|
||||
}
|
||||
|
||||
func (d *DriveImporter) bootstrapDriveDelta(ctx context.Context, accessToken, provider string, cursor map[string]any) error {
|
||||
switch provider {
|
||||
case "google":
|
||||
body, err := apiGet(ctx, d.client, "https://www.googleapis.com/drive/v3/changes/startPageToken?spaces=drive", accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var parsed struct {
|
||||
StartPageToken string `json:"startPageToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return err
|
||||
}
|
||||
if parsed.StartPageToken != "" {
|
||||
cursor["driveChangeToken"] = parsed.StartPageToken
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
link, err := d.walkMicrosoftDriveDelta(ctx, accessToken, graphMicrosoftURL(d.userUPN, "/drive/root/delta?$select=id,name,folder,file,size,parentReference,deleted"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if link != "" {
|
||||
cursor["driveDeltaLink"] = link
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DriveImporter) walkMicrosoftDriveDelta(ctx context.Context, accessToken, listURL string) (string, error) {
|
||||
for listURL != "" {
|
||||
body, err := apiGet(ctx, d.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed.DeltaLink != "" {
|
||||
return parsed.DeltaLink, nil
|
||||
}
|
||||
listURL = parsed.NextLink
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (d *DriveImporter) importDriveDelta(ctx context.Context, job *Job, accessToken, provider, ncUserID, root string, items *ImportedItemStore, update progressUpdater) error {
|
||||
switch provider {
|
||||
case "google":
|
||||
return d.importGoogleDriveDelta(ctx, job, accessToken, ncUserID, root, items, update)
|
||||
default:
|
||||
return d.importMicrosoftDriveDelta(ctx, job, accessToken, ncUserID, root, items, update)
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DriveImporter) importGoogleDriveDelta(ctx context.Context, job *Job, accessToken, ncUserID, root string, items *ImportedItemStore, update progressUpdater) error {
|
||||
pageToken, _ := job.CursorJSON["driveChangeToken"].(string)
|
||||
if pageToken == "" {
|
||||
return fmt.Errorf("google drive delta token missing")
|
||||
}
|
||||
|
||||
listURL := "https://www.googleapis.com/drive/v3/changes?pageSize=100&spaces=drive&includeRemoved=true&fields=" +
|
||||
url.QueryEscape("nextPageToken,newStartPageToken,changes(fileId,removed,file(id,name,mimeType,size,parents,trashed))") +
|
||||
"&pageToken=" + url.QueryEscape(pageToken)
|
||||
|
||||
body, err := apiGet(ctx, d.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var parsed struct {
|
||||
Changes []googleDriveChange `json:"changes"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
NewStartPageToken string `json:"newStartPageToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imported, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
exported, _ := job.StatsJSON["exported"].(float64)
|
||||
skipped, _ := job.StatsJSON["skipped"].(float64)
|
||||
batch := 0
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
|
||||
for i := listIndex; i < len(parsed.Changes) && batch < driveImportBatchSize(); i++ {
|
||||
change := parsed.Changes[i]
|
||||
if change.Removed || change.File == nil || change.File.Trashed {
|
||||
if err := d.deleteDriveItem(ctx, ncUserID, root, items, change.FileID); err != nil {
|
||||
return err
|
||||
}
|
||||
deleted++
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
item := googleFileToDriveItem(*change.File)
|
||||
relPath := d.resolveDriveRelPath(items, item)
|
||||
if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, item, items, &imported, &exported, &skipped, job.StatsJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
batch++
|
||||
}
|
||||
|
||||
job.StatsJSON["delta_imported"] = imported
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
job.StatsJSON["exported"] = exported
|
||||
job.StatsJSON["skipped"] = skipped
|
||||
|
||||
if listIndex+batch < len(parsed.Changes) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if parsed.NextPageToken != "" {
|
||||
job.CursorJSON["driveChangeToken"] = parsed.NextPageToken
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
if parsed.NewStartPageToken != "" {
|
||||
job.CursorJSON["driveChangeToken"] = parsed.NewStartPageToken
|
||||
}
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
type googleDriveChange struct {
|
||||
FileID string `json:"fileId"`
|
||||
Removed bool `json:"removed"`
|
||||
File *googleDriveFile `json:"file"`
|
||||
}
|
||||
|
||||
type googleDriveFile struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Size string `json:"size"`
|
||||
Parents []string `json:"parents"`
|
||||
Trashed bool `json:"trashed"`
|
||||
}
|
||||
|
||||
func googleFileToDriveItem(f googleDriveFile) driveItem {
|
||||
size := int64(0)
|
||||
if f.Size != "" {
|
||||
fmt.Sscan(f.Size, &size)
|
||||
}
|
||||
item := driveItem{
|
||||
ID: f.ID,
|
||||
Name: f.Name,
|
||||
IsFolder: f.MimeType == "application/vnd.google-apps.folder",
|
||||
Size: size,
|
||||
MimeType: f.MimeType,
|
||||
}
|
||||
if len(f.Parents) > 0 {
|
||||
item.ParentID = f.Parents[0]
|
||||
}
|
||||
if item.IsFolder {
|
||||
return item
|
||||
}
|
||||
if exportMime, ext, ok := googleWorkspaceExport(f.MimeType); ok {
|
||||
item.Export = true
|
||||
item.ExportMime = exportMime
|
||||
item.ExportExt = ext
|
||||
item.Name = driveExportFileName(f.Name, ext)
|
||||
} else {
|
||||
item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media"
|
||||
}
|
||||
return item
|
||||
}
|
||||
|
||||
func (d *DriveImporter) importMicrosoftDriveDelta(ctx context.Context, job *Job, accessToken, ncUserID, root string, items *ImportedItemStore, update progressUpdater) error {
|
||||
deltaLink, _ := job.CursorJSON["driveDeltaLink"].(string)
|
||||
if deltaLink == "" {
|
||||
return fmt.Errorf("microsoft drive delta link missing")
|
||||
}
|
||||
|
||||
body, err := apiGet(ctx, d.client, deltaLink, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []graphDriveItem `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imported, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
skipped, _ := job.StatsJSON["skipped"].(float64)
|
||||
batch := 0
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
|
||||
for i := listIndex; i < len(parsed.Value) && batch < driveImportBatchSize(); i++ {
|
||||
item := parsed.Value[i]
|
||||
sourceID := strings.TrimSpace(item.ID)
|
||||
if sourceID == "" {
|
||||
continue
|
||||
}
|
||||
if item.Removed != nil || item.Deleted != nil {
|
||||
if err := d.deleteDriveItem(ctx, ncUserID, root, items, sourceID); err != nil {
|
||||
return err
|
||||
}
|
||||
deleted++
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
driveItem := graphDriveToItem(d.userUPN, item)
|
||||
relPath := d.resolveDriveRelPath(items, driveItem)
|
||||
if err := d.uploadDriveItem(ctx, accessToken, ncUserID, root, relPath, driveItem, items, &imported, nil, &skipped, job.StatsJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
batch++
|
||||
}
|
||||
|
||||
job.StatsJSON["delta_imported"] = imported
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
job.StatsJSON["skipped"] = skipped
|
||||
|
||||
if listIndex+batch < len(parsed.Value) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if parsed.NextLink != "" {
|
||||
job.CursorJSON["driveDeltaLink"] = parsed.NextLink
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
if parsed.DeltaLink != "" {
|
||||
job.CursorJSON["driveDeltaLink"] = parsed.DeltaLink
|
||||
}
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
type graphDriveItem struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Folder *struct{ ChildCount int `json:"childCount"` } `json:"folder"`
|
||||
File *struct{ MimeType string `json:"mimeType"` } `json:"file"`
|
||||
ParentReference *struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"parentReference"`
|
||||
Removed *struct {
|
||||
Reason string `json:"reason"`
|
||||
} `json:"@removed"`
|
||||
Deleted *struct{} `json:"deleted"`
|
||||
}
|
||||
|
||||
func graphDriveToItem(userUPN string, item graphDriveItem) driveItem {
|
||||
out := driveItem{
|
||||
ID: item.ID,
|
||||
Name: item.Name,
|
||||
Size: item.Size,
|
||||
}
|
||||
if item.ParentReference != nil {
|
||||
out.ParentID = item.ParentReference.ID
|
||||
}
|
||||
if item.Folder != nil {
|
||||
out.IsFolder = true
|
||||
return out
|
||||
}
|
||||
mime := ""
|
||||
if item.File != nil {
|
||||
mime = item.File.MimeType
|
||||
}
|
||||
out.MimeType = mime
|
||||
out.Download = graphMicrosoftURL(userUPN, "/drive/items/"+url.PathEscape(item.ID)+"/content")
|
||||
return out
|
||||
}
|
||||
|
||||
func (d *DriveImporter) resolveDriveRelPath(items *ImportedItemStore, item driveItem) string {
|
||||
if stored := items.Path(item.ID); stored != "" {
|
||||
return stored
|
||||
}
|
||||
parentRel := ""
|
||||
if item.ParentID != "" {
|
||||
parentRel = items.Path(item.ParentID)
|
||||
}
|
||||
return path.Join(parentRel, sanitizeDrivePath(item.Name))
|
||||
}
|
||||
|
||||
func (d *DriveImporter) uploadDriveItem(ctx context.Context, accessToken, ncUserID, root, relPath string, item driveItem, items *ImportedItemStore, imported, exported, skipped *float64, stats map[string]any) error {
|
||||
targetPath := path.Join(root, relPath)
|
||||
if item.IsFolder {
|
||||
if err := d.nc.CreateFolder(ctx, ncUserID, targetPath); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(stats, "failed")
|
||||
return nil
|
||||
}
|
||||
if err := items.MarkPath(ctx, item.ID, relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if imported != nil {
|
||||
*imported++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if item.Export {
|
||||
content, contentType, fileName, err := d.downloadGoogleExport(ctx, accessToken, item)
|
||||
if err != nil {
|
||||
if skipped != nil {
|
||||
*skipped++
|
||||
}
|
||||
return items.MarkSkipped(ctx, item.ID, "export: "+err.Error(), relPath)
|
||||
}
|
||||
targetPath = path.Join(path.Dir(targetPath), fileName)
|
||||
relPath = path.Join(path.Dir(relPath), fileName)
|
||||
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(stats, "failed")
|
||||
return nil
|
||||
}
|
||||
if exported != nil {
|
||||
*exported++
|
||||
}
|
||||
if pdfMime, pdfExt, ok := googleSlidesPDFExport(item.MimeType); ok {
|
||||
pdfItem := item
|
||||
pdfItem.ExportMime = pdfMime
|
||||
pdfItem.ExportExt = pdfExt
|
||||
pdfContent, pdfType, pdfName, err := d.downloadGoogleExport(ctx, accessToken, pdfItem)
|
||||
if err == nil {
|
||||
pdfRel := path.Join(path.Dir(relPath), pdfName)
|
||||
pdfTarget := path.Join(root, pdfRel)
|
||||
if err := d.nc.Upload(ctx, ncUserID, pdfTarget, pdfContent, pdfType); err == nil {
|
||||
if err := items.MarkPath(ctx, item.ID+"_pdf", pdfRel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if item.Size > maxDriveFileBytes {
|
||||
if skipped != nil {
|
||||
*skipped++
|
||||
}
|
||||
reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes)
|
||||
return items.MarkSkipped(ctx, item.ID, reason, relPath)
|
||||
}
|
||||
content, contentType, err := d.downloadDriveFile(ctx, accessToken, item)
|
||||
if err != nil {
|
||||
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(stats, "failed")
|
||||
return nil
|
||||
}
|
||||
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil {
|
||||
if markErr := items.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(stats, "failed")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if err := items.MarkImported(ctx, item.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := items.MarkPath(ctx, item.ID, relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
if imported != nil {
|
||||
*imported++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DriveImporter) deleteDriveItem(ctx context.Context, ncUserID, root string, items *ImportedItemStore, fileID string) error {
|
||||
if fileID == "" {
|
||||
return nil
|
||||
}
|
||||
for _, suffix := range []string{"", "_pdf"} {
|
||||
rel := items.Path(fileID + suffix)
|
||||
if rel == "" {
|
||||
continue
|
||||
}
|
||||
target := path.Join(root, rel)
|
||||
if err := d.nc.Delete(ctx, ncUserID, target); err != nil && !isDeleteNotFound(err) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := items.Unmark(ctx, fileID); err != nil {
|
||||
return err
|
||||
}
|
||||
return items.Unmark(ctx, fileID+"_pdf")
|
||||
}
|
||||
|
||||
func isDeleteNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "404") || strings.Contains(msg, "not found")
|
||||
}
|
||||
110
internal/migration/drive_delta_test.go
Normal file
110
internal/migration/drive_delta_test.go
Normal file
@ -0,0 +1,110 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBootstrapDriveDeltaGoogle(t *testing.T) {
|
||||
client := mockGoogleHTTPClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/drive/v3/changes/startPageToken") {
|
||||
_, _ = w.Write([]byte(`{"startPageToken":"start-123"}`))
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
})
|
||||
|
||||
d := NewDriveImporter(nil, nil).WithHTTPClient(client)
|
||||
cursor := map[string]any{}
|
||||
if err := d.bootstrapDriveDelta(context.Background(), "token", "google", cursor); err != nil {
|
||||
t.Fatalf("bootstrap: %v", err)
|
||||
}
|
||||
if cursor["driveChangeToken"] != "start-123" {
|
||||
t.Fatalf("token = %v", cursor["driveChangeToken"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasDriveDeltaCursor(t *testing.T) {
|
||||
d := &DriveImporter{}
|
||||
job := &Job{CursorJSON: map[string]any{}}
|
||||
if d.hasDriveDeltaCursor(job, "google") {
|
||||
t.Fatal("expected no google token")
|
||||
}
|
||||
job.CursorJSON["driveChangeToken"] = "tok"
|
||||
if !d.hasDriveDeltaCursor(job, "google") {
|
||||
t.Fatal("expected google token")
|
||||
}
|
||||
job.CursorJSON = map[string]any{"driveDeltaLink": "https://graph.microsoft.com/delta"}
|
||||
if !d.hasDriveDeltaCursor(job, "microsoft") {
|
||||
t.Fatal("expected microsoft delta link")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleFileToDriveItem(t *testing.T) {
|
||||
folder := googleFileToDriveItem(googleDriveFile{
|
||||
ID: "f1", Name: "Docs", MimeType: "application/vnd.google-apps.folder", Parents: []string{"root"},
|
||||
})
|
||||
if !folder.IsFolder || folder.ParentID != "root" {
|
||||
t.Fatalf("folder: %#v", folder)
|
||||
}
|
||||
|
||||
slides := googleFileToDriveItem(googleDriveFile{
|
||||
ID: "s1", Name: "Deck", MimeType: "application/vnd.google-apps.presentation",
|
||||
})
|
||||
if !slides.Export || slides.ExportExt != ".pptx" {
|
||||
t.Fatalf("slides export: %#v", slides)
|
||||
}
|
||||
|
||||
binary := googleFileToDriveItem(googleDriveFile{
|
||||
ID: "b1", Name: "photo.png", MimeType: "image/png", Size: "1024",
|
||||
})
|
||||
if binary.Export || binary.Download == "" || binary.Size != 1024 {
|
||||
t.Fatalf("binary file: %#v", binary)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphDriveToItem(t *testing.T) {
|
||||
item := graphDriveToItem("", graphDriveItem{
|
||||
ID: "item-1",
|
||||
Name: "report.pdf",
|
||||
Size: 4096,
|
||||
File: &struct{ MimeType string `json:"mimeType"` }{MimeType: "application/pdf"},
|
||||
ParentReference: &struct {
|
||||
ID string `json:"id"`
|
||||
}{ID: "parent-1"},
|
||||
})
|
||||
if item.IsFolder || item.ParentID != "parent-1" || item.Download == "" {
|
||||
t.Fatalf("file item: %#v", item)
|
||||
}
|
||||
|
||||
folder := graphDriveToItem("", graphDriveItem{
|
||||
ID: "dir-1", Name: "Shared", Folder: &struct{ ChildCount int `json:"childCount"` }{ChildCount: 2},
|
||||
})
|
||||
if !folder.IsFolder {
|
||||
t.Fatalf("folder item: %#v", folder)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveDriveRelPath(t *testing.T) {
|
||||
d := &DriveImporter{}
|
||||
store := NewImportedItemStoreMemory()
|
||||
ctx := context.Background()
|
||||
if err := store.MarkPath(ctx, "parent", "Projects"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := store.MarkPath(ctx, "file-1", "Projects/old-name.docx"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if got := d.resolveDriveRelPath(store, driveItem{ID: "file-1", Name: "ignored.docx"}); got != "Projects/old-name.docx" {
|
||||
t.Fatalf("stored path wins: %q", got)
|
||||
}
|
||||
if got := d.resolveDriveRelPath(store, driveItem{ID: "new", Name: "readme.txt", ParentID: "parent"}); got != "Projects/readme.txt" {
|
||||
t.Fatalf("parent path join: %q", got)
|
||||
}
|
||||
if got := d.resolveDriveRelPath(store, driveItem{ID: "orphan", Name: "solo.txt"}); got != "solo.txt" {
|
||||
t.Fatalf("root file: %q", got)
|
||||
}
|
||||
}
|
||||
87
internal/migration/drive_helpers.go
Normal file
87
internal/migration/drive_helpers.go
Normal file
@ -0,0 +1,87 @@
|
||||
package migration
|
||||
|
||||
import "strings"
|
||||
|
||||
// googleWorkspaceExport maps Google native mime types to export targets.
|
||||
func googleWorkspaceExport(mimeType string) (exportMime, ext string, ok bool) {
|
||||
switch mimeType {
|
||||
case "application/vnd.google-apps.document":
|
||||
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document", ".docx", true
|
||||
case "application/vnd.google-apps.spreadsheet":
|
||||
return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", ".xlsx", true
|
||||
case "application/vnd.google-apps.presentation":
|
||||
return "application/vnd.openxmlformats-officedocument.presentationml.presentation", ".pptx", true
|
||||
case "application/vnd.google-apps.drawing":
|
||||
return "application/pdf", ".pdf", true
|
||||
case "application/vnd.google-apps.script":
|
||||
return "application/vnd.google-apps.script+json", ".json", true
|
||||
case "application/vnd.google-apps.site":
|
||||
return "text/plain", ".txt", true
|
||||
default:
|
||||
return "", "", false
|
||||
}
|
||||
}
|
||||
|
||||
// googleSlidesPDFExport returns PDF export for Google Slides (companion copy).
|
||||
func googleSlidesPDFExport(mimeType string) (exportMime, ext string, ok bool) {
|
||||
if mimeType == "application/vnd.google-apps.presentation" {
|
||||
return "application/pdf", ".pdf", true
|
||||
}
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func driveExportFileName(name, ext string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
name = "untitled"
|
||||
}
|
||||
if ext != "" && !strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
|
||||
return name + ext
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
type driveFolderRef struct {
|
||||
ID string
|
||||
Path string
|
||||
}
|
||||
|
||||
func readDriveFolderQueue(cursor map[string]any, provider string) []driveFolderRef {
|
||||
raw, _ := cursor["folderQueue"].([]any)
|
||||
out := make([]driveFolderRef, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
m, _ := item.(map[string]any)
|
||||
if m == nil {
|
||||
continue
|
||||
}
|
||||
id, _ := m["id"].(string)
|
||||
p, _ := m["path"].(string)
|
||||
if id != "" {
|
||||
out = append(out, driveFolderRef{ID: id, Path: p})
|
||||
}
|
||||
}
|
||||
if len(out) == 0 {
|
||||
if provider == "google" {
|
||||
return []driveFolderRef{{ID: "root", Path: ""}}
|
||||
}
|
||||
return []driveFolderRef{{ID: "root", Path: ""}}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func writeDriveFolderQueue(cursor map[string]any, queue []driveFolderRef) {
|
||||
raw := make([]any, 0, len(queue))
|
||||
for _, f := range queue {
|
||||
raw = append(raw, map[string]any{"id": f.ID, "path": f.Path})
|
||||
}
|
||||
cursor["folderQueue"] = raw
|
||||
}
|
||||
|
||||
func enqueueDriveFolder(queue []driveFolderRef, folder driveFolderRef) []driveFolderRef {
|
||||
for _, existing := range queue {
|
||||
if existing.ID == folder.ID {
|
||||
return queue
|
||||
}
|
||||
}
|
||||
return append(queue, folder)
|
||||
}
|
||||
95
internal/migration/drive_helpers_test.go
Normal file
95
internal/migration/drive_helpers_test.go
Normal file
@ -0,0 +1,95 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGoogleWorkspaceExport(t *testing.T) {
|
||||
mime, ext, ok := googleWorkspaceExport("application/vnd.google-apps.document")
|
||||
if !ok || ext != ".docx" || mime == "" {
|
||||
t.Fatalf("document export: mime=%q ext=%q ok=%v", mime, ext, ok)
|
||||
}
|
||||
_, _, ok = googleWorkspaceExport("application/vnd.google-apps.folder")
|
||||
if ok {
|
||||
t.Fatal("folder should not export")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveExportFileName(t *testing.T) {
|
||||
got := driveExportFileName("Report", ".docx")
|
||||
if got != "Report.docx" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
got = driveExportFileName("Report.docx", ".docx")
|
||||
if got != "Report.docx" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDriveFolderQueue(t *testing.T) {
|
||||
cursor := map[string]any{}
|
||||
queue := readDriveFolderQueue(cursor, "google")
|
||||
if len(queue) != 1 || queue[0].ID != "root" {
|
||||
t.Fatalf("default queue: %#v", queue)
|
||||
}
|
||||
queue = enqueueDriveFolder(queue, driveFolderRef{ID: "abc", Path: "Docs"})
|
||||
writeDriveFolderQueue(cursor, queue)
|
||||
if len(readDriveFolderQueue(cursor, "google")) != 2 {
|
||||
t.Fatal("expected 2 folders in queue")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleSlidesPDFExport(t *testing.T) {
|
||||
mime, ext, ok := googleSlidesPDFExport("application/vnd.google-apps.presentation")
|
||||
if !ok || ext != ".pdf" || mime != "application/pdf" {
|
||||
t.Fatalf("slides pdf export: mime=%q ext=%q ok=%v", mime, ext, ok)
|
||||
}
|
||||
_, _, ok = googleSlidesPDFExport("application/vnd.google-apps.document")
|
||||
if ok {
|
||||
t.Fatal("document should not have slides pdf export")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportedPathHelpers(t *testing.T) {
|
||||
store := NewImportedItemStoreMemory()
|
||||
ctx := context.Background()
|
||||
if err := store.MarkPath(ctx, "file-1", "Docs/report.docx"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := store.Path("file-1"); got != "Docs/report.docx" {
|
||||
t.Fatalf("imported path: got %q", got)
|
||||
}
|
||||
if err := store.Unmark(ctx, "file-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got := store.Path("file-1"); got != "" {
|
||||
t.Fatalf("expected empty after unmark, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGoogleCalendarCancelledEvent(t *testing.T) {
|
||||
ev := googleCalendarEvent{ID: "evt1", Status: "cancelled"}.toSourceEvent("primary")
|
||||
if !ev.Deleted {
|
||||
t.Fatal("cancelled event should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphCalendarRemovedEvent(t *testing.T) {
|
||||
removed := struct {
|
||||
Reason string `json:"reason"`
|
||||
}{Reason: "deleted"}
|
||||
ev := graphCalendarEvent{ID: "evt1", Removed: &removed}.toSourceEvent("cal1")
|
||||
if !ev.Deleted {
|
||||
t.Fatal("removed event should be deleted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCalendarSyncTokenHelpers(t *testing.T) {
|
||||
cursor := map[string]any{}
|
||||
setCalendarSyncToken(cursor, "primary", "token-1")
|
||||
tokens := calendarSyncTokens(cursor)
|
||||
if tokens["primary"] != "token-1" {
|
||||
t.Fatalf("got %#v", tokens)
|
||||
}
|
||||
}
|
||||
405
internal/migration/drive_import.go
Normal file
405
internal/migration/drive_import.go
Normal file
@ -0,0 +1,405 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
const maxDriveFileBytes = 25 * 1024 * 1024
|
||||
|
||||
type DriveImporter struct {
|
||||
db *pgxpool.Pool
|
||||
nc *nextcloud.Client
|
||||
client *http.Client
|
||||
userUPN string
|
||||
}
|
||||
|
||||
func NewDriveImporter(db *pgxpool.Pool, nc *nextcloud.Client) *DriveImporter {
|
||||
return &DriveImporter{db: db, nc: nc, client: migrationHTTPClient()}
|
||||
}
|
||||
|
||||
func (d *DriveImporter) WithUserPrincipal(upn string) *DriveImporter {
|
||||
d.userUPN = strings.TrimSpace(upn)
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DriveImporter) WithHTTPClient(c *http.Client) *DriveImporter {
|
||||
if c != nil {
|
||||
d.client = c
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
func (d *DriveImporter) ImportBatch(ctx context.Context, job *Job, accessToken, provider string, delta bool, update progressUpdater) error {
|
||||
if d.nc == nil {
|
||||
return fmt.Errorf("nextcloud required for drive migration")
|
||||
}
|
||||
user, err := resolveMigrationUser(ctx, d.db, job.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ncUserID := nextcloud.UserIDFromClaims(user.Email, user.ExternalID)
|
||||
if _, err := d.nc.EnsurePrincipal(ctx, user.Email, user.ExternalID, user.Name); err != nil {
|
||||
return fmt.Errorf("nextcloud user: %w", err)
|
||||
}
|
||||
root := fmt.Sprintf("/Migration/%s", provider)
|
||||
_ = d.nc.CreateFolder(ctx, ncUserID, root)
|
||||
store, err := LoadImportedItemStore(ctx, d.db, job.ID, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delta && d.hasDriveDeltaCursor(job, provider) {
|
||||
return d.importDriveDelta(ctx, job, accessToken, provider, ncUserID, root, store, update)
|
||||
}
|
||||
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
skipped, _ := job.StatsJSON["skipped"].(float64)
|
||||
exported, _ := job.StatsJSON["exported"].(float64)
|
||||
batch := 0
|
||||
|
||||
queue := readDriveFolderQueue(job.CursorJSON, provider)
|
||||
folderIndex := int(jsonNumber(job.CursorJSON["folderIndex"]))
|
||||
if folderIndex >= len(queue) {
|
||||
if delta && !d.hasDriveDeltaCursor(job, provider) {
|
||||
if err := d.bootstrapDriveDelta(ctx, accessToken, provider, job.CursorJSON); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.StatsJSON["skipped"] = skipped
|
||||
job.StatsJSON["exported"] = exported
|
||||
if delta && d.hasDriveDeltaCursor(job, provider) {
|
||||
job.StatsJSON["phase"] = "delta_ready"
|
||||
} else {
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
}
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
current := queue[folderIndex]
|
||||
folderItems, nextCursor, subfolders, err := d.listDriveFolderItems(ctx, accessToken, provider, current, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(folderItems) && batch < driveImportBatchSize(); i++ {
|
||||
item := folderItems[i]
|
||||
if alreadyImported(store, item.ID) {
|
||||
continue
|
||||
}
|
||||
relPath := path.Join(current.Path, sanitizeDrivePath(item.Name))
|
||||
targetPath := path.Join(root, relPath)
|
||||
if item.IsFolder {
|
||||
if err := d.nc.CreateFolder(ctx, ncUserID, targetPath); err != nil {
|
||||
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := store.MarkPath(ctx, item.ID, relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
queue = enqueueDriveFolder(queue, driveFolderRef{ID: item.ID, Path: relPath})
|
||||
} else {
|
||||
if item.Export {
|
||||
content, contentType, fileName, err := d.downloadGoogleExport(ctx, accessToken, item)
|
||||
if err != nil {
|
||||
skipped++
|
||||
if err := store.MarkSkipped(ctx, item.ID, "export: "+err.Error(), relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
targetPath = path.Join(path.Dir(targetPath), fileName)
|
||||
relPath = path.Join(path.Dir(relPath), fileName)
|
||||
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil {
|
||||
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
exported++
|
||||
if pdfMime, pdfExt, ok := googleSlidesPDFExport(item.MimeType); ok {
|
||||
pdfItem := item
|
||||
pdfItem.ExportMime = pdfMime
|
||||
pdfItem.ExportExt = pdfExt
|
||||
if pdfContent, pdfType, pdfName, err := d.downloadGoogleExport(ctx, accessToken, pdfItem); err == nil {
|
||||
pdfRel := path.Join(path.Dir(relPath), pdfName)
|
||||
pdfTarget := path.Join(root, pdfRel)
|
||||
if err := d.nc.Upload(ctx, ncUserID, pdfTarget, pdfContent, pdfType); err == nil {
|
||||
if err := store.MarkPath(ctx, item.ID+"_pdf", pdfRel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if item.Size > maxDriveFileBytes {
|
||||
skipped++
|
||||
reason := fmt.Sprintf("file exceeds %d byte limit", maxDriveFileBytes)
|
||||
if err := store.MarkSkipped(ctx, item.ID, reason, relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
content, contentType, err := d.downloadDriveFile(ctx, accessToken, item)
|
||||
if err != nil {
|
||||
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := d.nc.Upload(ctx, ncUserID, targetPath, content, contentType); err != nil {
|
||||
if markErr := store.MarkFailed(ctx, item.ID, err.Error(), relPath); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
if err := store.MarkImported(ctx, item.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if !item.IsFolder {
|
||||
if err := store.MarkPath(ctx, item.ID, relPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
imported++
|
||||
batch++
|
||||
}
|
||||
|
||||
for _, sub := range subfolders {
|
||||
relPath := path.Join(current.Path, sanitizeDrivePath(sub.Name))
|
||||
queue = enqueueDriveFolder(queue, driveFolderRef{ID: sub.ID, Path: relPath})
|
||||
}
|
||||
writeDriveFolderQueue(job.CursorJSON, queue)
|
||||
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.StatsJSON["skipped"] = skipped
|
||||
job.StatsJSON["exported"] = exported
|
||||
|
||||
if listIndex+batch < len(folderItems) {
|
||||
job.CursorJSON["listIndex"] = float64(listIndex + batch)
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
|
||||
if nextCursor != "" {
|
||||
if provider == "google" {
|
||||
job.CursorJSON["pageToken"] = nextCursor
|
||||
} else {
|
||||
job.CursorJSON["nextLink"] = nextCursor
|
||||
}
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "pageToken")
|
||||
delete(job.CursorJSON, "nextLink")
|
||||
|
||||
job.CursorJSON["folderIndex"] = float64(folderIndex + 1)
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
type driveItem struct {
|
||||
ID string
|
||||
Name string
|
||||
ParentID string
|
||||
IsFolder bool
|
||||
Size int64
|
||||
MimeType string
|
||||
Download string
|
||||
Export bool
|
||||
ExportMime string
|
||||
ExportExt string
|
||||
}
|
||||
|
||||
type driveSubfolder struct {
|
||||
ID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func (d *DriveImporter) listDriveFolderItems(ctx context.Context, accessToken, provider string, folder driveFolderRef, cursor map[string]any) ([]driveItem, string, []driveSubfolder, error) {
|
||||
switch provider {
|
||||
case "google":
|
||||
pageToken, _ := cursor["pageToken"].(string)
|
||||
q := url.QueryEscape("'" + folder.ID + "' in parents and trashed=false")
|
||||
listURL := "https://www.googleapis.com/drive/v3/files?pageSize=100&fields=nextPageToken,files(id,name,mimeType,size)&q=" + q
|
||||
if pageToken != "" {
|
||||
listURL += "&pageToken=" + url.QueryEscape(pageToken)
|
||||
}
|
||||
body, err := apiGet(ctx, d.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
var parsed struct {
|
||||
Files []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
MimeType string `json:"mimeType"`
|
||||
Size string `json:"size"`
|
||||
} `json:"files"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
out := make([]driveItem, 0, len(parsed.Files))
|
||||
for _, f := range parsed.Files {
|
||||
size := int64(0)
|
||||
if f.Size != "" {
|
||||
fmt.Sscan(f.Size, &size)
|
||||
}
|
||||
item := driveItem{
|
||||
ID: f.ID,
|
||||
Name: f.Name,
|
||||
IsFolder: f.MimeType == "application/vnd.google-apps.folder",
|
||||
Size: size,
|
||||
MimeType: f.MimeType,
|
||||
}
|
||||
if item.IsFolder {
|
||||
out = append(out, item)
|
||||
continue
|
||||
}
|
||||
if exportMime, ext, ok := googleWorkspaceExport(f.MimeType); ok {
|
||||
item.Export = true
|
||||
item.ExportMime = exportMime
|
||||
item.ExportExt = ext
|
||||
item.Name = driveExportFileName(f.Name, ext)
|
||||
} else {
|
||||
item.Download = "https://www.googleapis.com/drive/v3/files/" + url.PathEscape(f.ID) + "?alt=media"
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
return out, parsed.NextPageToken, nil, nil
|
||||
default:
|
||||
nextLink, _ := cursor["nextLink"].(string)
|
||||
var listURL string
|
||||
if folder.ID == "root" {
|
||||
listURL = graphMicrosoftURL(d.userUPN, "/drive/root/children?$top=100&$select=id,name,folder,file,size")
|
||||
} else {
|
||||
listURL = graphMicrosoftURL(d.userUPN, "/drive/items/"+url.PathEscape(folder.ID)+"/children?$top=100&$select=id,name,folder,file,size")
|
||||
}
|
||||
if nextLink != "" {
|
||||
listURL = nextLink
|
||||
}
|
||||
body, err := apiGet(ctx, d.client, listURL, accessToken)
|
||||
if err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Folder *struct {
|
||||
ChildCount int `json:"childCount"`
|
||||
} `json:"folder"`
|
||||
File *struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
} `json:"file"`
|
||||
Size int64 `json:"size"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return nil, "", nil, err
|
||||
}
|
||||
out := make([]driveItem, 0, len(parsed.Value))
|
||||
var subs []driveSubfolder
|
||||
for _, f := range parsed.Value {
|
||||
if f.Folder != nil {
|
||||
out = append(out, driveItem{ID: f.ID, Name: f.Name, IsFolder: true})
|
||||
if f.Folder.ChildCount > 0 {
|
||||
subs = append(subs, driveSubfolder{ID: f.ID, Name: f.Name})
|
||||
}
|
||||
continue
|
||||
}
|
||||
mime := ""
|
||||
if f.File != nil {
|
||||
mime = f.File.MimeType
|
||||
}
|
||||
out = append(out, driveItem{
|
||||
ID: f.ID,
|
||||
Name: f.Name,
|
||||
Size: f.Size,
|
||||
MimeType: mime,
|
||||
Download: graphMicrosoftURL(d.userUPN, "/drive/items/"+url.PathEscape(f.ID)+"/content"),
|
||||
})
|
||||
}
|
||||
return out, parsed.NextLink, subs, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DriveImporter) downloadDriveFile(ctx context.Context, accessToken string, item driveItem) (io.ReadCloser, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, item.Download, nil)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
resp, err := migrationDo(ctx, d.client, req)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("download %s: %w", item.Name, err)
|
||||
}
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = item.MimeType
|
||||
}
|
||||
if contentType == "" {
|
||||
contentType = "application/octet-stream"
|
||||
}
|
||||
return resp.Body, contentType, nil
|
||||
}
|
||||
|
||||
func (d *DriveImporter) downloadGoogleExport(ctx context.Context, accessToken string, item driveItem) (io.ReadCloser, string, string, error) {
|
||||
exportURL := fmt.Sprintf(
|
||||
"https://www.googleapis.com/drive/v3/files/%s/export?mimeType=%s",
|
||||
url.PathEscape(item.ID),
|
||||
url.QueryEscape(item.ExportMime),
|
||||
)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, exportURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
resp, err := migrationDo(ctx, d.client, req)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("export %s: %w", item.Name, err)
|
||||
}
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
contentType = item.ExportMime
|
||||
}
|
||||
return resp.Body, contentType, driveExportFileName(item.Name, item.ExportExt), nil
|
||||
}
|
||||
|
||||
func sanitizeDrivePath(name string) string {
|
||||
name = strings.TrimSpace(name)
|
||||
name = strings.ReplaceAll(name, "/", "-")
|
||||
name = strings.ReplaceAll(name, "\\", "-")
|
||||
if name == "" {
|
||||
return "untitled"
|
||||
}
|
||||
return name
|
||||
}
|
||||
285
internal/migration/gmail_attachments.go
Normal file
285
internal/migration/gmail_attachments.go
Normal file
@ -0,0 +1,285 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"mime"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
)
|
||||
|
||||
type gmailAttachmentPart struct {
|
||||
Filename string
|
||||
MimeType string
|
||||
ContentID string
|
||||
IsInline bool
|
||||
Size int
|
||||
BodyData string
|
||||
AttachmentID string
|
||||
}
|
||||
|
||||
func extractGmailAttachmentParts(payload gmailPayload) []gmailAttachmentPart {
|
||||
var out []gmailAttachmentPart
|
||||
var walk func(gmailPayload)
|
||||
walk = func(node gmailPayload) {
|
||||
headers := gmailHeaderMap(node.Headers)
|
||||
mimeType := strings.TrimSpace(node.MimeType)
|
||||
if strings.HasPrefix(strings.ToLower(mimeType), "multipart/") {
|
||||
for _, part := range node.Parts {
|
||||
walk(part)
|
||||
}
|
||||
return
|
||||
}
|
||||
if isGmailAttachmentPart(mimeType, headers, node.Body) {
|
||||
out = append(out, buildGmailAttachmentPart(mimeType, headers, node.Body))
|
||||
}
|
||||
}
|
||||
walk(payload)
|
||||
if len(out) > limits.MaxAttachmentsPerMessage {
|
||||
return out[:limits.MaxAttachmentsPerMessage]
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func gmailHeaderMap(headers []gmailHeader) map[string]string {
|
||||
out := make(map[string]string, len(headers))
|
||||
for _, h := range headers {
|
||||
key := strings.ToLower(strings.TrimSpace(h.Name))
|
||||
if key != "" && out[key] == "" {
|
||||
out[key] = h.Value
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func isGmailAttachmentPart(mimeType string, headers map[string]string, body gmailBody) bool {
|
||||
if body.AttachmentID != "" {
|
||||
return true
|
||||
}
|
||||
disposition := strings.ToLower(headers["content-disposition"])
|
||||
if strings.HasPrefix(disposition, "attachment") {
|
||||
return true
|
||||
}
|
||||
contentID := normalizeGmailContentID(headers["content-id"])
|
||||
filename := gmailPartFilename(headers)
|
||||
if strings.HasPrefix(disposition, "inline") && (filename != "" || contentID != "") {
|
||||
return true
|
||||
}
|
||||
if contentID != "" && body.Data != "" {
|
||||
return true
|
||||
}
|
||||
lower := strings.ToLower(strings.TrimSpace(mimeType))
|
||||
if lower == "text/plain" || lower == "text/html" || lower == "" {
|
||||
return false
|
||||
}
|
||||
return filename != "" && body.Data != ""
|
||||
}
|
||||
|
||||
func buildGmailAttachmentPart(mimeType string, headers map[string]string, body gmailBody) gmailAttachmentPart {
|
||||
disposition := strings.ToLower(headers["content-disposition"])
|
||||
contentID := normalizeGmailContentID(headers["content-id"])
|
||||
filename := gmailPartFilename(headers)
|
||||
isInline := strings.HasPrefix(disposition, "inline") || (contentID != "" && !strings.HasPrefix(disposition, "attachment"))
|
||||
if filename == "" {
|
||||
filename = inlineGmailAttachmentFilename(contentID, mimeType)
|
||||
}
|
||||
size := body.Size
|
||||
if size <= 0 && body.Data != "" {
|
||||
size = len(body.Data)
|
||||
}
|
||||
return gmailAttachmentPart{
|
||||
Filename: filename,
|
||||
MimeType: normalizeAttachmentMimeType(mimeType),
|
||||
ContentID: contentID,
|
||||
IsInline: isInline,
|
||||
Size: size,
|
||||
BodyData: body.Data,
|
||||
AttachmentID: body.AttachmentID,
|
||||
}
|
||||
}
|
||||
|
||||
func gmailPartFilename(headers map[string]string) string {
|
||||
if name := strings.TrimSpace(headers["filename"]); name != "" {
|
||||
return name
|
||||
}
|
||||
disposition := headers["content-disposition"]
|
||||
if disposition == "" {
|
||||
return ""
|
||||
}
|
||||
_, params, err := mime.ParseMediaType(disposition)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(params["filename"])
|
||||
}
|
||||
|
||||
func normalizeGmailContentID(raw string) string {
|
||||
return strings.Trim(raw, "<> \t")
|
||||
}
|
||||
|
||||
func inlineGmailAttachmentFilename(contentID, mimeType string) string {
|
||||
base := "inline"
|
||||
if contentID != "" {
|
||||
base = strings.Map(func(r rune) rune {
|
||||
switch r {
|
||||
case '<', '>', '/', '\\', ':', '"', '\'', '?', '*':
|
||||
return '_'
|
||||
default:
|
||||
return r
|
||||
}
|
||||
}, contentID)
|
||||
}
|
||||
ext := extensionFromMimeType(mimeType)
|
||||
if ext != "" && !strings.HasSuffix(strings.ToLower(base), ext) {
|
||||
return base + ext
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
func extensionFromMimeType(mimeType string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(strings.Split(mimeType, ";")[0])) {
|
||||
case "image/jpeg", "image/jpg":
|
||||
return ".jpg"
|
||||
case "image/png":
|
||||
return ".png"
|
||||
case "image/gif":
|
||||
return ".gif"
|
||||
case "image/webp":
|
||||
return ".webp"
|
||||
case "application/pdf":
|
||||
return ".pdf"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeAttachmentMimeType(mimeType string) string {
|
||||
mimeType = strings.TrimSpace(mimeType)
|
||||
if mimeType == "" {
|
||||
return "application/octet-stream"
|
||||
}
|
||||
mediaType, _, err := mime.ParseMediaType(mimeType)
|
||||
if err != nil || mediaType == "" {
|
||||
return mimeType
|
||||
}
|
||||
return mediaType
|
||||
}
|
||||
|
||||
func (g *GmailImporter) storeGmailAttachments(
|
||||
ctx context.Context,
|
||||
userID, messageID, gmailID, accessToken string,
|
||||
payload gmailPayload,
|
||||
messageExisted bool,
|
||||
) error {
|
||||
if g.storage == nil {
|
||||
return nil
|
||||
}
|
||||
parts := extractGmailAttachmentParts(payload)
|
||||
if len(parts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var existingCount int
|
||||
var existingTotal int64
|
||||
if err := g.db.QueryRow(ctx, `
|
||||
SELECT COUNT(*)::int, COALESCE(SUM(size), 0)::bigint
|
||||
FROM attachments WHERE message_id = $1::uuid
|
||||
`, messageID).Scan(&existingCount, &existingTotal); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stored := 0
|
||||
for _, part := range parts {
|
||||
if messageExisted && gmailAttachmentExists(ctx, g.db, messageID, part) {
|
||||
continue
|
||||
}
|
||||
data, err := g.loadGmailAttachmentData(ctx, accessToken, gmailID, part)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if err := limits.ValidateAttachmentSize(int64(len(data))); err != nil {
|
||||
continue
|
||||
}
|
||||
if err := limits.ValidateAttachmentQuota(existingCount, existingTotal, int64(len(data))); err != nil {
|
||||
break
|
||||
}
|
||||
|
||||
objectKey := storage.MessageObjectKey(userID, messageID, part.Filename)
|
||||
if err := g.storage.Put(ctx, objectKey, bytes.NewReader(data), int64(len(data)), part.MimeType); err != nil {
|
||||
return err
|
||||
}
|
||||
bucket := g.attachBucket
|
||||
if bucket == "" {
|
||||
bucket = "mail-attachments"
|
||||
}
|
||||
_, err = g.db.Exec(ctx, `
|
||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5, $6, $7, $8, 'skipped')
|
||||
`, messageID, part.Filename, part.MimeType, len(data), bucket, objectKey, part.ContentID, part.IsInline)
|
||||
if err != nil {
|
||||
_ = g.storage.Delete(ctx, objectKey)
|
||||
return err
|
||||
}
|
||||
existingCount++
|
||||
existingTotal += int64(len(data))
|
||||
stored++
|
||||
}
|
||||
|
||||
if stored > 0 {
|
||||
_, err := g.db.Exec(ctx, `UPDATE messages SET has_attachments = true, updated_at = NOW() WHERE id = $1::uuid`, messageID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *GmailImporter) loadGmailAttachmentData(ctx context.Context, accessToken, gmailID string, part gmailAttachmentPart) ([]byte, error) {
|
||||
if part.AttachmentID != "" {
|
||||
url := fmt.Sprintf(
|
||||
"https://gmail.googleapis.com/gmail/v1/users/me/messages/%s/attachments/%s",
|
||||
gmailID,
|
||||
part.AttachmentID,
|
||||
)
|
||||
raw, err := g.apiGet(ctx, url, accessToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var parsed struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed.Data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(decodeGmailBody(parsed.Data)), nil
|
||||
}
|
||||
if part.BodyData == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(decodeGmailBody(part.BodyData)), nil
|
||||
}
|
||||
|
||||
func gmailAttachmentExists(ctx context.Context, db *pgxpool.Pool, messageID string, part gmailAttachmentPart) bool {
|
||||
var count int
|
||||
if part.ContentID != "" {
|
||||
_ = db.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM attachments
|
||||
WHERE message_id = $1::uuid AND (content_id = $2 OR filename = $3)
|
||||
`, messageID, part.ContentID, part.Filename).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
_ = db.QueryRow(ctx, `
|
||||
SELECT COUNT(*) FROM attachments WHERE message_id = $1::uuid AND filename = $2
|
||||
`, messageID, part.Filename).Scan(&count)
|
||||
return count > 0
|
||||
}
|
||||
55
internal/migration/gmail_attachments_test.go
Normal file
55
internal/migration/gmail_attachments_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestExtractGmailAttachmentPartsInlineData(t *testing.T) {
|
||||
payload := gmailPayload{
|
||||
MimeType: "multipart/mixed",
|
||||
Parts: []gmailPayload{
|
||||
{
|
||||
MimeType: "text/plain",
|
||||
Body: gmailBody{Data: "dGV4dA=="},
|
||||
},
|
||||
{
|
||||
MimeType: "image/png",
|
||||
Headers: []gmailHeader{
|
||||
{Name: "Content-Disposition", Value: `attachment; filename="logo.png"`},
|
||||
},
|
||||
Body: gmailBody{Data: "aW1n", Size: 3},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts := extractGmailAttachmentParts(payload)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 attachment, got %d", len(parts))
|
||||
}
|
||||
if parts[0].Filename != "logo.png" || parts[0].MimeType != "image/png" {
|
||||
t.Fatalf("unexpected part: %#v", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractGmailAttachmentPartsAttachmentID(t *testing.T) {
|
||||
payload := gmailPayload{
|
||||
MimeType: "multipart/mixed",
|
||||
Parts: []gmailPayload{
|
||||
{
|
||||
MimeType: "application/pdf",
|
||||
Headers: []gmailHeader{
|
||||
{Name: "Content-Disposition", Value: `attachment; filename="report.pdf"`},
|
||||
},
|
||||
Body: gmailBody{AttachmentID: "ANGjdJ_test", Size: 4096},
|
||||
},
|
||||
},
|
||||
}
|
||||
parts := extractGmailAttachmentParts(payload)
|
||||
if len(parts) != 1 || parts[0].AttachmentID != "ANGjdJ_test" {
|
||||
t.Fatalf("unexpected parts: %#v", parts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodeGmailBodyBytes(t *testing.T) {
|
||||
got := decodeGmailBody("aGVsbG8=")
|
||||
if got != "hello" {
|
||||
t.Fatalf("decode = %q", got)
|
||||
}
|
||||
}
|
||||
30
internal/migration/gmail_delta_test.go
Normal file
30
internal/migration/gmail_delta_test.go
Normal file
@ -0,0 +1,30 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsGmailHistoryNotFound(t *testing.T) {
|
||||
if !isGmailHistoryNotFound(fmtError("gmail api 404 Not Found: historyId")) {
|
||||
t.Fatal("expected history 404")
|
||||
}
|
||||
if isGmailHistoryNotFound(fmtError("gmail api 403 Forbidden")) {
|
||||
t.Fatal("403 is not history not found")
|
||||
}
|
||||
}
|
||||
|
||||
func fmtError(msg string) error {
|
||||
return &simpleError{msg: msg}
|
||||
}
|
||||
|
||||
type simpleError struct{ msg string }
|
||||
|
||||
func (e *simpleError) Error() string { return e.msg }
|
||||
|
||||
func TestGraphMessageRemoved(t *testing.T) {
|
||||
removed := struct {
|
||||
Reason string `json:"reason"`
|
||||
}{Reason: "deleted"}
|
||||
msg := graphMessage{ID: "msg-1", Removed: &removed}
|
||||
if msg.Removed == nil {
|
||||
t.Fatal("expected removed marker")
|
||||
}
|
||||
}
|
||||
727
internal/migration/gmail_import.go
Normal file
727
internal/migration/gmail_import.go
Normal file
@ -0,0 +1,727 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/threading"
|
||||
)
|
||||
|
||||
type GmailImporter struct {
|
||||
db *pgxpool.Pool
|
||||
client *http.Client
|
||||
storage *storage.Client
|
||||
attachBucket string
|
||||
}
|
||||
|
||||
func NewGmailImporter(db *pgxpool.Pool) *GmailImporter {
|
||||
return &GmailImporter{
|
||||
db: db,
|
||||
client: &http.Client{Timeout: 90 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GmailImporter) WithHTTPClient(c *http.Client) *GmailImporter {
|
||||
if c != nil {
|
||||
g.client = c
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GmailImporter) WithStorage(storage *storage.Client, bucket string) *GmailImporter {
|
||||
g.storage = storage
|
||||
g.attachBucket = strings.TrimSpace(bucket)
|
||||
return g
|
||||
}
|
||||
|
||||
type gmailMessage struct {
|
||||
ID string `json:"id"`
|
||||
ThreadID string `json:"threadId"`
|
||||
LabelIDs []string `json:"labelIds"`
|
||||
Snippet string `json:"snippet"`
|
||||
InternalDate string `json:"internalDate"`
|
||||
Payload gmailPayload `json:"payload"`
|
||||
}
|
||||
|
||||
type gmailPayload struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
Headers []gmailHeader `json:"headers"`
|
||||
Body gmailBody `json:"body"`
|
||||
Parts []gmailPayload `json:"parts"`
|
||||
}
|
||||
|
||||
type gmailHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
type gmailBody struct {
|
||||
Size int `json:"size"`
|
||||
Data string `json:"data"`
|
||||
AttachmentID string `json:"attachmentId"`
|
||||
}
|
||||
|
||||
func (g *GmailImporter) ImportBatch(
|
||||
ctx context.Context,
|
||||
job *Job,
|
||||
accessToken string,
|
||||
delta bool,
|
||||
update func(status string, cursor, stats map[string]any, jobErr string) error,
|
||||
) error {
|
||||
accountID, err := g.resolveMailAccountID(ctx, job.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureDefaultMailFolders(ctx, g.db, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := LoadImportedItemStore(ctx, g.db, job.ID, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delta {
|
||||
historyID, _ := job.CursorJSON["historyId"].(string)
|
||||
if historyID != "" {
|
||||
more, err := g.importHistory(ctx, job, accessToken, accountID, historyID, items)
|
||||
if err != nil {
|
||||
if isGmailHistoryNotFound(err) {
|
||||
if newID, fetchErr := g.fetchHistoryID(ctx, accessToken); fetchErr == nil && newID != "" {
|
||||
job.CursorJSON["historyId"] = newID
|
||||
delete(job.CursorJSON, "historyPageToken")
|
||||
job.StatsJSON["history_reset"] = float64(1)
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
if more {
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
}
|
||||
|
||||
pageToken, _ := job.CursorJSON["pageToken"].(string)
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
|
||||
listURL := "https://gmail.googleapis.com/gmail/v1/users/me/messages?maxResults=100"
|
||||
if pageToken != "" {
|
||||
listURL += "&pageToken=" + pageToken
|
||||
}
|
||||
body, err := g.apiGet(ctx, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Messages []struct{ ID string `json:"id"` } `json:"messages"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
ResultSizeEstimate int `json:"resultSizeEstimate"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
if listed.ResultSizeEstimate > 0 {
|
||||
job.StatsJSON["estimated_total"] = float64(listed.ResultSizeEstimate)
|
||||
}
|
||||
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
batch := 0
|
||||
for i := listIndex; i < len(listed.Messages) && batch < mailImportBatchSize(); i++ {
|
||||
msgID := listed.Messages[i].ID
|
||||
if alreadyImported(items, msgID) {
|
||||
listIndex = i + 1
|
||||
continue
|
||||
}
|
||||
created, err := g.importOne(ctx, accessToken, job.UserID, accountID, msgID)
|
||||
if err != nil {
|
||||
if markErr := items.MarkFailed(ctx, msgID, err.Error(), ""); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
listIndex = i + 1
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, msgID); err != nil {
|
||||
return err
|
||||
}
|
||||
if created {
|
||||
imported++
|
||||
}
|
||||
batch++
|
||||
listIndex = i + 1
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.CursorJSON["listIndex"] = float64(listIndex)
|
||||
|
||||
if listIndex < len(listed.Messages) {
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
// page complete
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
if listed.NextPageToken != "" {
|
||||
job.CursorJSON["pageToken"] = listed.NextPageToken
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
delete(job.CursorJSON, "pageToken")
|
||||
if hid, err := g.fetchHistoryID(ctx, accessToken); err == nil && hid != "" {
|
||||
job.CursorJSON["historyId"] = hid
|
||||
}
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func jsonNumber(v any) float64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
if f, ok := v.(float64); ok {
|
||||
return f
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (g *GmailImporter) importOne(ctx context.Context, accessToken, userID, accountID, gmailID string) (bool, error) {
|
||||
raw, err := g.apiGet(ctx, "https://gmail.googleapis.com/gmail/v1/users/me/messages/"+gmailID+"?format=full", accessToken)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var msg gmailMessage
|
||||
if err := json.Unmarshal(raw, &msg); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
remoteName, folderType := primaryGmailFolder(msg.LabelIDs)
|
||||
folderID, err := ensureMailFolder(ctx, g.db, accountID, displayFolderName(remoteName, folderType), remoteName, folderType)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
headers := indexHeaders(msg.Payload)
|
||||
subject := headers["subject"]
|
||||
fromJSON := parseAddressListJSON(headers["from"])
|
||||
toJSON := parseAddressListJSON(headers["to"])
|
||||
ccJSON := parseAddressListJSON(headers["cc"])
|
||||
replyToJSON := parseAddressListJSON(headers["reply-to"])
|
||||
rfcID := threading.NormalizeMessageID(headers["message-id"])
|
||||
if rfcID == "" {
|
||||
rfcID = threading.NormalizeMessageID("<gmail-" + gmailID + "@ultimail.migrated>")
|
||||
}
|
||||
inReplyTo := threading.NormalizeMessageID(headers["in-reply-to"])
|
||||
references := parseReferences(headers["references"])
|
||||
bodyText, bodyHTML := extractGmailBodies(msg.Payload)
|
||||
snippet := strings.TrimSpace(msg.Snippet)
|
||||
if snippet == "" {
|
||||
snippet = truncateRunes(bodyText, 200)
|
||||
}
|
||||
|
||||
date := parseMailDate(headers["date"])
|
||||
if msg.InternalDate != "" {
|
||||
if ms, err := parseInternalDate(msg.InternalDate); err == nil {
|
||||
date = ms
|
||||
}
|
||||
}
|
||||
|
||||
flags := gmailFlags(msg.LabelIDs)
|
||||
labels := gmailUserLabels(msg.LabelIDs)
|
||||
|
||||
uid := gmailUID(gmailID)
|
||||
var messageID string
|
||||
var existed bool
|
||||
_ = g.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM messages WHERE folder_id = $1 AND uid = $2)`, folderID, uid).Scan(&existed)
|
||||
|
||||
err = g.db.QueryRow(ctx, `
|
||||
INSERT INTO messages (
|
||||
account_id, folder_id, uid, message_id, subject,
|
||||
from_addr, to_addrs, cc_addrs, reply_to,
|
||||
date, snippet, body_text, body_html, flags, labels,
|
||||
in_reply_to, references_header
|
||||
)
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17)
|
||||
ON CONFLICT (folder_id, uid) DO UPDATE SET
|
||||
message_id = EXCLUDED.message_id,
|
||||
subject = EXCLUDED.subject,
|
||||
from_addr = EXCLUDED.from_addr,
|
||||
to_addrs = EXCLUDED.to_addrs,
|
||||
cc_addrs = EXCLUDED.cc_addrs,
|
||||
reply_to = EXCLUDED.reply_to,
|
||||
date = EXCLUDED.date,
|
||||
snippet = EXCLUDED.snippet,
|
||||
body_text = EXCLUDED.body_text,
|
||||
body_html = EXCLUDED.body_html,
|
||||
flags = EXCLUDED.flags,
|
||||
labels = EXCLUDED.labels,
|
||||
in_reply_to = EXCLUDED.in_reply_to,
|
||||
references_header = EXCLUDED.references_header,
|
||||
updated_at = NOW()
|
||||
RETURNING id
|
||||
`, accountID, folderID, uid, rfcID, subject,
|
||||
fromJSON, toJSON, ccJSON, replyToJSON,
|
||||
date, snippet, bodyText, sanitize.SanitizeHTML(bodyHTML), flags, labels,
|
||||
inReplyTo, references,
|
||||
).Scan(&messageID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := threading.ApplyMessageThread(ctx, g.db, accountID, messageID, rfcID, inReplyTo, references); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := g.storeGmailAttachments(ctx, userID, messageID, gmailID, accessToken, msg.Payload, existed); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !existed, nil
|
||||
}
|
||||
|
||||
func (g *GmailImporter) importHistory(ctx context.Context, job *Job, accessToken, accountID, historyID string, items *ImportedItemStore) (more bool, err error) {
|
||||
pageToken, _ := job.CursorJSON["historyPageToken"].(string)
|
||||
listURL := fmt.Sprintf(
|
||||
"https://gmail.googleapis.com/gmail/v1/users/me/history?startHistoryId=%s&maxResults=100&historyTypes=messageAdded&historyTypes=messageDeleted&historyTypes=labelAdded&historyTypes=labelRemoved",
|
||||
historyID,
|
||||
)
|
||||
if pageToken != "" {
|
||||
listURL += "&pageToken=" + pageToken
|
||||
}
|
||||
body, err := g.apiGet(ctx, listURL, accessToken)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var parsed struct {
|
||||
History []struct {
|
||||
MessagesAdded []struct {
|
||||
Message struct{ ID string `json:"id"` } `json:"message"`
|
||||
} `json:"messagesAdded"`
|
||||
MessagesDeleted []struct {
|
||||
Message struct{ ID string `json:"id"` } `json:"message"`
|
||||
} `json:"messagesDeleted"`
|
||||
LabelsAdded []struct {
|
||||
Message struct{ ID string `json:"id"` } `json:"message"`
|
||||
} `json:"labelsAdded"`
|
||||
LabelsRemoved []struct {
|
||||
Message struct{ ID string `json:"id"` } `json:"message"`
|
||||
} `json:"labelsRemoved"`
|
||||
} `json:"history"`
|
||||
NextPageToken string `json:"nextPageToken"`
|
||||
HistoryID string `json:"historyId"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
deltaCount, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
batch := 0
|
||||
listIndex := int(jsonNumber(job.CursorJSON["historyListIndex"]))
|
||||
|
||||
for hi := listIndex; hi < len(parsed.History) && batch < mailImportBatchSize(); hi++ {
|
||||
h := parsed.History[hi]
|
||||
for _, added := range h.MessagesAdded {
|
||||
if batch >= mailImportBatchSize() {
|
||||
job.CursorJSON["historyListIndex"] = float64(hi)
|
||||
return true, nil
|
||||
}
|
||||
msgID := added.Message.ID
|
||||
if alreadyImported(items, msgID) {
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
ok, err := g.importOne(ctx, accessToken, job.UserID, accountID, msgID)
|
||||
if err != nil {
|
||||
if markErr := items.MarkFailed(ctx, msgID, err.Error(), ""); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, msgID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
deltaCount++
|
||||
}
|
||||
batch++
|
||||
}
|
||||
for _, removed := range h.MessagesDeleted {
|
||||
if batch >= mailImportBatchSize() {
|
||||
job.CursorJSON["historyListIndex"] = float64(hi)
|
||||
return true, nil
|
||||
}
|
||||
if err := g.deleteByGmailID(ctx, accountID, removed.Message.ID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
deleted++
|
||||
batch++
|
||||
}
|
||||
for _, labeled := range h.LabelsAdded {
|
||||
if batch >= mailImportBatchSize() {
|
||||
job.CursorJSON["historyListIndex"] = float64(hi)
|
||||
return true, nil
|
||||
}
|
||||
if _, err := g.importOne(ctx, accessToken, job.UserID, accountID, labeled.Message.ID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
deltaCount++
|
||||
batch++
|
||||
}
|
||||
for _, labeled := range h.LabelsRemoved {
|
||||
if batch >= mailImportBatchSize() {
|
||||
job.CursorJSON["historyListIndex"] = float64(hi)
|
||||
return true, nil
|
||||
}
|
||||
if _, err := g.importOne(ctx, accessToken, job.UserID, accountID, labeled.Message.ID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
deltaCount++
|
||||
batch++
|
||||
}
|
||||
}
|
||||
delete(job.CursorJSON, "historyListIndex")
|
||||
|
||||
job.StatsJSON["delta_imported"] = deltaCount
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
if parsed.NextPageToken != "" {
|
||||
job.CursorJSON["historyPageToken"] = parsed.NextPageToken
|
||||
if parsed.HistoryID != "" {
|
||||
job.CursorJSON["historyId"] = parsed.HistoryID
|
||||
}
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return true, nil
|
||||
}
|
||||
delete(job.CursorJSON, "historyPageToken")
|
||||
if parsed.HistoryID != "" {
|
||||
job.CursorJSON["historyId"] = parsed.HistoryID
|
||||
}
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (g *GmailImporter) deleteByGmailID(ctx context.Context, accountID, gmailID string) error {
|
||||
if strings.TrimSpace(gmailID) == "" {
|
||||
return nil
|
||||
}
|
||||
uid := gmailUID(gmailID)
|
||||
_, err := g.db.Exec(ctx, `DELETE FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
func isGmailHistoryNotFound(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "404") && strings.Contains(msg, "history")
|
||||
}
|
||||
|
||||
func (g *GmailImporter) fetchHistoryID(ctx context.Context, accessToken string) (string, error) {
|
||||
body, err := g.apiGet(ctx, "https://gmail.googleapis.com/gmail/v1/users/me/profile", accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
HistoryID string `json:"historyId"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return parsed.HistoryID, nil
|
||||
}
|
||||
|
||||
func (g *GmailImporter) apiGet(ctx context.Context, url, accessToken string) ([]byte, error) {
|
||||
raw, err := apiGet(ctx, g.client, url, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gmail api: %w", err)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (g *GmailImporter) resolveMailAccountID(ctx context.Context, userID string) (string, error) {
|
||||
var accountID string
|
||||
err := g.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(
|
||||
(SELECT mail_account_id::text FROM mailboxes WHERE user_id = $1::uuid AND mail_account_id IS NOT NULL LIMIT 1),
|
||||
(SELECT id::text FROM mail_accounts WHERE user_id = $1::uuid AND is_active ORDER BY created_at LIMIT 1)
|
||||
)
|
||||
`, userID).Scan(&accountID)
|
||||
if err != nil || accountID == "" {
|
||||
return "", fmt.Errorf("no mail account for migration user")
|
||||
}
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
func ensureDefaultMailFolders(ctx context.Context, db *pgxpool.Pool, accountID string) error {
|
||||
defaults := []struct{ name, remote, ftype string }{
|
||||
{"Boîte de réception", "INBOX", "inbox"},
|
||||
{"Envoyés", "SENT", "sent"},
|
||||
{"Brouillons", "DRAFT", "drafts"},
|
||||
{"Corbeille", "TRASH", "trash"},
|
||||
{"Spam", "SPAM", "spam"},
|
||||
{"Archives", "ARCHIVE", "archive"},
|
||||
}
|
||||
for _, d := range defaults {
|
||||
if _, err := ensureMailFolder(ctx, db, accountID, d.name, d.remote, d.ftype); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureMailFolder(ctx context.Context, db *pgxpool.Pool, accountID, name, remoteName, folderType string) (string, error) {
|
||||
var folderID string
|
||||
err := db.QueryRow(ctx, `
|
||||
INSERT INTO mail_folders (account_id, name, remote_name, folder_type)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (account_id, remote_name) DO UPDATE
|
||||
SET name = EXCLUDED.name, folder_type = EXCLUDED.folder_type, updated_at = NOW()
|
||||
RETURNING id::text
|
||||
`, accountID, name, remoteName, folderType).Scan(&folderID)
|
||||
return folderID, err
|
||||
}
|
||||
|
||||
func displayFolderName(remote, folderType string) string {
|
||||
switch folderType {
|
||||
case "inbox":
|
||||
return "Boîte de réception"
|
||||
case "sent":
|
||||
return "Envoyés"
|
||||
case "drafts":
|
||||
return "Brouillons"
|
||||
case "trash":
|
||||
return "Corbeille"
|
||||
case "spam":
|
||||
return "Spam"
|
||||
case "archive":
|
||||
return "Archives"
|
||||
default:
|
||||
return remote
|
||||
}
|
||||
}
|
||||
|
||||
func primaryGmailFolder(labelIDs []string) (remoteName, folderType string) {
|
||||
priority := []struct{ label, remote, ftype string }{
|
||||
{"INBOX", "INBOX", "inbox"},
|
||||
{"SENT", "SENT", "sent"},
|
||||
{"DRAFT", "DRAFT", "drafts"},
|
||||
{"TRASH", "TRASH", "trash"},
|
||||
{"SPAM", "SPAM", "spam"},
|
||||
}
|
||||
set := make(map[string]struct{}, len(labelIDs))
|
||||
for _, l := range labelIDs {
|
||||
set[l] = struct{}{}
|
||||
}
|
||||
for _, p := range priority {
|
||||
if _, ok := set[p.label]; ok {
|
||||
return p.remote, p.ftype
|
||||
}
|
||||
}
|
||||
return "ARCHIVE", "archive"
|
||||
}
|
||||
|
||||
func gmailUserLabels(labelIDs []string) []string {
|
||||
system := map[string]struct{}{
|
||||
"INBOX": {}, "SENT": {}, "DRAFT": {}, "TRASH": {}, "SPAM": {},
|
||||
"STARRED": {}, "IMPORTANT": {}, "UNREAD": {}, "CATEGORY_PERSONAL": {},
|
||||
"CATEGORY_SOCIAL": {}, "CATEGORY_PROMOTIONS": {}, "CATEGORY_UPDATES": {},
|
||||
"CATEGORY_FORUMS": {},
|
||||
}
|
||||
out := make([]string, 0, len(labelIDs))
|
||||
for _, l := range labelIDs {
|
||||
if _, skip := system[l]; skip {
|
||||
continue
|
||||
}
|
||||
out = append(out, strings.ToLower(l))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func gmailFlags(labelIDs []string) []string {
|
||||
flags := []string{}
|
||||
unread := false
|
||||
for _, l := range labelIDs {
|
||||
switch l {
|
||||
case "UNREAD":
|
||||
unread = true
|
||||
case "STARRED":
|
||||
flags = append(flags, "\\Flagged")
|
||||
case "IMPORTANT":
|
||||
flags = append(flags, "important")
|
||||
}
|
||||
}
|
||||
if !unread {
|
||||
flags = append(flags, "\\Seen")
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
func gmailUID(gmailID string) int64 {
|
||||
h := fnv.New64a()
|
||||
_, _ = h.Write([]byte(gmailID))
|
||||
v := int64(h.Sum64() & 0x7fffffffffffffff)
|
||||
if v == 0 {
|
||||
return 1
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
func indexHeaders(p gmailPayload) map[string]string {
|
||||
out := map[string]string{}
|
||||
var walk func(gmailPayload)
|
||||
walk = func(node gmailPayload) {
|
||||
for _, h := range node.Headers {
|
||||
key := strings.ToLower(strings.TrimSpace(h.Name))
|
||||
if key != "" && out[key] == "" {
|
||||
out[key] = h.Value
|
||||
}
|
||||
}
|
||||
for _, part := range node.Parts {
|
||||
walk(part)
|
||||
}
|
||||
}
|
||||
walk(p)
|
||||
return out
|
||||
}
|
||||
|
||||
func extractGmailBodies(p gmailPayload) (text, html string) {
|
||||
var walk func(gmailPayload)
|
||||
walk = func(node gmailPayload) {
|
||||
if text == "" && node.MimeType == "text/plain" && node.Body.Data != "" {
|
||||
text = decodeGmailBody(node.Body.Data)
|
||||
}
|
||||
if html == "" && node.MimeType == "text/html" && node.Body.Data != "" {
|
||||
html = decodeGmailBody(node.Body.Data)
|
||||
}
|
||||
for _, part := range node.Parts {
|
||||
walk(part)
|
||||
}
|
||||
}
|
||||
walk(p)
|
||||
return text, html
|
||||
}
|
||||
|
||||
func decodeGmailBody(data string) string {
|
||||
data = strings.ReplaceAll(data, "-", "+")
|
||||
data = strings.ReplaceAll(data, "_", "/")
|
||||
raw, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(raw)
|
||||
}
|
||||
|
||||
func parseAddressListJSON(raw string) []byte {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return []byte("[]")
|
||||
}
|
||||
parts := splitAddresses(raw)
|
||||
type addr struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
out := make([]addr, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
name, email := parseSingleAddress(p)
|
||||
out = append(out, addr{Name: name, Email: email})
|
||||
}
|
||||
b, _ := json.Marshal(out)
|
||||
return b
|
||||
}
|
||||
|
||||
func splitAddresses(raw string) []string {
|
||||
return strings.Split(raw, ",")
|
||||
}
|
||||
|
||||
func parseSingleAddress(raw string) (name, email string) {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if i := strings.Index(raw, "<"); i >= 0 && strings.HasSuffix(raw, ">") {
|
||||
name = strings.Trim(strings.TrimSpace(raw[:i]), `"`)
|
||||
email = strings.Trim(raw[i+1:len(raw)-1], " <>")
|
||||
return name, strings.ToLower(email)
|
||||
}
|
||||
return "", strings.ToLower(raw)
|
||||
}
|
||||
|
||||
func parseReferences(raw string) []string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return []string{}
|
||||
}
|
||||
var out []string
|
||||
for _, part := range strings.Fields(raw) {
|
||||
if id := threading.NormalizeMessageID(part); id != "" {
|
||||
out = append(out, id)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseMailDate(raw string) time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return time.Now().UTC()
|
||||
}
|
||||
layouts := []string{time.RFC1123Z, time.RFC1123, time.RFC3339}
|
||||
for _, layout := range layouts {
|
||||
if t, err := time.Parse(layout, raw); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
}
|
||||
return time.Now().UTC()
|
||||
}
|
||||
|
||||
func parseInternalDate(raw string) (time.Time, error) {
|
||||
var ms int64
|
||||
if _, err := fmt.Sscan(raw, &ms); err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
return time.UnixMilli(ms).UTC(), nil
|
||||
}
|
||||
|
||||
func truncateRunes(s string, n int) string {
|
||||
r := []rune(strings.TrimSpace(s))
|
||||
if len(r) <= n {
|
||||
return string(r)
|
||||
}
|
||||
return string(r[:n])
|
||||
}
|
||||
|
||||
func LinkHostedMailboxByEmail(ctx context.Context, db *pgxpool.Pool, userID, email string) error {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if email == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec(ctx, `
|
||||
UPDATE mailboxes SET user_id = $1::uuid, updated_at = NOW()
|
||||
WHERE user_id IS NULL AND lower(local_part || '@' || (SELECT name FROM mail_domains d WHERE d.id = mailboxes.domain_id)) = $2
|
||||
`, userID, email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = db.Exec(ctx, `
|
||||
UPDATE mail_accounts ma SET user_id = $1::uuid, updated_at = NOW()
|
||||
FROM mailboxes mb
|
||||
JOIN mail_domains md ON md.id = mb.domain_id
|
||||
WHERE mb.mail_account_id = ma.id
|
||||
AND ma.user_id IS NULL
|
||||
AND mb.user_id = $1::uuid
|
||||
AND lower(mb.local_part || '@' || md.name) = $2
|
||||
`, userID, email)
|
||||
return err
|
||||
}
|
||||
|
||||
var _ = pgx.ErrNoRows
|
||||
57
internal/migration/gmail_import_test.go
Normal file
57
internal/migration/gmail_import_test.go
Normal file
@ -0,0 +1,57 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGmailUIDStable(t *testing.T) {
|
||||
a := gmailUID("18c4f2a1b2d3e4f5")
|
||||
b := gmailUID("18c4f2a1b2d3e4f5")
|
||||
if a != b {
|
||||
t.Fatalf("expected stable uid, got %d vs %d", a, b)
|
||||
}
|
||||
if a <= 0 {
|
||||
t.Fatalf("expected positive uid, got %d", a)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrimaryGmailFolder(t *testing.T) {
|
||||
remote, folderType := primaryGmailFolder([]string{"INBOX", "UNREAD"})
|
||||
if remote != "INBOX" || folderType != "inbox" {
|
||||
t.Fatalf("got %q / %q", remote, folderType)
|
||||
}
|
||||
|
||||
remote, folderType = primaryGmailFolder([]string{"SENT"})
|
||||
if remote != "SENT" || folderType != "sent" {
|
||||
t.Fatalf("got %q / %q", remote, folderType)
|
||||
}
|
||||
|
||||
remote, folderType = primaryGmailFolder([]string{"Label_123", "INBOX"})
|
||||
if remote != "INBOX" || folderType != "inbox" {
|
||||
t.Fatalf("inbox wins over label: got %q / %q", remote, folderType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseAddressListJSON(t *testing.T) {
|
||||
raw := parseAddressListJSON(`Alice <alice@example.com>, bob@example.com`)
|
||||
var addrs []map[string]string
|
||||
if err := json.Unmarshal(raw, &addrs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(addrs) != 2 {
|
||||
t.Fatalf("expected 2 addresses, got %d", len(addrs))
|
||||
}
|
||||
if addrs[0]["email"] != "alice@example.com" {
|
||||
t.Fatalf("unexpected first email: %v", addrs[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayFolderName(t *testing.T) {
|
||||
if got := displayFolderName("Label_42", "custom"); got != "Label_42" {
|
||||
t.Fatalf("custom label: %q", got)
|
||||
}
|
||||
if got := displayFolderName("INBOX", "inbox"); got != "Boîte de réception" {
|
||||
t.Fatalf("inbox display: %q", got)
|
||||
}
|
||||
}
|
||||
77
internal/migration/google_dwd.go
Normal file
77
internal/migration/google_dwd.go
Normal file
@ -0,0 +1,77 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/oauth2/google"
|
||||
"golang.org/x/oauth2/jwt"
|
||||
)
|
||||
|
||||
var googleDWDScopes = []string{
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
"https://www.googleapis.com/auth/contacts.readonly",
|
||||
}
|
||||
|
||||
// GoogleDWD mints access tokens via a service account with domain-wide delegation.
|
||||
type GoogleDWD struct {
|
||||
jwtConfig *jwt.Config
|
||||
}
|
||||
|
||||
func NewGoogleDWD(jsonKey string) (*GoogleDWD, error) {
|
||||
jsonKey = strings.TrimSpace(jsonKey)
|
||||
if jsonKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
conf, err := google.JWTConfigFromJSON([]byte(jsonKey), googleDWDScopes...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse google service account: %w", err)
|
||||
}
|
||||
return &GoogleDWD{jwtConfig: conf}, nil
|
||||
}
|
||||
|
||||
func (g *GoogleDWD) Enabled() bool {
|
||||
return g != nil && g.jwtConfig != nil
|
||||
}
|
||||
|
||||
func (g *GoogleDWD) AccessToken(ctx context.Context, subjectEmail string) (string, error) {
|
||||
if !g.Enabled() {
|
||||
return "", fmt.Errorf("google domain-wide delegation not configured")
|
||||
}
|
||||
subjectEmail = strings.ToLower(strings.TrimSpace(subjectEmail))
|
||||
if subjectEmail == "" {
|
||||
return "", fmt.Errorf("subject email required for domain-wide delegation")
|
||||
}
|
||||
conf := *g.jwtConfig
|
||||
conf.Subject = subjectEmail
|
||||
token, err := conf.TokenSource(ctx).Token()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("google dwd token: %w", err)
|
||||
}
|
||||
if token.AccessToken == "" {
|
||||
return "", fmt.Errorf("google dwd token empty")
|
||||
}
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
func validateServiceAccountJSON(raw string) error {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return nil
|
||||
}
|
||||
var probe struct {
|
||||
ClientEmail string `json:"client_email"`
|
||||
PrivateKey string `json:"private_key"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &probe); err != nil {
|
||||
return fmt.Errorf("invalid service account json: %w", err)
|
||||
}
|
||||
if probe.ClientEmail == "" || probe.PrivateKey == "" {
|
||||
return fmt.Errorf("service account json missing client_email or private_key")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
37
internal/migration/google_dwd_test.go
Normal file
37
internal/migration/google_dwd_test.go
Normal file
@ -0,0 +1,37 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNormalizeAuthMode(t *testing.T) {
|
||||
if got := NormalizeAuthMode("google", "google_dwd"); got != AuthModeGoogleDWD {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := NormalizeAuthMode("microsoft", "google_dwd"); got != AuthModeOAuth {
|
||||
t.Fatalf("microsoft ignores dwd: got %q", got)
|
||||
}
|
||||
if got := NormalizeAuthMode("microsoft", "microsoft_app"); got != AuthModeMicrosoftApp {
|
||||
t.Fatalf("microsoft app: got %q", got)
|
||||
}
|
||||
if got := NormalizeAuthMode("google", ""); got != AuthModeOAuth {
|
||||
t.Fatalf("default oauth: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateServiceAccountJSON(t *testing.T) {
|
||||
if err := validateServiceAccountJSON(""); err != nil {
|
||||
t.Fatalf("empty ok: %v", err)
|
||||
}
|
||||
if err := validateServiceAccountJSON(`{"client_email":"a@b.iam.gserviceaccount.com","private_key":"-----BEGIN PRIVATE KEY-----\nabc\n-----END PRIVATE KEY-----\n"}`); err != nil {
|
||||
t.Fatalf("valid json: %v", err)
|
||||
}
|
||||
if err := validateServiceAccountJSON(`{`); err == nil {
|
||||
t.Fatal("expected invalid json error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGoogleDWDEmpty(t *testing.T) {
|
||||
dwd, err := NewGoogleDWD("")
|
||||
if err != nil || dwd != nil {
|
||||
t.Fatalf("empty config: dwd=%v err=%v", dwd, err)
|
||||
}
|
||||
}
|
||||
531
internal/migration/graph_import.go
Normal file
531
internal/migration/graph_import.go
Normal file
@ -0,0 +1,531 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/threading"
|
||||
)
|
||||
|
||||
const graphMessageSelect = "id,subject,bodyPreview,body,from,toRecipients,ccRecipients,replyTo," +
|
||||
"receivedDateTime,sentDateTime,parentFolderId,isRead,flag,internetMessageId,internetMessageHeaders"
|
||||
|
||||
type GraphImporter struct {
|
||||
db *pgxpool.Pool
|
||||
client *http.Client
|
||||
baseURL string
|
||||
userUPN string
|
||||
folders map[string]graphFolderMeta
|
||||
}
|
||||
|
||||
type graphFolderMeta struct {
|
||||
RemoteName string
|
||||
FolderType string
|
||||
}
|
||||
|
||||
func NewGraphImporter(db *pgxpool.Pool) *GraphImporter {
|
||||
return &GraphImporter{
|
||||
db: db,
|
||||
client: &http.Client{Timeout: 90 * time.Second},
|
||||
folders: map[string]graphFolderMeta{},
|
||||
}
|
||||
}
|
||||
|
||||
func (g *GraphImporter) WithHTTPClient(c *http.Client) *GraphImporter {
|
||||
if c != nil {
|
||||
g.client = c
|
||||
}
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GraphImporter) WithUserPrincipal(upn string) *GraphImporter {
|
||||
g.userUPN = strings.TrimSpace(upn)
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GraphImporter) userBase() string {
|
||||
return graphUserBase(g.userUPN)
|
||||
}
|
||||
|
||||
func (g *GraphImporter) WithBaseURL(baseURL string) *GraphImporter {
|
||||
g.baseURL = strings.TrimRight(strings.TrimSpace(baseURL), "/")
|
||||
return g
|
||||
}
|
||||
|
||||
func (g *GraphImporter) graphURL(path string) string {
|
||||
if g.baseURL != "" {
|
||||
return g.baseURL + path
|
||||
}
|
||||
return "https://graph.microsoft.com" + path
|
||||
}
|
||||
|
||||
type graphMessage struct {
|
||||
ID string `json:"id"`
|
||||
Subject string `json:"subject"`
|
||||
BodyPreview string `json:"bodyPreview"`
|
||||
Body graphBody `json:"body"`
|
||||
From graphRecipient `json:"from"`
|
||||
ToRecipients []graphRecipient `json:"toRecipients"`
|
||||
CcRecipients []graphRecipient `json:"ccRecipients"`
|
||||
ReplyTo []graphRecipient `json:"replyTo"`
|
||||
ReceivedDateTime string `json:"receivedDateTime"`
|
||||
SentDateTime string `json:"sentDateTime"`
|
||||
ParentFolderID string `json:"parentFolderId"`
|
||||
IsRead bool `json:"isRead"`
|
||||
Flag graphFlag `json:"flag"`
|
||||
InternetMessageID string `json:"internetMessageId"`
|
||||
InternetMessageHeaders []graphHeader `json:"internetMessageHeaders"`
|
||||
Removed *struct {
|
||||
Reason string `json:"reason"`
|
||||
} `json:"@removed"`
|
||||
}
|
||||
|
||||
type graphBody struct {
|
||||
ContentType string `json:"contentType"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type graphRecipient struct {
|
||||
EmailAddress graphEmailAddress `json:"emailAddress"`
|
||||
}
|
||||
|
||||
type graphEmailAddress struct {
|
||||
Name string `json:"name"`
|
||||
Address string `json:"address"`
|
||||
}
|
||||
|
||||
type graphFlag struct {
|
||||
FlagStatus string `json:"flagStatus"`
|
||||
}
|
||||
|
||||
type graphHeader struct {
|
||||
Name string `json:"name"`
|
||||
Value string `json:"value"`
|
||||
}
|
||||
|
||||
func (g *GraphImporter) ImportBatch(
|
||||
ctx context.Context,
|
||||
job *Job,
|
||||
accessToken string,
|
||||
delta bool,
|
||||
update func(status string, cursor, stats map[string]any, jobErr string) error,
|
||||
) error {
|
||||
accountID, err := g.resolveMailAccountID(ctx, job.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := ensureDefaultMailFolders(ctx, g.db, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := g.ensureGraphFolders(ctx, accessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
items, err := LoadImportedItemStore(ctx, g.db, job.ID, job.CursorJSON)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if delta {
|
||||
deltaLink, _ := job.CursorJSON["deltaLink"].(string)
|
||||
if deltaLink != "" {
|
||||
more, err := g.importDeltaPage(ctx, job, accessToken, accountID, deltaLink, items)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if more {
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
}
|
||||
|
||||
nextLink, _ := job.CursorJSON["nextLink"].(string)
|
||||
var listURL string
|
||||
if nextLink != "" {
|
||||
listURL = nextLink
|
||||
} else {
|
||||
listURL = g.graphURL(g.userBase()+"/messages?$top=100&$orderby="+url.QueryEscape("receivedDateTime desc")+"&$select="+graphMessageSelect)
|
||||
}
|
||||
|
||||
body, err := g.apiGet(ctx, listURL, accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var listed struct {
|
||||
Value []graphMessage `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &listed); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
imported, _ := job.StatsJSON["imported"].(float64)
|
||||
batch := 0
|
||||
listIndex := int(jsonNumber(job.CursorJSON["listIndex"]))
|
||||
for i := listIndex; i < len(listed.Value) && batch < mailImportBatchSize(); i++ {
|
||||
msg := listed.Value[i]
|
||||
if alreadyImported(items, msg.ID) {
|
||||
listIndex = i + 1
|
||||
continue
|
||||
}
|
||||
created, err := g.importOne(ctx, accountID, msg)
|
||||
if err != nil {
|
||||
if markErr := items.MarkFailed(ctx, msg.ID, err.Error(), ""); markErr != nil {
|
||||
return markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
batch++
|
||||
listIndex = i + 1
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, msg.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
if created {
|
||||
imported++
|
||||
}
|
||||
batch++
|
||||
listIndex = i + 1
|
||||
}
|
||||
job.StatsJSON["imported"] = imported
|
||||
job.CursorJSON["listIndex"] = float64(listIndex)
|
||||
|
||||
if listIndex < len(listed.Value) {
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
delete(job.CursorJSON, "listIndex")
|
||||
if listed.NextLink != "" {
|
||||
job.CursorJSON["nextLink"] = listed.NextLink
|
||||
return update("pending", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
delete(job.CursorJSON, "nextLink")
|
||||
|
||||
if delta {
|
||||
if listed.DeltaLink != "" {
|
||||
job.CursorJSON["deltaLink"] = listed.DeltaLink
|
||||
} else if link, err := g.initDeltaLink(ctx, accessToken); err == nil && link != "" {
|
||||
job.CursorJSON["deltaLink"] = link
|
||||
}
|
||||
}
|
||||
job.StatsJSON["phase"] = "imported"
|
||||
return update("completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (g *GraphImporter) importDeltaPage(ctx context.Context, job *Job, accessToken, accountID, deltaLink string, items *ImportedItemStore) (more bool, err error) {
|
||||
body, err := g.apiGet(ctx, deltaLink, accessToken)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []graphMessage `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return false, err
|
||||
}
|
||||
deltaCount, _ := job.StatsJSON["delta_imported"].(float64)
|
||||
deleted, _ := job.StatsJSON["delta_deleted"].(float64)
|
||||
for _, msg := range parsed.Value {
|
||||
if msg.Removed != nil {
|
||||
if err := g.deleteByGraphID(ctx, accountID, msg.ID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
deleted++
|
||||
continue
|
||||
}
|
||||
if alreadyImported(items, msg.ID) {
|
||||
continue
|
||||
}
|
||||
ok, err := g.importOne(ctx, accountID, msg)
|
||||
if err != nil {
|
||||
if markErr := items.MarkFailed(ctx, msg.ID, err.Error(), ""); markErr != nil {
|
||||
return false, markErr
|
||||
}
|
||||
incJobStat(job.StatsJSON, "failed")
|
||||
continue
|
||||
}
|
||||
if err := items.MarkImported(ctx, msg.ID); err != nil {
|
||||
return false, err
|
||||
}
|
||||
if ok {
|
||||
deltaCount++
|
||||
}
|
||||
}
|
||||
job.StatsJSON["delta_imported"] = deltaCount
|
||||
job.StatsJSON["delta_deleted"] = deleted
|
||||
if parsed.NextLink != "" {
|
||||
job.CursorJSON["deltaLink"] = parsed.NextLink
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return true, nil
|
||||
}
|
||||
if parsed.DeltaLink != "" {
|
||||
job.CursorJSON["deltaLink"] = parsed.DeltaLink
|
||||
}
|
||||
job.StatsJSON["phase"] = "delta"
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (g *GraphImporter) initDeltaLink(ctx context.Context, accessToken string) (string, error) {
|
||||
body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/messages/delta?$select=id"), accessToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
var parsed struct {
|
||||
DeltaLink string `json:"@odata.deltaLink"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsed.DeltaLink != "" {
|
||||
return parsed.DeltaLink, nil
|
||||
}
|
||||
return parsed.NextLink, nil
|
||||
}
|
||||
|
||||
func (g *GraphImporter) importOne(ctx context.Context, accountID string, msg graphMessage) (bool, error) {
|
||||
meta := g.folders[msg.ParentFolderID]
|
||||
if meta.RemoteName == "" {
|
||||
meta = graphFolderMeta{RemoteName: "ARCHIVE", FolderType: "archive"}
|
||||
}
|
||||
folderID, err := ensureMailFolder(ctx, g.db, accountID, displayFolderName(meta.RemoteName, meta.FolderType), meta.RemoteName, meta.FolderType)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
headers := indexGraphHeaders(msg.InternetMessageHeaders)
|
||||
rfcID := threading.NormalizeMessageID(msg.InternetMessageID)
|
||||
if rfcID == "" {
|
||||
rfcID = threading.NormalizeMessageID(headers["message-id"])
|
||||
}
|
||||
if rfcID == "" {
|
||||
rfcID = threading.NormalizeMessageID("<graph-" + msg.ID + "@ultimail.migrated>")
|
||||
}
|
||||
inReplyTo := threading.NormalizeMessageID(headers["in-reply-to"])
|
||||
references := parseReferences(headers["references"])
|
||||
|
||||
bodyText, bodyHTML := extractGraphBody(msg.Body)
|
||||
snippet := strings.TrimSpace(msg.BodyPreview)
|
||||
if snippet == "" {
|
||||
snippet = truncateRunes(bodyText, 200)
|
||||
}
|
||||
|
||||
date := parseGraphTime(msg.ReceivedDateTime)
|
||||
if date.IsZero() {
|
||||
date = parseGraphTime(msg.SentDateTime)
|
||||
}
|
||||
if date.IsZero() {
|
||||
date = time.Now().UTC()
|
||||
}
|
||||
|
||||
fromJSON := graphRecipientJSON(msg.From)
|
||||
toJSON := graphRecipientsJSON(msg.ToRecipients)
|
||||
ccJSON := graphRecipientsJSON(msg.CcRecipients)
|
||||
replyToJSON := graphRecipientsJSON(msg.ReplyTo)
|
||||
flags := graphFlags(msg.IsRead, msg.Flag.FlagStatus)
|
||||
|
||||
uid := remoteMessageUID(msg.ID)
|
||||
var messageID string
|
||||
var existed bool
|
||||
_ = g.db.QueryRow(ctx, `SELECT EXISTS(SELECT 1 FROM messages WHERE folder_id = $1 AND uid = $2)`, folderID, uid).Scan(&existed)
|
||||
|
||||
err = g.db.QueryRow(ctx, `
|
||||
INSERT INTO messages (
|
||||
account_id, folder_id, uid, message_id, subject,
|
||||
from_addr, to_addrs, cc_addrs, reply_to,
|
||||
date, snippet, body_text, body_html, flags, labels,
|
||||
in_reply_to, references_header
|
||||
)
|
||||
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17)
|
||||
ON CONFLICT (folder_id, uid) DO UPDATE SET
|
||||
message_id = EXCLUDED.message_id,
|
||||
subject = EXCLUDED.subject,
|
||||
from_addr = EXCLUDED.from_addr,
|
||||
to_addrs = EXCLUDED.to_addrs,
|
||||
cc_addrs = EXCLUDED.cc_addrs,
|
||||
reply_to = EXCLUDED.reply_to,
|
||||
date = EXCLUDED.date,
|
||||
snippet = EXCLUDED.snippet,
|
||||
body_text = EXCLUDED.body_text,
|
||||
body_html = EXCLUDED.body_html,
|
||||
flags = EXCLUDED.flags,
|
||||
in_reply_to = EXCLUDED.in_reply_to,
|
||||
references_header = EXCLUDED.references_header,
|
||||
updated_at = NOW()
|
||||
RETURNING id
|
||||
`, accountID, folderID, uid, rfcID, msg.Subject,
|
||||
fromJSON, toJSON, ccJSON, replyToJSON,
|
||||
date, snippet, bodyText, sanitize.SanitizeHTML(bodyHTML), flags, []string{},
|
||||
inReplyTo, references,
|
||||
).Scan(&messageID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if err := threading.ApplyMessageThread(ctx, g.db, accountID, messageID, rfcID, inReplyTo, references); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return !existed, nil
|
||||
}
|
||||
|
||||
func (g *GraphImporter) deleteByGraphID(ctx context.Context, accountID, graphID string) error {
|
||||
if strings.TrimSpace(graphID) == "" {
|
||||
return nil
|
||||
}
|
||||
uid := remoteMessageUID(graphID)
|
||||
_, err := g.db.Exec(ctx, `DELETE FROM messages WHERE account_id = $1::uuid AND uid = $2`, accountID, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
func (g *GraphImporter) ensureGraphFolders(ctx context.Context, accessToken string) error {
|
||||
if len(g.folders) > 0 {
|
||||
return nil
|
||||
}
|
||||
body, err := g.apiGet(ctx, g.graphURL(g.userBase()+"/mailFolders?$top=100&$select=id,displayName,wellKnownName"), accessToken)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var parsed struct {
|
||||
Value []struct {
|
||||
ID string `json:"id"`
|
||||
DisplayName string `json:"displayName"`
|
||||
WellKnownName string `json:"wellKnownName"`
|
||||
} `json:"value"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &parsed); err != nil {
|
||||
return err
|
||||
}
|
||||
for _, f := range parsed.Value {
|
||||
remote, ftype := graphWellKnownFolder(f.WellKnownName, f.DisplayName)
|
||||
g.folders[f.ID] = graphFolderMeta{RemoteName: remote, FolderType: ftype}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func graphWellKnownFolder(wellKnown, displayName string) (remoteName, folderType string) {
|
||||
switch strings.ToLower(strings.TrimSpace(wellKnown)) {
|
||||
case "inbox":
|
||||
return "INBOX", "inbox"
|
||||
case "sentitems":
|
||||
return "SENT", "sent"
|
||||
case "drafts":
|
||||
return "DRAFT", "drafts"
|
||||
case "deleteditems":
|
||||
return "TRASH", "trash"
|
||||
case "junkemail":
|
||||
return "SPAM", "spam"
|
||||
case "archive":
|
||||
return "ARCHIVE", "archive"
|
||||
default:
|
||||
name := strings.TrimSpace(displayName)
|
||||
if name == "" {
|
||||
name = "CUSTOM"
|
||||
}
|
||||
return strings.ToUpper(strings.ReplaceAll(name, " ", "_")), "custom"
|
||||
}
|
||||
}
|
||||
|
||||
func graphFlags(isRead bool, flagStatus string) []string {
|
||||
flags := []string{}
|
||||
if isRead {
|
||||
flags = append(flags, "\\Seen")
|
||||
}
|
||||
if strings.EqualFold(flagStatus, "flagged") {
|
||||
flags = append(flags, "\\Flagged")
|
||||
}
|
||||
return flags
|
||||
}
|
||||
|
||||
func extractGraphBody(body graphBody) (text, html string) {
|
||||
content := body.Content
|
||||
switch strings.ToLower(body.ContentType) {
|
||||
case "html":
|
||||
html = content
|
||||
case "text":
|
||||
text = content
|
||||
default:
|
||||
text = content
|
||||
}
|
||||
return text, html
|
||||
}
|
||||
|
||||
func graphRecipientJSON(r graphRecipient) []byte {
|
||||
if strings.TrimSpace(r.EmailAddress.Address) == "" {
|
||||
return []byte("[]")
|
||||
}
|
||||
type addr struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
b, _ := json.Marshal([]addr{{Name: r.EmailAddress.Name, Email: strings.ToLower(r.EmailAddress.Address)}})
|
||||
return b
|
||||
}
|
||||
|
||||
func graphRecipientsJSON(recipients []graphRecipient) []byte {
|
||||
type addr struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
out := make([]addr, 0, len(recipients))
|
||||
for _, r := range recipients {
|
||||
email := strings.ToLower(strings.TrimSpace(r.EmailAddress.Address))
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, addr{Name: r.EmailAddress.Name, Email: email})
|
||||
}
|
||||
b, _ := json.Marshal(out)
|
||||
return b
|
||||
}
|
||||
|
||||
func indexGraphHeaders(headers []graphHeader) map[string]string {
|
||||
out := map[string]string{}
|
||||
for _, h := range headers {
|
||||
key := strings.ToLower(strings.TrimSpace(h.Name))
|
||||
if key != "" && out[key] == "" {
|
||||
out[key] = h.Value
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseGraphTime(raw string) time.Time {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return time.Time{}
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339Nano, raw); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, raw); err == nil {
|
||||
return t.UTC()
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (g *GraphImporter) apiGet(ctx context.Context, url, accessToken string) ([]byte, error) {
|
||||
raw, err := apiGet(ctx, g.client, url, accessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("graph api: %w", err)
|
||||
}
|
||||
return raw, nil
|
||||
}
|
||||
|
||||
func (g *GraphImporter) resolveMailAccountID(ctx context.Context, userID string) (string, error) {
|
||||
importer := NewGmailImporter(g.db)
|
||||
return importer.resolveMailAccountID(ctx, userID)
|
||||
}
|
||||
|
||||
func remoteMessageUID(remoteID string) int64 {
|
||||
return gmailUID(remoteID)
|
||||
}
|
||||
49
internal/migration/graph_import_test.go
Normal file
49
internal/migration/graph_import_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGraphWellKnownFolder(t *testing.T) {
|
||||
remote, ftype := graphWellKnownFolder("inbox", "Inbox")
|
||||
if remote != "INBOX" || ftype != "inbox" {
|
||||
t.Fatalf("got %q / %q", remote, ftype)
|
||||
}
|
||||
|
||||
remote, ftype = graphWellKnownFolder("", "Projects")
|
||||
if remote != "PROJECTS" || ftype != "custom" {
|
||||
t.Fatalf("custom folder: got %q / %q", remote, ftype)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphFlags(t *testing.T) {
|
||||
flags := graphFlags(true, "notFlagged")
|
||||
if len(flags) != 1 || flags[0] != "\\Seen" {
|
||||
t.Fatalf("read flags: %v", flags)
|
||||
}
|
||||
flags = graphFlags(false, "flagged")
|
||||
if len(flags) != 1 || flags[0] != "\\Flagged" {
|
||||
t.Fatalf("flagged: %v", flags)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphRecipientsJSON(t *testing.T) {
|
||||
raw := graphRecipientsJSON([]graphRecipient{
|
||||
{EmailAddress: graphEmailAddress{Name: "Bob", Address: "bob@example.com"}},
|
||||
})
|
||||
if string(raw) != `[{"name":"Bob","email":"bob@example.com"}]` {
|
||||
t.Fatalf("unexpected json: %s", raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseGraphTime(t *testing.T) {
|
||||
tm := parseGraphTime("2024-05-01T12:34:56Z")
|
||||
if tm.IsZero() {
|
||||
t.Fatal("expected parsed time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoteMessageUIDMatchesGmailUID(t *testing.T) {
|
||||
id := "abc123"
|
||||
if remoteMessageUID(id) != gmailUID(id) {
|
||||
t.Fatal("uid helpers diverged")
|
||||
}
|
||||
}
|
||||
22
internal/migration/graph_user.go
Normal file
22
internal/migration/graph_user.go
Normal file
@ -0,0 +1,22 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func graphUserBase(userUPN string) string {
|
||||
userUPN = strings.TrimSpace(userUPN)
|
||||
if userUPN == "" {
|
||||
return "/v1.0/me"
|
||||
}
|
||||
return "/v1.0/users/" + url.PathEscape(userUPN)
|
||||
}
|
||||
|
||||
func graphMicrosoftURL(userUPN, suffix string) string {
|
||||
suffix = strings.TrimSpace(suffix)
|
||||
if suffix != "" && !strings.HasPrefix(suffix, "/") {
|
||||
suffix = "/" + suffix
|
||||
}
|
||||
return "https://graph.microsoft.com" + graphUserBase(userUPN) + suffix
|
||||
}
|
||||
186
internal/migration/http_retry.go
Normal file
186
internal/migration/http_retry.go
Normal file
@ -0,0 +1,186 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/observability"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultRateLimitMaxRetries = 6
|
||||
defaultRateLimitBaseDelay = 2 * time.Second
|
||||
defaultRateLimitMaxDelay = 2 * time.Minute
|
||||
)
|
||||
|
||||
// RateLimitConfig controls exponential backoff for migration provider API calls.
|
||||
type RateLimitConfig struct {
|
||||
MaxRetries int
|
||||
BaseDelay time.Duration
|
||||
MaxDelay time.Duration
|
||||
}
|
||||
|
||||
// RateLimitError is returned when a provider keeps responding with HTTP 429
|
||||
// after all configured retries are exhausted.
|
||||
type RateLimitError struct {
|
||||
Cause error
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
func (e *RateLimitError) Error() string {
|
||||
if e == nil || e.Cause == nil {
|
||||
return "migration api rate limited"
|
||||
}
|
||||
return e.Cause.Error()
|
||||
}
|
||||
|
||||
func (e *RateLimitError) Unwrap() error {
|
||||
if e == nil {
|
||||
return nil
|
||||
}
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// IsRateLimitError reports whether err is a terminal migration rate-limit error.
|
||||
func IsRateLimitError(err error) bool {
|
||||
var rl *RateLimitError
|
||||
return errors.As(err, &rl)
|
||||
}
|
||||
|
||||
var (
|
||||
rateLimitMu sync.RWMutex
|
||||
rateLimitConfig = RateLimitConfig{
|
||||
MaxRetries: defaultRateLimitMaxRetries,
|
||||
BaseDelay: defaultRateLimitBaseDelay,
|
||||
MaxDelay: defaultRateLimitMaxDelay,
|
||||
}
|
||||
)
|
||||
|
||||
// ConfigureRateLimit sets package-wide retry settings for migration HTTP calls.
|
||||
func ConfigureRateLimit(cfg RateLimitConfig) {
|
||||
rateLimitMu.Lock()
|
||||
defer rateLimitMu.Unlock()
|
||||
if cfg.MaxRetries > 0 {
|
||||
rateLimitConfig.MaxRetries = cfg.MaxRetries
|
||||
}
|
||||
if cfg.BaseDelay > 0 {
|
||||
rateLimitConfig.BaseDelay = cfg.BaseDelay
|
||||
}
|
||||
if cfg.MaxDelay > 0 {
|
||||
rateLimitConfig.MaxDelay = cfg.MaxDelay
|
||||
}
|
||||
}
|
||||
|
||||
func currentRateLimitConfig() RateLimitConfig {
|
||||
rateLimitMu.RLock()
|
||||
defer rateLimitMu.RUnlock()
|
||||
return rateLimitConfig
|
||||
}
|
||||
|
||||
func parseRetryAfter(v string) time.Duration {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
return 0
|
||||
}
|
||||
if secs, err := strconv.Atoi(v); err == nil && secs >= 0 {
|
||||
return time.Duration(secs) * time.Second
|
||||
}
|
||||
if t, err := http.ParseTime(v); err == nil {
|
||||
d := time.Until(t)
|
||||
if d < 0 {
|
||||
return 0
|
||||
}
|
||||
return d
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func rateLimitDelay(attempt int, retryAfter time.Duration) time.Duration {
|
||||
cfg := currentRateLimitConfig()
|
||||
delay := cfg.BaseDelay
|
||||
for i := 1; i < attempt; i++ {
|
||||
if delay >= cfg.MaxDelay {
|
||||
delay = cfg.MaxDelay
|
||||
break
|
||||
}
|
||||
delay *= 2
|
||||
}
|
||||
if retryAfter > delay {
|
||||
delay = retryAfter
|
||||
}
|
||||
if delay > cfg.MaxDelay {
|
||||
delay = cfg.MaxDelay
|
||||
}
|
||||
return delay
|
||||
}
|
||||
|
||||
func migrationDo(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) {
|
||||
if client == nil {
|
||||
client = migrationHTTPClient()
|
||||
}
|
||||
cfg := currentRateLimitConfig()
|
||||
var lastErr error
|
||||
var lastRetryAfter time.Duration
|
||||
|
||||
for attempt := 0; attempt <= cfg.MaxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
observability.IncMigrationRateLimitRetry()
|
||||
delay := rateLimitDelay(attempt, lastRetryAfter)
|
||||
slog.Default().Warn(
|
||||
"migration api rate limited, backing off",
|
||||
"component", "migration-http",
|
||||
"attempt", attempt,
|
||||
"max_retries", cfg.MaxRetries,
|
||||
"delay", delay.String(),
|
||||
"method", req.Method,
|
||||
"host", req.URL.Host,
|
||||
"path", req.URL.Path,
|
||||
)
|
||||
timer := time.NewTimer(delay)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
cloned := req.Clone(ctx)
|
||||
resp, err := client.Do(cloned)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusTooManyRequests {
|
||||
lastRetryAfter = parseRetryAfter(resp.Header.Get("Retry-After"))
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
lastErr = fmt.Errorf("api rate limited (429): %s", strings.TrimSpace(string(raw)))
|
||||
if attempt >= cfg.MaxRetries {
|
||||
return nil, &RateLimitError{Cause: lastErr, RetryAfter: lastRetryAfter}
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
raw, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("api %s: %s", resp.Status, strings.TrimSpace(string(raw)))
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
if lastErr == nil {
|
||||
lastErr = errors.New("migration api rate limited")
|
||||
}
|
||||
return nil, &RateLimitError{Cause: lastErr, RetryAfter: lastRetryAfter}
|
||||
}
|
||||
103
internal/migration/http_retry_test.go
Normal file
103
internal/migration/http_retry_test.go
Normal file
@ -0,0 +1,103 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestParseRetryAfter(t *testing.T) {
|
||||
if got := parseRetryAfter("30"); got != 30*time.Second {
|
||||
t.Fatalf("seconds = %v", got)
|
||||
}
|
||||
if got := parseRetryAfter(""); got != 0 {
|
||||
t.Fatalf("empty = %v", got)
|
||||
}
|
||||
future := time.Now().Add(45 * time.Second).UTC().Format(http.TimeFormat)
|
||||
if got := parseRetryAfter(future); got < 40*time.Second || got > 50*time.Second {
|
||||
t.Fatalf("http date = %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitDelayUsesRetryAfter(t *testing.T) {
|
||||
ConfigureRateLimit(RateLimitConfig{
|
||||
MaxRetries: 3,
|
||||
BaseDelay: 100 * time.Millisecond,
|
||||
MaxDelay: time.Second,
|
||||
})
|
||||
delay := rateLimitDelay(1, 500*time.Millisecond)
|
||||
if delay != 500*time.Millisecond {
|
||||
t.Fatalf("delay = %v", delay)
|
||||
}
|
||||
delay = rateLimitDelay(3, 0)
|
||||
if delay != 400*time.Millisecond {
|
||||
t.Fatalf("exponential delay = %v", delay)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGETRetries429ThenSucceeds(t *testing.T) {
|
||||
ConfigureRateLimit(RateLimitConfig{
|
||||
MaxRetries: 5,
|
||||
BaseDelay: 5 * time.Millisecond,
|
||||
MaxDelay: 50 * time.Millisecond,
|
||||
})
|
||||
|
||||
calls := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
calls++
|
||||
if calls < 3 {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":"rateLimitExceeded"}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
body, err := apiGet(context.Background(), srv.Client(), srv.URL, "token")
|
||||
if err != nil {
|
||||
t.Fatalf("apiGet: %v", err)
|
||||
}
|
||||
if string(body) != `{"ok":true}` {
|
||||
t.Fatalf("body = %q", body)
|
||||
}
|
||||
if calls != 3 {
|
||||
t.Fatalf("calls = %d", calls)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIGETReturnsRateLimitErrorAfterMaxRetries(t *testing.T) {
|
||||
ConfigureRateLimit(RateLimitConfig{
|
||||
MaxRetries: 2,
|
||||
BaseDelay: time.Millisecond,
|
||||
MaxDelay: 10 * time.Millisecond,
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Retry-After", "0")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = w.Write([]byte(`{"error":"quota"}`))
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
_, err := apiGet(context.Background(), srv.Client(), srv.URL, "token")
|
||||
if !IsRateLimitError(err) {
|
||||
t.Fatalf("expected RateLimitError, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerRateLimitErrorIsPending(t *testing.T) {
|
||||
if !IsRateLimitError(&RateLimitError{Cause: errTestRateLimit}) {
|
||||
t.Fatal("expected typed rate limit error")
|
||||
}
|
||||
}
|
||||
|
||||
var errTestRateLimit = &RateLimitError{Cause: errTestRateLimitCause{}}
|
||||
|
||||
type errTestRateLimitCause struct{}
|
||||
|
||||
func (errTestRateLimitCause) Error() string { return "api rate limited (429): quota" }
|
||||
59
internal/migration/httpmock_test.go
Normal file
59
internal/migration/httpmock_test.go
Normal file
@ -0,0 +1,59 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type hostRewriteTransport struct {
|
||||
mockBase string
|
||||
match func(host string) bool
|
||||
base http.RoundTripper
|
||||
}
|
||||
|
||||
func (rt *hostRewriteTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if rt.match(req.URL.Host) {
|
||||
mockURL, err := url.Parse(rt.mockBase)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.URL.Scheme = mockURL.Scheme
|
||||
req.URL.Host = mockURL.Host
|
||||
}
|
||||
base := rt.base
|
||||
if base == nil {
|
||||
base = http.DefaultTransport
|
||||
}
|
||||
return base.RoundTrip(req)
|
||||
}
|
||||
|
||||
func mockGoogleHTTPClient(t *testing.T, handler http.HandlerFunc) *http.Client {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
return &http.Client{
|
||||
Transport: &hostRewriteTransport{
|
||||
mockBase: srv.URL,
|
||||
match: func(host string) bool {
|
||||
return strings.Contains(host, "googleapis.com")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func mockGraphHTTPClient(t *testing.T, handler http.HandlerFunc) *http.Client {
|
||||
t.Helper()
|
||||
srv := httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
return &http.Client{
|
||||
Transport: &hostRewriteTransport{
|
||||
mockBase: srv.URL,
|
||||
match: func(host string) bool {
|
||||
return strings.Contains(host, "graph.microsoft.com")
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
142
internal/migration/import_helpers.go
Normal file
142
internal/migration/import_helpers.go
Normal file
@ -0,0 +1,142 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
)
|
||||
|
||||
type progressUpdater func(status string, cursor, stats map[string]any, jobErr string) error
|
||||
|
||||
type migrationUser struct {
|
||||
Email string
|
||||
ExternalID string
|
||||
Name string
|
||||
}
|
||||
|
||||
func resolveMigrationUser(ctx context.Context, db *pgxpool.Pool, userID string) (migrationUser, error) {
|
||||
var u migrationUser
|
||||
err := db.QueryRow(ctx, `
|
||||
SELECT COALESCE(email, ''), COALESCE(external_id, ''), COALESCE(name, '')
|
||||
FROM users WHERE id = $1::uuid
|
||||
`, userID).Scan(&u.Email, &u.ExternalID, &u.Name)
|
||||
if err != nil {
|
||||
return migrationUser{}, fmt.Errorf("migration user not found")
|
||||
}
|
||||
if u.Email == "" {
|
||||
return migrationUser{}, fmt.Errorf("migration user email missing")
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func migrationHTTPClient() *http.Client {
|
||||
return &http.Client{Timeout: 90 * time.Second}
|
||||
}
|
||||
|
||||
func apiGet(ctx context.Context, client *http.Client, url, accessToken string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
resp, err := migrationDo(ctx, client, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func alreadyImported(store *ImportedItemStore, id string) bool {
|
||||
if store == nil {
|
||||
return false
|
||||
}
|
||||
return store.Has(id)
|
||||
}
|
||||
|
||||
func calendarSyncTokens(cursor map[string]any) map[string]string {
|
||||
raw, _ := cursor["calendarSyncTokens"].(map[string]any)
|
||||
out := make(map[string]string, len(raw))
|
||||
for k, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
out[k] = s
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func setCalendarSyncToken(cursor map[string]any, calID, token string) {
|
||||
if calID == "" || token == "" {
|
||||
return
|
||||
}
|
||||
raw, _ := cursor["calendarSyncTokens"].(map[string]any)
|
||||
if raw == nil {
|
||||
raw = map[string]any{}
|
||||
cursor["calendarSyncTokens"] = raw
|
||||
}
|
||||
raw[calID] = token
|
||||
}
|
||||
|
||||
func calendarDeltaLinks(cursor map[string]any) map[string]string {
|
||||
raw, _ := cursor["calendarDeltaLinks"].(map[string]any)
|
||||
out := make(map[string]string, len(raw))
|
||||
for k, v := range raw {
|
||||
if s, ok := v.(string); ok && s != "" {
|
||||
out[k] = s
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func setCalendarDeltaLink(cursor map[string]any, calID, link string) {
|
||||
if calID == "" || link == "" {
|
||||
return
|
||||
}
|
||||
raw, _ := cursor["calendarDeltaLinks"].(map[string]any)
|
||||
if raw == nil {
|
||||
raw = map[string]any{}
|
||||
cursor["calendarDeltaLinks"] = raw
|
||||
}
|
||||
raw[calID] = link
|
||||
}
|
||||
|
||||
func migrationContactPath(bookPath, provider, sourceID string) string {
|
||||
uid := sanitizeMigrationUID(provider, sourceID)
|
||||
return bookPath + uid + ".vcf"
|
||||
}
|
||||
|
||||
func migrationEventPath(calPath, provider, sourceID string) string {
|
||||
uid := sanitizeMigrationUID(provider, sourceID)
|
||||
return calPath + uid + ".ics"
|
||||
}
|
||||
|
||||
func sanitizeMigrationUID(provider, sourceID string) string {
|
||||
sourceID = strings.TrimSpace(sourceID)
|
||||
sourceID = strings.ReplaceAll(sourceID, "/", "-")
|
||||
return provider + "-" + sourceID + "@ultimail.migrated"
|
||||
}
|
||||
|
||||
func applyOAuthToken(cred credentials.Credential, token *oauthToken) credentials.Credential {
|
||||
cred.AuthType = credentials.AuthOAuth2
|
||||
cred.AccessToken = token.AccessToken
|
||||
if token.RefreshToken != "" {
|
||||
cred.RefreshToken = token.RefreshToken
|
||||
}
|
||||
if !token.Expiry.IsZero() {
|
||||
cred.Expiry = token.Expiry.UTC()
|
||||
}
|
||||
return cred
|
||||
}
|
||||
|
||||
type oauthToken struct {
|
||||
AccessToken string
|
||||
RefreshToken string
|
||||
Expiry time.Time
|
||||
}
|
||||
35
internal/migration/import_helpers_test.go
Normal file
35
internal/migration/import_helpers_test.go
Normal file
@ -0,0 +1,35 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMigrationContactAndEventPaths(t *testing.T) {
|
||||
book := "/remote.php/dav/addressbooks/user%40test.local/migration-import/"
|
||||
contact := migrationContactPath(book, "google", "people/abc")
|
||||
if contact != book+"google-people-abc@ultimail.migrated.vcf" {
|
||||
t.Fatalf("contact path: %q", contact)
|
||||
}
|
||||
cal := "/remote.php/dav/calendars/user%40test.local/migration-import/"
|
||||
event := migrationEventPath(cal, "microsoft", "cal1:evt1")
|
||||
if event != cal+"microsoft-cal1:evt1@ultimail.migrated.ics" {
|
||||
t.Fatalf("event path: %q", event)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsDeleteNotFound(t *testing.T) {
|
||||
if isDeleteNotFound(nil) {
|
||||
t.Fatal("nil is not not-found")
|
||||
}
|
||||
if !isDeleteNotFound(errors.New("delete failed: 404")) {
|
||||
t.Fatal("expected 404 as not-found")
|
||||
}
|
||||
if !isDeleteNotFound(fmt.Errorf("not found")) {
|
||||
t.Fatal("expected not found message")
|
||||
}
|
||||
if isDeleteNotFound(errors.New("permission denied")) {
|
||||
t.Fatal("unexpected not-found")
|
||||
}
|
||||
}
|
||||
206
internal/migration/imported_items.go
Normal file
206
internal/migration/imported_items.go
Normal file
@ -0,0 +1,206 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// ImportedItemStore tracks imported source IDs and optional relative paths for a migration job.
|
||||
// Data lives in migration_imported_items instead of unbounded cursor_json maps.
|
||||
type ImportedItemStore struct {
|
||||
db *pgxpool.Pool
|
||||
jobID string
|
||||
done map[string]struct{} // imported or skipped — resume skips these
|
||||
paths map[string]string
|
||||
}
|
||||
|
||||
func NewImportedItemStoreMemory() *ImportedItemStore {
|
||||
return &ImportedItemStore{
|
||||
done: map[string]struct{}{},
|
||||
paths: map[string]string{},
|
||||
}
|
||||
}
|
||||
|
||||
func LoadImportedItemStore(ctx context.Context, db *pgxpool.Pool, jobID string, cursor map[string]any) (*ImportedItemStore, error) {
|
||||
store := &ImportedItemStore{
|
||||
db: db,
|
||||
jobID: jobID,
|
||||
done: map[string]struct{}{},
|
||||
paths: map[string]string{},
|
||||
}
|
||||
if db != nil && jobID != "" {
|
||||
rows, err := db.Query(ctx, `
|
||||
SELECT source_id, rel_path, status
|
||||
FROM migration_imported_items
|
||||
WHERE job_id = $1::uuid
|
||||
`, jobID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
var sourceID, relPath, status string
|
||||
if err := rows.Scan(&sourceID, &relPath, &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isImportedItemDone(status) {
|
||||
store.done[sourceID] = struct{}{}
|
||||
}
|
||||
if relPath != "" {
|
||||
store.paths[sourceID] = relPath
|
||||
}
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := store.migrateLegacyCursor(ctx, cursor); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stripImportedCursorKeys(cursor)
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func isImportedItemDone(status string) bool {
|
||||
switch status {
|
||||
case "", ItemStatusImported, ItemStatusSkipped:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func stripImportedCursorKeys(cursor map[string]any) {
|
||||
if cursor == nil {
|
||||
return
|
||||
}
|
||||
delete(cursor, "imported_ids")
|
||||
delete(cursor, "imported_paths")
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) Has(id string) bool {
|
||||
_, ok := s.done[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) Path(id string) string {
|
||||
return s.paths[id]
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) MarkImported(ctx context.Context, id string) error {
|
||||
return s.upsertItem(ctx, id, "", ItemStatusImported, "")
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) MarkPath(ctx context.Context, id, relPath string) error {
|
||||
return s.upsertItem(ctx, id, relPath, ItemStatusImported, "")
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) MarkSkipped(ctx context.Context, id, reason, relPath string) error {
|
||||
return s.upsertItem(ctx, id, relPath, ItemStatusSkipped, reason)
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) MarkFailed(ctx context.Context, id, reason, relPath string) error {
|
||||
delete(s.done, id)
|
||||
delete(s.paths, id)
|
||||
if s.db == nil || s.jobID == "" || id == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO migration_imported_items (job_id, source_id, rel_path, status, reason)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5)
|
||||
ON CONFLICT (job_id, source_id) DO UPDATE
|
||||
SET rel_path = EXCLUDED.rel_path,
|
||||
status = EXCLUDED.status,
|
||||
reason = EXCLUDED.reason,
|
||||
imported_at = NOW()
|
||||
`, s.jobID, id, relPath, ItemStatusFailed, truncateReason(reason))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) upsertItem(ctx context.Context, id, relPath, status, reason string) error {
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
if isImportedItemDone(status) {
|
||||
s.done[id] = struct{}{}
|
||||
} else {
|
||||
delete(s.done, id)
|
||||
}
|
||||
if relPath != "" {
|
||||
s.paths[id] = relPath
|
||||
}
|
||||
if s.db == nil || s.jobID == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
INSERT INTO migration_imported_items (job_id, source_id, rel_path, status, reason)
|
||||
VALUES ($1::uuid, $2, $3, $4, $5)
|
||||
ON CONFLICT (job_id, source_id) DO UPDATE
|
||||
SET rel_path = EXCLUDED.rel_path,
|
||||
status = EXCLUDED.status,
|
||||
reason = EXCLUDED.reason,
|
||||
imported_at = NOW()
|
||||
`, s.jobID, id, relPath, status, truncateReason(reason))
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateReason(reason string) string {
|
||||
reason = strings.TrimSpace(reason)
|
||||
const maxLen = 2000
|
||||
if len(reason) <= maxLen {
|
||||
return reason
|
||||
}
|
||||
return reason[:maxLen]
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) Unmark(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return nil
|
||||
}
|
||||
delete(s.done, id)
|
||||
delete(s.paths, id)
|
||||
if s.db == nil || s.jobID == "" {
|
||||
return nil
|
||||
}
|
||||
_, err := s.db.Exec(ctx, `
|
||||
DELETE FROM migration_imported_items
|
||||
WHERE job_id = $1::uuid AND source_id = $2
|
||||
`, s.jobID, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *ImportedItemStore) migrateLegacyCursor(ctx context.Context, cursor map[string]any) error {
|
||||
if cursor == nil {
|
||||
return nil
|
||||
}
|
||||
rawIDs, _ := cursor["imported_ids"].(map[string]any)
|
||||
rawPaths, _ := cursor["imported_paths"].(map[string]any)
|
||||
if len(rawIDs) == 0 && len(rawPaths) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
seen := map[string]struct{}{}
|
||||
for id := range rawIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for id := range rawPaths {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
|
||||
for id := range seen {
|
||||
relPath, _ := rawPaths[id].(string)
|
||||
if relPath != "" {
|
||||
if err := s.MarkPath(ctx, id, relPath); err != nil {
|
||||
return fmt.Errorf("migrate imported path %q: %w", id, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := s.MarkImported(ctx, id); err != nil {
|
||||
return fmt.Errorf("migrate imported id %q: %w", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
109
internal/migration/imported_items_test.go
Normal file
109
internal/migration/imported_items_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImportedItemStoreStatusMemory(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
store := NewImportedItemStoreMemory()
|
||||
|
||||
if err := store.MarkImported(ctx, "ok-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !store.Has("ok-1") {
|
||||
t.Fatal("expected imported id in done set")
|
||||
}
|
||||
|
||||
if err := store.MarkSkipped(ctx, "skip-1", "too large", "big.bin"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !store.Has("skip-1") {
|
||||
t.Fatal("expected skipped id in done set")
|
||||
}
|
||||
if got := store.Path("skip-1"); got != "big.bin" {
|
||||
t.Fatalf("path = %q", got)
|
||||
}
|
||||
|
||||
if err := store.MarkFailed(ctx, "fail-1", "upload error", ""); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if store.Has("fail-1") {
|
||||
t.Fatal("failed item should not be in done set")
|
||||
}
|
||||
|
||||
// retry success clears failure
|
||||
if err := store.MarkImported(ctx, "fail-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !store.Has("fail-1") {
|
||||
t.Fatal("expected retried id in done set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportedItemStoreMemory(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
store := NewImportedItemStoreMemory()
|
||||
|
||||
if err := store.MarkPath(ctx, "id-1", "Docs/a.docx"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !store.Has("id-1") {
|
||||
t.Fatal("expected imported id")
|
||||
}
|
||||
if got := store.Path("id-1"); got != "Docs/a.docx" {
|
||||
t.Fatalf("path = %q", got)
|
||||
}
|
||||
if err := store.Unmark(ctx, "id-1"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if store.Has("id-1") {
|
||||
t.Fatal("expected id removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestImportedItemStoreMigratesLegacyCursor(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
cursor := map[string]any{
|
||||
"imported_ids": map[string]any{
|
||||
"file-1": true,
|
||||
},
|
||||
"imported_paths": map[string]any{
|
||||
"file-1": "Docs/report.docx",
|
||||
},
|
||||
}
|
||||
store, err := LoadImportedItemStore(ctx, nil, "", cursor)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !store.Has("file-1") {
|
||||
t.Fatal("expected migrated id")
|
||||
}
|
||||
if got := store.Path("file-1"); got != "Docs/report.docx" {
|
||||
t.Fatalf("path = %q", got)
|
||||
}
|
||||
if _, ok := cursor["imported_ids"]; ok {
|
||||
t.Fatal("expected imported_ids stripped from cursor")
|
||||
}
|
||||
if _, ok := cursor["imported_paths"]; ok {
|
||||
t.Fatal("expected imported_paths stripped from cursor")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeAuditStatusFilter(t *testing.T) {
|
||||
if got := normalizeAuditStatusFilter("failed"); got != ItemStatusFailed {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := normalizeAuditStatusFilter("bogus"); got != "" {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncJobStat(t *testing.T) {
|
||||
stats := map[string]any{}
|
||||
incJobStat(stats, "failed")
|
||||
incJobStat(stats, "failed")
|
||||
if stats["failed"] != float64(2) {
|
||||
t.Fatalf("failed = %v", stats["failed"])
|
||||
}
|
||||
}
|
||||
158
internal/migration/job_audit.go
Normal file
158
internal/migration/job_audit.go
Normal file
@ -0,0 +1,158 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
)
|
||||
|
||||
const (
|
||||
ItemStatusImported = "imported"
|
||||
ItemStatusFailed = "failed"
|
||||
ItemStatusSkipped = "skipped"
|
||||
)
|
||||
|
||||
// JobAuditItem is one row in a migration job item audit report.
|
||||
type JobAuditItem struct {
|
||||
SourceID string `json:"source_id"`
|
||||
RelPath string `json:"rel_path,omitempty"`
|
||||
Status string `json:"status"`
|
||||
Reason string `json:"reason,omitempty"`
|
||||
ImportedAt string `json:"imported_at"`
|
||||
}
|
||||
|
||||
// JobAuditSummary counts items by status for a migration job.
|
||||
type JobAuditSummary struct {
|
||||
Service string `json:"service"`
|
||||
Imported int64 `json:"imported"`
|
||||
Failed int64 `json:"failed"`
|
||||
Skipped int64 `json:"skipped"`
|
||||
Total int64 `json:"total"`
|
||||
ByStatus map[string]int64 `json:"by_status,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) verifyJobInProject(ctx context.Context, projectID, jobID string) (service string, err error) {
|
||||
err = s.db.QueryRow(ctx, `
|
||||
SELECT service FROM migration_jobs
|
||||
WHERE id = $1::uuid AND project_id = $2::uuid
|
||||
`, jobID, projectID).Scan(&service)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", fmt.Errorf("job not found")
|
||||
}
|
||||
return service, err
|
||||
}
|
||||
|
||||
func (s *Service) JobAuditSummary(ctx context.Context, projectID, jobID string) (JobAuditSummary, error) {
|
||||
service, err := s.verifyJobInProject(ctx, projectID, jobID)
|
||||
if err != nil {
|
||||
return JobAuditSummary{}, err
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT status, COUNT(*)
|
||||
FROM migration_imported_items
|
||||
WHERE job_id = $1::uuid
|
||||
GROUP BY status
|
||||
`, jobID)
|
||||
if err != nil {
|
||||
return JobAuditSummary{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
summary := JobAuditSummary{Service: service, ByStatus: map[string]int64{}}
|
||||
for rows.Next() {
|
||||
var status string
|
||||
var count int64
|
||||
if err := rows.Scan(&status, &count); err != nil {
|
||||
return JobAuditSummary{}, err
|
||||
}
|
||||
summary.ByStatus[status] = count
|
||||
summary.Total += count
|
||||
switch status {
|
||||
case ItemStatusImported:
|
||||
summary.Imported = count
|
||||
case ItemStatusFailed:
|
||||
summary.Failed = count
|
||||
case ItemStatusSkipped:
|
||||
summary.Skipped = count
|
||||
}
|
||||
}
|
||||
return summary, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) ListJobAudit(
|
||||
ctx context.Context,
|
||||
projectID, jobID, statusFilter string,
|
||||
params query.ListParams,
|
||||
) ([]JobAuditItem, query.PaginationMeta, error) {
|
||||
if _, err := s.verifyJobInProject(ctx, projectID, jobID); err != nil {
|
||||
return nil, query.PaginationMeta{}, err
|
||||
}
|
||||
|
||||
statusFilter = normalizeAuditStatusFilter(statusFilter)
|
||||
|
||||
var total int64
|
||||
countSQL := `SELECT COUNT(*) FROM migration_imported_items WHERE job_id = $1::uuid`
|
||||
countArgs := []any{jobID}
|
||||
if statusFilter != "" {
|
||||
countSQL += ` AND status = $2`
|
||||
countArgs = append(countArgs, statusFilter)
|
||||
}
|
||||
if err := s.db.QueryRow(ctx, countSQL, countArgs...).Scan(&total); err != nil {
|
||||
return nil, query.PaginationMeta{}, err
|
||||
}
|
||||
|
||||
listSQL := `
|
||||
SELECT source_id, rel_path, status, reason, imported_at::text
|
||||
FROM migration_imported_items
|
||||
WHERE job_id = $1::uuid
|
||||
`
|
||||
listArgs := []any{jobID}
|
||||
if statusFilter != "" {
|
||||
listSQL += ` AND status = $2`
|
||||
listArgs = append(listArgs, statusFilter)
|
||||
}
|
||||
listSQL += ` ORDER BY imported_at DESC, source_id ASC LIMIT $` + fmt.Sprint(len(listArgs)+1) +
|
||||
` OFFSET $` + fmt.Sprint(len(listArgs)+2)
|
||||
listArgs = append(listArgs, params.Limit(), params.Offset())
|
||||
|
||||
rows, err := s.db.Query(ctx, listSQL, listArgs...)
|
||||
if err != nil {
|
||||
return nil, query.PaginationMeta{}, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
out := make([]JobAuditItem, 0, params.Limit())
|
||||
for rows.Next() {
|
||||
var item JobAuditItem
|
||||
if err := rows.Scan(&item.SourceID, &item.RelPath, &item.Status, &item.Reason, &item.ImportedAt); err != nil {
|
||||
return nil, query.PaginationMeta{}, err
|
||||
}
|
||||
out = append(out, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, query.PaginationMeta{}, err
|
||||
}
|
||||
return out, params.Meta(&total), nil
|
||||
}
|
||||
|
||||
func normalizeAuditStatusFilter(raw string) string {
|
||||
switch raw {
|
||||
case ItemStatusImported, ItemStatusFailed, ItemStatusSkipped:
|
||||
return raw
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func incJobStat(stats map[string]any, key string) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
v, _ := stats[key].(float64)
|
||||
stats[key] = v + 1
|
||||
}
|
||||
143
internal/migration/microsoft_admin_consent.go
Normal file
143
internal/migration/microsoft_admin_consent.go
Normal file
@ -0,0 +1,143 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const adminConsentStatePrefix = "project:"
|
||||
|
||||
// MicrosoftAdminConsentRecord stores the outcome of a Microsoft admin consent redirect.
|
||||
type MicrosoftAdminConsentRecord struct {
|
||||
TenantID string
|
||||
ClientID string
|
||||
ProjectID string
|
||||
Granted bool
|
||||
ErrorCode string
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// MicrosoftAdminConsent is a persisted tenant-level admin consent row.
|
||||
type MicrosoftAdminConsent struct {
|
||||
TenantID string `json:"tenant_id"`
|
||||
ClientID string `json:"client_id"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
Granted bool `json:"granted"`
|
||||
ErrorCode string `json:"error_code,omitempty"`
|
||||
ErrorDescription string `json:"error_description,omitempty"`
|
||||
ConsentedAt string `json:"consented_at"`
|
||||
UpdatedAt string `json:"updated_at"`
|
||||
}
|
||||
|
||||
func EncodeAdminConsentState(projectID string) string {
|
||||
projectID = strings.TrimSpace(projectID)
|
||||
if projectID == "" {
|
||||
return ""
|
||||
}
|
||||
return adminConsentStatePrefix + projectID
|
||||
}
|
||||
|
||||
func ParseAdminConsentProjectID(state string) string {
|
||||
state = strings.TrimSpace(state)
|
||||
if !strings.HasPrefix(state, adminConsentStatePrefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(strings.TrimPrefix(state, adminConsentStatePrefix))
|
||||
}
|
||||
|
||||
func (s *Service) RecordMicrosoftAdminConsent(ctx context.Context, in MicrosoftAdminConsentRecord) error {
|
||||
if s.db == nil {
|
||||
return fmt.Errorf("database not configured")
|
||||
}
|
||||
tenantID := strings.TrimSpace(in.TenantID)
|
||||
clientID := strings.TrimSpace(in.ClientID)
|
||||
projectID := strings.TrimSpace(in.ProjectID)
|
||||
if tenantID == "" {
|
||||
return fmt.Errorf("tenant id required")
|
||||
}
|
||||
if clientID == "" {
|
||||
return fmt.Errorf("client id required")
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO migration_microsoft_admin_consents (
|
||||
tenant_id, client_id, project_id, granted, error_code, error_description
|
||||
) VALUES ($1, $2, NULLIF($3, '')::uuid, $4, $5, $6)
|
||||
ON CONFLICT (tenant_id, client_id) DO UPDATE SET
|
||||
project_id = COALESCE(EXCLUDED.project_id, migration_microsoft_admin_consents.project_id),
|
||||
granted = EXCLUDED.granted,
|
||||
error_code = EXCLUDED.error_code,
|
||||
error_description = EXCLUDED.error_description,
|
||||
consented_at = NOW(),
|
||||
updated_at = NOW()
|
||||
`, tenantID, clientID, projectID, in.Granted, in.ErrorCode, in.ErrorDescription)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if projectID != "" {
|
||||
if in.Granted {
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE migration_projects SET
|
||||
microsoft_tenant_id = $2,
|
||||
microsoft_admin_consent_at = NOW(),
|
||||
microsoft_admin_consent_error = '',
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, projectID, tenantID)
|
||||
} else {
|
||||
errMsg := strings.TrimSpace(in.ErrorDescription)
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(in.ErrorCode)
|
||||
}
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE migration_projects SET
|
||||
microsoft_tenant_id = CASE
|
||||
WHEN microsoft_tenant_id = '' THEN $2
|
||||
ELSE microsoft_tenant_id
|
||||
END,
|
||||
microsoft_admin_consent_error = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, projectID, tenantID, errMsg)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (s *Service) ListMicrosoftAdminConsents(ctx context.Context) ([]MicrosoftAdminConsent, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT tenant_id, client_id, COALESCE(project_id::text, ''), granted,
|
||||
error_code, error_description, consented_at::text, updated_at::text
|
||||
FROM migration_microsoft_admin_consents
|
||||
ORDER BY updated_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var out []MicrosoftAdminConsent
|
||||
for rows.Next() {
|
||||
var row MicrosoftAdminConsent
|
||||
if err := rows.Scan(
|
||||
&row.TenantID, &row.ClientID, &row.ProjectID, &row.Granted,
|
||||
&row.ErrorCode, &row.ErrorDescription, &row.ConsentedAt, &row.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
17
internal/migration/microsoft_admin_consent_test.go
Normal file
17
internal/migration/microsoft_admin_consent_test.go
Normal file
@ -0,0 +1,17 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestAdminConsentStateRoundTrip(t *testing.T) {
|
||||
const projectID = "550e8400-e29b-41d4-a716-446655440000"
|
||||
state := EncodeAdminConsentState(projectID)
|
||||
if state != adminConsentStatePrefix+projectID {
|
||||
t.Fatalf("encode: %q", state)
|
||||
}
|
||||
if got := ParseAdminConsentProjectID(state); got != projectID {
|
||||
t.Fatalf("parse: got %q", got)
|
||||
}
|
||||
if ParseAdminConsentProjectID("other") != "" {
|
||||
t.Fatal("expected empty for unrelated state")
|
||||
}
|
||||
}
|
||||
103
internal/migration/microsoft_app.go
Normal file
103
internal/migration/microsoft_app.go
Normal file
@ -0,0 +1,103 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/clientcredentials"
|
||||
)
|
||||
|
||||
// MicrosoftApp mints Graph access tokens via the client credentials flow.
|
||||
type MicrosoftApp struct {
|
||||
clientID string
|
||||
clientSecret string
|
||||
defaultTenant string
|
||||
tokenURL string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type MicrosoftAppConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
DefaultTenant string
|
||||
TokenURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func NewMicrosoftApp(cfg MicrosoftAppConfig) (*MicrosoftApp, error) {
|
||||
clientID := strings.TrimSpace(cfg.ClientID)
|
||||
clientSecret := strings.TrimSpace(cfg.ClientSecret)
|
||||
if clientID == "" || clientSecret == "" {
|
||||
return nil, nil
|
||||
}
|
||||
client := cfg.HTTPClient
|
||||
if client == nil {
|
||||
client = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &MicrosoftApp{
|
||||
clientID: clientID,
|
||||
clientSecret: clientSecret,
|
||||
defaultTenant: strings.TrimSpace(cfg.DefaultTenant),
|
||||
tokenURL: strings.TrimSpace(cfg.TokenURL),
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *MicrosoftApp) Enabled() bool {
|
||||
return m != nil && m.clientID != "" && m.clientSecret != ""
|
||||
}
|
||||
|
||||
func (m *MicrosoftApp) AccessToken(ctx context.Context, tenantID string) (string, error) {
|
||||
if !m.Enabled() {
|
||||
return "", fmt.Errorf("microsoft app-only auth not configured")
|
||||
}
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
tenantID = m.defaultTenant
|
||||
}
|
||||
if tenantID == "" {
|
||||
return "", fmt.Errorf("microsoft tenant id required for app-only auth")
|
||||
}
|
||||
tokenURL := m.tokenURL
|
||||
if tokenURL == "" {
|
||||
tokenURL = fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", url.PathEscape(tenantID))
|
||||
}
|
||||
cc := clientcredentials.Config{
|
||||
ClientID: m.clientID,
|
||||
ClientSecret: m.clientSecret,
|
||||
TokenURL: tokenURL,
|
||||
Scopes: []string{"https://graph.microsoft.com/.default"},
|
||||
}
|
||||
if m.client != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, m.client)
|
||||
}
|
||||
token, err := cc.Token(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("microsoft app token: %w", err)
|
||||
}
|
||||
if token.AccessToken == "" {
|
||||
return "", fmt.Errorf("microsoft app token empty")
|
||||
}
|
||||
return token.AccessToken, nil
|
||||
}
|
||||
|
||||
// WithClient overrides the HTTP client (tests).
|
||||
func (m *MicrosoftApp) WithClient(c *http.Client) *MicrosoftApp {
|
||||
if m != nil && c != nil {
|
||||
m.client = c
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func microsoftAppTokenURL(tenantID string) string {
|
||||
tenantID = strings.TrimSpace(tenantID)
|
||||
if tenantID == "" {
|
||||
tenantID = "common"
|
||||
}
|
||||
return fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", url.PathEscape(tenantID))
|
||||
}
|
||||
141
internal/migration/microsoft_app_test.go
Normal file
141
internal/migration/microsoft_app_test.go
Normal file
@ -0,0 +1,141 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNormalizeAuthModeMicrosoftApp(t *testing.T) {
|
||||
if got := NormalizeAuthMode("microsoft", "microsoft_app"); got != AuthModeMicrosoftApp {
|
||||
t.Fatalf("got %q", got)
|
||||
}
|
||||
if got := NormalizeAuthMode("google", "microsoft_app"); got != AuthModeOAuth {
|
||||
t.Fatalf("google ignores ms app: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsesUserOAuth(t *testing.T) {
|
||||
if UsesUserOAuth("google", AuthModeGoogleDWD) {
|
||||
t.Fatal("google dwd should skip user oauth")
|
||||
}
|
||||
if UsesUserOAuth("microsoft", AuthModeMicrosoftApp) {
|
||||
t.Fatal("microsoft app should skip user oauth")
|
||||
}
|
||||
if !UsesUserOAuth("microsoft", "oauth") {
|
||||
t.Fatal("microsoft oauth needs user oauth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGraphUserBase(t *testing.T) {
|
||||
if got := graphUserBase(""); got != "/v1.0/me" {
|
||||
t.Fatalf("empty upn: %q", got)
|
||||
}
|
||||
if got := graphUserBase("alice@contoso.com"); got != "/v1.0/users/alice@contoso.com" {
|
||||
t.Fatalf("encoded upn: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMicrosoftAppEmpty(t *testing.T) {
|
||||
app, err := NewMicrosoftApp(MicrosoftAppConfig{})
|
||||
if err != nil || app != nil {
|
||||
t.Fatalf("empty config: app=%v err=%v", app, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicrosoftAppAccessToken(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("method %s", r.Method)
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if r.Form.Get("grant_type") != "client_credentials" {
|
||||
t.Fatalf("grant_type=%q", r.Form.Get("grant_type"))
|
||||
}
|
||||
auth := r.Header.Get("Authorization")
|
||||
if !strings.HasPrefix(auth, "Basic ") {
|
||||
t.Fatalf("expected basic auth, got %q", auth)
|
||||
}
|
||||
if !strings.Contains(r.Form.Get("scope"), "graph.microsoft.com/.default") {
|
||||
t.Fatalf("scope=%q", r.Form.Get("scope"))
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"app-token","token_type":"Bearer","expires_in":3600}`)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
app, err := NewMicrosoftApp(MicrosoftAppConfig{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
TokenURL: srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
token, err := app.AccessToken(t.Context(), "tenant-123")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if token != "app-token" {
|
||||
t.Fatalf("token=%q", token)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicrosoftAppRequiresTenant(t *testing.T) {
|
||||
app, err := NewMicrosoftApp(MicrosoftAppConfig{
|
||||
ClientID: "client-id",
|
||||
ClientSecret: "client-secret",
|
||||
TokenURL: "http://example.invalid/token",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := app.AccessToken(t.Context(), ""); err == nil {
|
||||
t.Fatal("expected tenant required error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWorkerAuthPathSelection(t *testing.T) {
|
||||
cases := []struct {
|
||||
provider string
|
||||
authMode string
|
||||
userOAuth bool
|
||||
}{
|
||||
{"google", AuthModeOAuth, true},
|
||||
{"google", AuthModeGoogleDWD, false},
|
||||
{"microsoft", AuthModeOAuth, true},
|
||||
{"microsoft", AuthModeMicrosoftApp, false},
|
||||
{"microsoft", "google_dwd", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := UsesUserOAuth(tc.provider, tc.authMode)
|
||||
if got != tc.userOAuth {
|
||||
t.Fatalf("%s/%s: got userOAuth=%v want %v", tc.provider, tc.authMode, got, tc.userOAuth)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMicrosoftAppTokenError(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{"error": "invalid_client"})
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
app, err := NewMicrosoftApp(MicrosoftAppConfig{
|
||||
ClientID: "bad",
|
||||
ClientSecret: "bad",
|
||||
TokenURL: srv.URL,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := app.AccessToken(t.Context(), "tenant"); err == nil {
|
||||
t.Fatal("expected token error")
|
||||
}
|
||||
}
|
||||
228
internal/migration/oauth.go
Normal file
228
internal/migration/oauth.go
Normal file
@ -0,0 +1,228 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const pendingKeyPrefix = "migration_oauth_pending:"
|
||||
const pendingTTL = 15 * time.Minute
|
||||
|
||||
var ErrUnknownState = errors.New("migration oauth state expired or unknown")
|
||||
var ErrProviderDisabled = errors.New("migration oauth provider not configured")
|
||||
|
||||
type Provider string
|
||||
|
||||
const (
|
||||
ProviderGoogle Provider = "google"
|
||||
ProviderMicrosoft Provider = "microsoft"
|
||||
)
|
||||
|
||||
type PendingOAuth struct {
|
||||
UserID string `json:"user_id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
Provider string `json:"provider"`
|
||||
InviteToken string `json:"invite_token,omitempty"`
|
||||
PKCEVerifier string `json:"pkce_verifier"`
|
||||
}
|
||||
|
||||
type OAuthConfig struct {
|
||||
GoogleClientID string
|
||||
GoogleClientSecret string
|
||||
MicrosoftClientID string
|
||||
MicrosoftSecret string
|
||||
MicrosoftTenant string
|
||||
RedirectURL string
|
||||
}
|
||||
|
||||
type OAuthService struct {
|
||||
cfg OAuthConfig
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewOAuthService(cfg OAuthConfig, rdb *redis.Client) *OAuthService {
|
||||
return &OAuthService{cfg: cfg, rdb: rdb}
|
||||
}
|
||||
|
||||
func (s *OAuthService) EnabledProviders() []string {
|
||||
var out []string
|
||||
if s.providerConfig(ProviderGoogle) != nil {
|
||||
out = append(out, string(ProviderGoogle))
|
||||
}
|
||||
if s.providerConfig(ProviderMicrosoft) != nil {
|
||||
out = append(out, string(ProviderMicrosoft))
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *OAuthService) Start(ctx context.Context, pending PendingOAuth, provider Provider) (authURL, state string, err error) {
|
||||
oauthCfg := s.providerConfig(provider)
|
||||
if oauthCfg == nil {
|
||||
return "", "", ErrProviderDisabled
|
||||
}
|
||||
verifier, challenge, err := newPKCE()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
state, err = randomState()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
pending.Provider = string(provider)
|
||||
pending.PKCEVerifier = verifier
|
||||
if err := s.savePending(ctx, state, pending); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
authURL = oauthCfg.AuthCodeURL(state,
|
||||
oauth2.AccessTypeOffline,
|
||||
oauth2.SetAuthURLParam("code_challenge", challenge),
|
||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||
oauth2.SetAuthURLParam("prompt", "consent"),
|
||||
)
|
||||
return authURL, state, nil
|
||||
}
|
||||
|
||||
func (s *OAuthService) Exchange(ctx context.Context, state, code string) (PendingOAuth, *oauth2.Token, []string, error) {
|
||||
pending, err := s.loadPending(ctx, state)
|
||||
if err != nil {
|
||||
return PendingOAuth{}, nil, nil, err
|
||||
}
|
||||
oauthCfg := s.providerConfig(Provider(pending.Provider))
|
||||
if oauthCfg == nil {
|
||||
return PendingOAuth{}, nil, nil, ErrProviderDisabled
|
||||
}
|
||||
token, err := oauthCfg.Exchange(ctx, code, oauth2.SetAuthURLParam("code_verifier", pending.PKCEVerifier))
|
||||
if err != nil {
|
||||
return PendingOAuth{}, nil, nil, fmt.Errorf("token exchange: %w", err)
|
||||
}
|
||||
_ = s.rdb.Del(ctx, pendingKeyPrefix+state).Err()
|
||||
return pending, token, oauthCfg.Scopes, nil
|
||||
}
|
||||
|
||||
func (s *OAuthService) Refresh(ctx context.Context, provider Provider, refreshToken string) (*oauth2.Token, error) {
|
||||
oauthCfg := s.providerConfig(provider)
|
||||
if oauthCfg == nil {
|
||||
return nil, ErrProviderDisabled
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("refresh token required")
|
||||
}
|
||||
token, err := oauthCfg.TokenSource(ctx, &oauth2.Token{RefreshToken: refreshToken}).Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// AdminConsentURL returns the Microsoft tenant admin consent URL for migration scopes.
|
||||
func (s *OAuthService) AdminConsentURL(tenant, state string) (string, error) {
|
||||
if s.cfg.MicrosoftClientID == "" || s.cfg.RedirectURL == "" {
|
||||
return "", ErrProviderDisabled
|
||||
}
|
||||
tenant = strings.TrimSpace(tenant)
|
||||
if tenant == "" {
|
||||
tenant = s.cfg.MicrosoftTenant
|
||||
}
|
||||
if tenant == "" {
|
||||
tenant = "common"
|
||||
}
|
||||
values := url.Values{}
|
||||
values.Set("client_id", s.cfg.MicrosoftClientID)
|
||||
values.Set("redirect_uri", s.cfg.RedirectURL)
|
||||
if state = strings.TrimSpace(state); state != "" {
|
||||
values.Set("state", state)
|
||||
}
|
||||
return fmt.Sprintf("https://login.microsoftonline.com/%s/adminconsent?%s", url.PathEscape(tenant), values.Encode()), nil
|
||||
}
|
||||
|
||||
func (s *OAuthService) MicrosoftClientID() string {
|
||||
return s.cfg.MicrosoftClientID
|
||||
}
|
||||
|
||||
func (s *OAuthService) providerConfig(provider Provider) *oauth2.Config {
|
||||
switch provider {
|
||||
case ProviderGoogle:
|
||||
if s.cfg.GoogleClientID == "" || s.cfg.GoogleClientSecret == "" || s.cfg.RedirectURL == "" {
|
||||
return nil
|
||||
}
|
||||
return &oauth2.Config{
|
||||
ClientID: s.cfg.GoogleClientID,
|
||||
ClientSecret: s.cfg.GoogleClientSecret,
|
||||
RedirectURL: s.cfg.RedirectURL,
|
||||
Scopes: []string{
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/calendar.readonly",
|
||||
"https://www.googleapis.com/auth/contacts.readonly",
|
||||
},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: "https://accounts.google.com/o/oauth2/v2/auth",
|
||||
TokenURL: "https://oauth2.googleapis.com/token",
|
||||
},
|
||||
}
|
||||
case ProviderMicrosoft:
|
||||
if s.cfg.MicrosoftClientID == "" || s.cfg.MicrosoftSecret == "" || s.cfg.RedirectURL == "" {
|
||||
return nil
|
||||
}
|
||||
tenant := s.cfg.MicrosoftTenant
|
||||
if tenant == "" {
|
||||
tenant = "common"
|
||||
}
|
||||
return &oauth2.Config{
|
||||
ClientID: s.cfg.MicrosoftClientID,
|
||||
ClientSecret: s.cfg.MicrosoftSecret,
|
||||
RedirectURL: s.cfg.RedirectURL,
|
||||
Scopes: []string{
|
||||
"offline_access",
|
||||
"User.Read",
|
||||
"Mail.Read",
|
||||
"Files.Read.All",
|
||||
"Calendars.Read",
|
||||
"Contacts.Read",
|
||||
},
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/authorize", tenant),
|
||||
TokenURL: fmt.Sprintf("https://login.microsoftonline.com/%s/oauth2/v2.0/token", tenant),
|
||||
},
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OAuthService) savePending(ctx context.Context, state string, pending PendingOAuth) error {
|
||||
if s.rdb == nil {
|
||||
return errors.New("oauth state store unavailable")
|
||||
}
|
||||
raw, err := json.Marshal(pending)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.rdb.Set(ctx, pendingKeyPrefix+state, raw, pendingTTL).Err()
|
||||
}
|
||||
|
||||
func (s *OAuthService) loadPending(ctx context.Context, state string) (PendingOAuth, error) {
|
||||
if s.rdb == nil {
|
||||
return PendingOAuth{}, errors.New("oauth state store unavailable")
|
||||
}
|
||||
raw, err := s.rdb.Get(ctx, pendingKeyPrefix+state).Bytes()
|
||||
if err != nil {
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return PendingOAuth{}, ErrUnknownState
|
||||
}
|
||||
return PendingOAuth{}, err
|
||||
}
|
||||
var pending PendingOAuth
|
||||
if err := json.Unmarshal(raw, &pending); err != nil {
|
||||
return PendingOAuth{}, err
|
||||
}
|
||||
return pending, nil
|
||||
}
|
||||
61
internal/migration/oauth_admin_test.go
Normal file
61
internal/migration/oauth_admin_test.go
Normal file
@ -0,0 +1,61 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdminConsentURL(t *testing.T) {
|
||||
svc := NewOAuthService(OAuthConfig{
|
||||
MicrosoftClientID: "client-id",
|
||||
MicrosoftSecret: "secret",
|
||||
MicrosoftTenant: "contoso.onmicrosoft.com",
|
||||
RedirectURL: "https://suite.example.com/api/v1/migration/oauth/callback",
|
||||
}, nil)
|
||||
url, err := svc.AdminConsentURL("", EncodeAdminConsentState("550e8400-e29b-41d4-a716-446655440000"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.Contains(url, "adminconsent") {
|
||||
t.Fatalf("expected adminconsent url, got %q", url)
|
||||
}
|
||||
if !strings.Contains(url, "client-id") {
|
||||
t.Fatalf("missing client id: %q", url)
|
||||
}
|
||||
if !strings.Contains(url, "state=project%3A") {
|
||||
t.Fatalf("missing project state: %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminConsentURLWithoutState(t *testing.T) {
|
||||
svc := NewOAuthService(OAuthConfig{
|
||||
MicrosoftClientID: "client-id",
|
||||
MicrosoftSecret: "secret",
|
||||
RedirectURL: "https://suite.example.com/api/v1/migration/oauth/callback",
|
||||
}, nil)
|
||||
url, err := svc.AdminConsentURL("contoso.onmicrosoft.com", "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if strings.Contains(url, "state=") {
|
||||
t.Fatalf("unexpected state in url: %q", url)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeMigrationUID(t *testing.T) {
|
||||
uid := sanitizeMigrationUID("google", "people/abc123")
|
||||
if uid != "google-people-abc123@ultimail.migrated" {
|
||||
t.Fatalf("got %q", uid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseFlexibleTime(t *testing.T) {
|
||||
tm := parseFlexibleTime("2024-05-01T10:00:00Z", "")
|
||||
if tm.IsZero() {
|
||||
t.Fatal("expected datetime parse")
|
||||
}
|
||||
tm = parseFlexibleTime("", "2024-05-01")
|
||||
if tm.IsZero() {
|
||||
t.Fatal("expected date parse")
|
||||
}
|
||||
}
|
||||
78
internal/migration/oauth_refresh.go
Normal file
78
internal/migration/oauth_refresh.go
Normal file
@ -0,0 +1,78 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
)
|
||||
|
||||
// RefreshCredential refreshes an expired migration OAuth token and persists it.
|
||||
func (s *Service) RefreshCredential(
|
||||
ctx context.Context,
|
||||
oauth *OAuthService,
|
||||
userID, projectID, provider string,
|
||||
cred credentials.Credential,
|
||||
) (credentials.Credential, error) {
|
||||
if oauth == nil {
|
||||
return cred, fmt.Errorf("migration oauth not configured")
|
||||
}
|
||||
if !cred.NeedsRefresh() {
|
||||
return cred, nil
|
||||
}
|
||||
if cred.RefreshToken == "" {
|
||||
return cred, fmt.Errorf("migration refresh token missing; re-run OAuth consent")
|
||||
}
|
||||
token, err := oauth.Refresh(ctx, Provider(provider), cred.RefreshToken)
|
||||
if err != nil {
|
||||
return cred, fmt.Errorf("refresh migration token: %w", err)
|
||||
}
|
||||
updated := applyOAuthTokenFromOAuth2(cred, token)
|
||||
if err := s.SaveCredential(ctx, userID, projectID, provider, updated); err != nil {
|
||||
return cred, err
|
||||
}
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// SaveCredential encrypts and stores migration OAuth credentials.
|
||||
func (s *Service) SaveCredential(ctx context.Context, userID, projectID, provider string, cred credentials.Credential) error {
|
||||
if s.creds == nil {
|
||||
return fmt.Errorf("credential manager not configured")
|
||||
}
|
||||
cred.AuthType = credentials.AuthOAuth2
|
||||
cred.OAuthProvider = provider
|
||||
enc, err := s.creds.EncryptCredential(cred)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var expiresAt *time.Time
|
||||
if !cred.Expiry.IsZero() {
|
||||
expiresAt = &cred.Expiry
|
||||
}
|
||||
_, err = s.db.Exec(ctx, `
|
||||
UPDATE migration_credentials SET
|
||||
encrypted_token = $4,
|
||||
expires_at = $5,
|
||||
revoked_at = NULL
|
||||
WHERE user_id = $1::uuid AND project_id = $2::uuid AND provider = $3
|
||||
`, userID, projectID, provider, enc, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) MicrosoftAdminConsentURL(tenant, projectID string) (string, error) {
|
||||
if s.oauth == nil {
|
||||
return "", fmt.Errorf("migration oauth not configured")
|
||||
}
|
||||
return s.oauth.AdminConsentURL(tenant, EncodeAdminConsentState(projectID))
|
||||
}
|
||||
|
||||
func applyOAuthTokenFromOAuth2(cred credentials.Credential, token *oauth2.Token) credentials.Credential {
|
||||
return applyOAuthToken(cred, &oauthToken{
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: token.Expiry,
|
||||
})
|
||||
}
|
||||
93
internal/migration/onboarding.go
Normal file
93
internal/migration/onboarding.go
Normal file
@ -0,0 +1,93 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// OnboardingHints guides the user-facing migration setup flow.
|
||||
type OnboardingHints struct {
|
||||
NeedsUserOAuth bool `json:"needs_user_oauth"`
|
||||
OAuthProvider string `json:"oauth_provider,omitempty"`
|
||||
WaitingForAdmin bool `json:"waiting_for_admin"`
|
||||
WaitingReason string `json:"waiting_reason,omitempty"`
|
||||
HasMigrationCredentials bool `json:"has_migration_credentials"`
|
||||
NeedsMicrosoftAdminConsent bool `json:"needs_microsoft_admin_consent,omitempty"`
|
||||
}
|
||||
|
||||
func (s *Service) BuildOnboardingHints(ctx context.Context, userID string, proj Project, invite Invite) OnboardingHints {
|
||||
h := OnboardingHints{
|
||||
OAuthProvider: proj.SourceProvider,
|
||||
}
|
||||
|
||||
switch proj.Status {
|
||||
case "active", "cutover":
|
||||
// worker eligible
|
||||
default:
|
||||
h.WaitingForAdmin = true
|
||||
if proj.Status == "draft" {
|
||||
h.WaitingReason = "project_not_activated"
|
||||
} else {
|
||||
h.WaitingReason = "project_status_" + proj.Status
|
||||
}
|
||||
}
|
||||
|
||||
if UsesUserOAuth(proj.SourceProvider, proj.AuthMode) {
|
||||
h.NeedsUserOAuth = true
|
||||
if userID != "" {
|
||||
hasCred, err := s.hasMigrationCredential(ctx, userID, proj.ID, proj.SourceProvider)
|
||||
if err == nil {
|
||||
h.HasMigrationCredentials = hasCred
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if proj.SourceProvider == "microsoft" && proj.MicrosoftAdminConsentAt == nil {
|
||||
h.NeedsMicrosoftAdminConsent = true
|
||||
}
|
||||
|
||||
if invite.Status == "claimed" && h.WaitingForAdmin {
|
||||
h.WaitingReason = waitingReasonMessage(h.WaitingReason)
|
||||
}
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func waitingReasonMessage(code string) string {
|
||||
switch code {
|
||||
case "project_not_activated":
|
||||
return "project_not_activated"
|
||||
default:
|
||||
return code
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) hasMigrationCredential(ctx context.Context, userID, projectID, provider string) (bool, error) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return false, nil
|
||||
}
|
||||
var exists bool
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT EXISTS(
|
||||
SELECT 1 FROM migration_credentials
|
||||
WHERE user_id = $1::uuid AND project_id = $2::uuid AND provider = $3 AND revoked_at IS NULL
|
||||
)
|
||||
`, userID, projectID, provider).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (s *Service) BuildInviteOnboardingHints(proj Project, invite Invite) OnboardingHints {
|
||||
h := OnboardingHints{OAuthProvider: proj.SourceProvider}
|
||||
if invite.Status == "claimed" {
|
||||
h.NeedsUserOAuth = UsesUserOAuth(proj.SourceProvider, proj.AuthMode)
|
||||
return h
|
||||
}
|
||||
if UsesUserOAuth(proj.SourceProvider, proj.AuthMode) {
|
||||
h.NeedsUserOAuth = true
|
||||
}
|
||||
if proj.SourceProvider == "microsoft" && proj.MicrosoftAdminConsentAt == nil {
|
||||
h.NeedsMicrosoftAdminConsent = true
|
||||
}
|
||||
return h
|
||||
}
|
||||
75
internal/migration/onboarding_test.go
Normal file
75
internal/migration/onboarding_test.go
Normal file
@ -0,0 +1,75 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestBuildOnboardingHintsGoogleDWD(t *testing.T) {
|
||||
s := &Service{}
|
||||
h := s.BuildOnboardingHints(t.Context(), "user-id", Project{
|
||||
ID: "p1",
|
||||
SourceProvider: "google",
|
||||
AuthMode: AuthModeGoogleDWD,
|
||||
Status: "active",
|
||||
}, Invite{Status: "claimed"})
|
||||
if h.NeedsUserOAuth {
|
||||
t.Fatal("google dwd should not need user oauth")
|
||||
}
|
||||
if h.WaitingForAdmin {
|
||||
t.Fatal("active project should not wait for admin")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOnboardingHintsDraftProject(t *testing.T) {
|
||||
s := &Service{}
|
||||
h := s.BuildOnboardingHints(t.Context(), "", Project{
|
||||
ID: "p1",
|
||||
SourceProvider: "google",
|
||||
AuthMode: "oauth",
|
||||
Status: "draft",
|
||||
}, Invite{Status: "claimed"})
|
||||
if !h.WaitingForAdmin || h.WaitingReason != "project_not_activated" {
|
||||
t.Fatalf("expected wait activate, got %#v", h)
|
||||
}
|
||||
if !h.NeedsUserOAuth {
|
||||
t.Fatal("oauth mode needs user oauth")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOnboardingHintsMicrosoftConsent(t *testing.T) {
|
||||
s := &Service{}
|
||||
h := s.BuildOnboardingHints(t.Context(), "", Project{
|
||||
ID: "p1",
|
||||
SourceProvider: "microsoft",
|
||||
AuthMode: "oauth",
|
||||
Status: "active",
|
||||
}, Invite{Status: "claimed"})
|
||||
if !h.NeedsMicrosoftAdminConsent {
|
||||
t.Fatal("expected ms admin consent hint")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOnboardingHintsMicrosoftApp(t *testing.T) {
|
||||
s := &Service{}
|
||||
h := s.BuildOnboardingHints(t.Context(), "user-id", Project{
|
||||
ID: "p1",
|
||||
SourceProvider: "microsoft",
|
||||
AuthMode: AuthModeMicrosoftApp,
|
||||
Status: "active",
|
||||
MicrosoftAdminConsentAt: strPtr("2026-01-01T00:00:00Z"),
|
||||
}, Invite{Status: "claimed"})
|
||||
if h.NeedsUserOAuth {
|
||||
t.Fatal("microsoft app should not need user oauth")
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
func TestBuildInviteOnboardingHintsUnclaimed(t *testing.T) {
|
||||
s := &Service{}
|
||||
h := s.BuildInviteOnboardingHints(Project{
|
||||
SourceProvider: "google",
|
||||
AuthMode: AuthModeGoogleDWD,
|
||||
}, Invite{Status: "invited"})
|
||||
if h.NeedsUserOAuth {
|
||||
t.Fatal("dwd invite should not prompt oauth before claim")
|
||||
}
|
||||
}
|
||||
30
internal/migration/pkce.go
Normal file
30
internal/migration/pkce.go
Normal file
@ -0,0 +1,30 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
func base64URLEncode(b []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func sha256Sum(input string) ([]byte, error) {
|
||||
sum := sha256.Sum256([]byte(input))
|
||||
return sum[:], nil
|
||||
}
|
||||
|
||||
func newPKCE() (verifier, challenge string, err error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
verifier = base64.RawURLEncoding.EncodeToString(b)
|
||||
sum, err := sha256Sum(verifier)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
challenge = base64.RawURLEncoding.EncodeToString(sum)
|
||||
return verifier, challenge, nil
|
||||
}
|
||||
94
internal/migration/project_columns.go
Normal file
94
internal/migration/project_columns.go
Normal file
@ -0,0 +1,94 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
)
|
||||
|
||||
var projectSelectColumns = []string{
|
||||
"id::text",
|
||||
"COALESCE(domain_id::text, '')",
|
||||
"name",
|
||||
"source_provider",
|
||||
"auth_mode",
|
||||
"status",
|
||||
"cutover_at::text",
|
||||
"delta_mode",
|
||||
"created_at::text",
|
||||
"NULLIF(microsoft_tenant_id, '')",
|
||||
"microsoft_admin_consent_at::text",
|
||||
"COALESCE(NULLIF(microsoft_admin_consent_error, ''), '')",
|
||||
"cutover_dns_json",
|
||||
}
|
||||
|
||||
func projectSelectSQL(tablePrefix string) string {
|
||||
if tablePrefix != "" && !strings.HasSuffix(tablePrefix, ".") {
|
||||
tablePrefix += "."
|
||||
}
|
||||
cols := make([]string, len(projectSelectColumns))
|
||||
for i, col := range projectSelectColumns {
|
||||
cols[i] = tablePrefix + col
|
||||
}
|
||||
return strings.Join(cols, ", ")
|
||||
}
|
||||
|
||||
type projectScanner struct {
|
||||
project Project
|
||||
cutoverDNSRaw []byte
|
||||
}
|
||||
|
||||
func newProjectScanner() *projectScanner {
|
||||
return &projectScanner{}
|
||||
}
|
||||
|
||||
func (s *projectScanner) targets() []any {
|
||||
return []any{
|
||||
&s.project.ID, &s.project.DomainID, &s.project.Name, &s.project.SourceProvider,
|
||||
&s.project.AuthMode, &s.project.Status, &s.project.CutoverAt, &s.project.DeltaMode,
|
||||
&s.project.CreatedAt, &s.project.MicrosoftTenantID, &s.project.MicrosoftAdminConsentAt,
|
||||
&s.project.MicrosoftAdminConsentError, &s.cutoverDNSRaw,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *projectScanner) result() Project {
|
||||
applyCutoverDNS(&s.project, s.cutoverDNSRaw)
|
||||
return s.project
|
||||
}
|
||||
|
||||
func applyCutoverDNS(p *Project, raw []byte) {
|
||||
p.CutoverDNS = nil
|
||||
raw = bytesTrimSpace(raw)
|
||||
if len(raw) == 0 || string(raw) == "{}" || string(raw) == "null" {
|
||||
return
|
||||
}
|
||||
var report hosted.DNSCheckReport
|
||||
if err := json.Unmarshal(raw, &report); err != nil {
|
||||
return
|
||||
}
|
||||
if !dnsReportHasContent(report) {
|
||||
return
|
||||
}
|
||||
p.CutoverDNS = &report
|
||||
}
|
||||
|
||||
func dnsReportHasContent(r hosted.DNSCheckReport) bool {
|
||||
if strings.TrimSpace(r.Domain) != "" {
|
||||
return true
|
||||
}
|
||||
if len(r.Errors) > 0 || len(r.Warnings) > 0 {
|
||||
return true
|
||||
}
|
||||
if len(r.MXRecords) > 0 || len(r.TXTRecords) > 0 || len(r.ExpectedMX) > 0 {
|
||||
return true
|
||||
}
|
||||
if r.TXTVerified || r.MXVerified {
|
||||
return true
|
||||
}
|
||||
return strings.TrimSpace(r.TXTExpected) != ""
|
||||
}
|
||||
|
||||
func bytesTrimSpace(b []byte) []byte {
|
||||
return []byte(strings.TrimSpace(string(b)))
|
||||
}
|
||||
32
internal/migration/project_columns_test.go
Normal file
32
internal/migration/project_columns_test.go
Normal file
@ -0,0 +1,32 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
)
|
||||
|
||||
func TestApplyCutoverDNSEmptyObject(t *testing.T) {
|
||||
var p Project
|
||||
applyCutoverDNS(&p, []byte(`{}`))
|
||||
if p.CutoverDNS != nil {
|
||||
t.Fatal("expected nil cutover dns for empty object")
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCutoverDNSReport(t *testing.T) {
|
||||
var p Project
|
||||
applyCutoverDNS(&p, []byte(`{"domain":"acme.com","txt_verified":true,"mx_verified":false}`))
|
||||
if p.CutoverDNS == nil || p.CutoverDNS.Domain != "acme.com" {
|
||||
t.Fatalf("got %#v", p.CutoverDNS)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSReportHasContent(t *testing.T) {
|
||||
if dnsReportHasContent(hosted.DNSCheckReport{}) {
|
||||
t.Fatal("empty report should have no content")
|
||||
}
|
||||
if !dnsReportHasContent(hosted.DNSCheckReport{Warnings: []string{"no domain_id"}}) {
|
||||
t.Fatal("warnings should count as content")
|
||||
}
|
||||
}
|
||||
499
internal/migration/service.go
Normal file
499
internal/migration/service.go
Normal file
@ -0,0 +1,499 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInviteNotFound = errors.New("migration invite not found")
|
||||
ErrInviteClaimed = errors.New("migration invite already claimed")
|
||||
ErrEmailMismatch = errors.New("email does not match invite")
|
||||
ErrMigrationDomainNotActive = errors.New("migration project mail domain is not active")
|
||||
ErrMigrationDomainMismatch = errors.New("invite email domain does not match migration project domain")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *pgxpool.Pool
|
||||
rdb *redis.Client
|
||||
creds *credentials.Manager
|
||||
hosted *hosted.Service
|
||||
oauth *OAuthService
|
||||
cutover CutoverConfig
|
||||
}
|
||||
|
||||
func NewService(db *pgxpool.Pool, rdb *redis.Client, creds *credentials.Manager, hostedSvc *hosted.Service, oauth *OAuthService) *Service {
|
||||
return &Service{db: db, rdb: rdb, creds: creds, hosted: hostedSvc, oauth: oauth}
|
||||
}
|
||||
|
||||
func (s *Service) SetCutoverConfig(cfg CutoverConfig) {
|
||||
s.cutover = cfg
|
||||
}
|
||||
|
||||
type Project struct {
|
||||
ID string `json:"id"`
|
||||
DomainID string `json:"domain_id,omitempty"`
|
||||
Name string `json:"name"`
|
||||
SourceProvider string `json:"source_provider"`
|
||||
AuthMode string `json:"auth_mode"`
|
||||
Status string `json:"status"`
|
||||
CutoverAt *string `json:"cutover_at,omitempty"`
|
||||
DeltaMode bool `json:"delta_mode"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
MicrosoftTenantID string `json:"microsoft_tenant_id,omitempty"`
|
||||
MicrosoftAdminConsentAt *string `json:"microsoft_admin_consent_at,omitempty"`
|
||||
MicrosoftAdminConsentError string `json:"microsoft_admin_consent_error,omitempty"`
|
||||
CutoverDNS *hosted.DNSCheckReport `json:"cutover_dns,omitempty"`
|
||||
}
|
||||
|
||||
type Invite struct {
|
||||
ID string `json:"id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
Email string `json:"email"`
|
||||
AlternateEmails []string `json:"alternate_emails,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Status string `json:"status"`
|
||||
ClaimedAt *string `json:"claimed_at,omitempty"`
|
||||
UserID string `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type Job struct {
|
||||
ID string `json:"id"`
|
||||
ProjectID string `json:"project_id"`
|
||||
UserID string `json:"user_id"`
|
||||
Service string `json:"service"`
|
||||
Status string `json:"status"`
|
||||
CursorJSON map[string]any `json:"cursor_json"`
|
||||
StatsJSON map[string]any `json:"stats_json"`
|
||||
Error string `json:"error,omitempty"`
|
||||
StartedAt *string `json:"started_at,omitempty"`
|
||||
CompletedAt *string `json:"completed_at,omitempty"`
|
||||
}
|
||||
|
||||
type UserStatus struct {
|
||||
Project Project `json:"project"`
|
||||
Invite Invite `json:"invite,omitempty"`
|
||||
Jobs []Job `json:"jobs"`
|
||||
Onboarding OnboardingHints `json:"onboarding"`
|
||||
}
|
||||
|
||||
func (s *Service) CreateProject(ctx context.Context, name, sourceProvider, domainID, authMode string) (Project, error) {
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return Project{}, fmt.Errorf("project name required")
|
||||
}
|
||||
sourceProvider = strings.ToLower(strings.TrimSpace(sourceProvider))
|
||||
if sourceProvider == "" {
|
||||
sourceProvider = "google"
|
||||
}
|
||||
authMode = NormalizeAuthMode(sourceProvider, authMode)
|
||||
sc := newProjectScanner()
|
||||
err := s.db.QueryRow(ctx, `
|
||||
INSERT INTO migration_projects (name, source_provider, domain_id, auth_mode)
|
||||
VALUES ($1, $2, NULLIF($3, '')::uuid, $4)
|
||||
RETURNING `+projectSelectSQL("")+`
|
||||
`, name, sourceProvider, domainID, authMode).Scan(sc.targets()...)
|
||||
return sc.result(), err
|
||||
}
|
||||
|
||||
func (s *Service) ListProjects(ctx context.Context) ([]Project, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT `+projectSelectSQL("")+`
|
||||
FROM migration_projects ORDER BY created_at DESC
|
||||
`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Project
|
||||
for rows.Next() {
|
||||
sc := newProjectScanner()
|
||||
if err := rows.Scan(sc.targets()...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, sc.result())
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) CreateInvite(ctx context.Context, projectID, email string, alternateEmails []string) (Invite, error) {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if email == "" {
|
||||
return Invite{}, fmt.Errorf("email required")
|
||||
}
|
||||
alternates := normalizeAlternateEmails(email, alternateEmails)
|
||||
token, err := hosted.NewInviteToken()
|
||||
if err != nil {
|
||||
return Invite{}, err
|
||||
}
|
||||
var row Invite
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO migration_invites (project_id, email, token, alternate_emails)
|
||||
VALUES ($1::uuid, $2, $3, $4)
|
||||
RETURNING id::text, project_id::text, email, token, status, claimed_at::text, COALESCE(user_id::text, ''), alternate_emails
|
||||
`, projectID, email, token, alternates).Scan(
|
||||
&row.ID, &row.ProjectID, &row.Email, &row.Token, &row.Status, &row.ClaimedAt, &row.UserID, &row.AlternateEmails,
|
||||
)
|
||||
return row, err
|
||||
}
|
||||
|
||||
func normalizeAlternateEmails(inviteEmail string, alternateEmails []string) []string {
|
||||
inviteEmail = normalizeInviteEmail(inviteEmail)
|
||||
seen := map[string]struct{}{inviteEmail: {}}
|
||||
var out []string
|
||||
for _, raw := range alternateEmails {
|
||||
email := normalizeInviteEmail(raw)
|
||||
if email == "" || !isEmailAddress(email) {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[email]; ok {
|
||||
continue
|
||||
}
|
||||
seen[email] = struct{}{}
|
||||
out = append(out, email)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *Service) ImportInvites(ctx context.Context, projectID string, emails []string) (int, error) {
|
||||
count := 0
|
||||
for _, email := range emails {
|
||||
email = strings.ToLower(strings.TrimSpace(email))
|
||||
if email == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := s.CreateInvite(ctx, projectID, email, nil); err != nil {
|
||||
return count, err
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetInviteByToken(ctx context.Context, token string) (Invite, Project, error) {
|
||||
var inv Invite
|
||||
sc := newProjectScanner()
|
||||
scanArgs := append([]any{
|
||||
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID, &inv.AlternateEmails,
|
||||
}, sc.targets()...)
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT i.id::text, i.project_id::text, i.email, i.status, i.claimed_at::text, COALESCE(i.user_id::text, ''), i.alternate_emails,
|
||||
`+projectSelectSQL("p")+`
|
||||
FROM migration_invites i
|
||||
JOIN migration_projects p ON p.id = i.project_id
|
||||
WHERE i.token = $1
|
||||
`, token).Scan(scanArgs...)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return Invite{}, Project{}, ErrInviteNotFound
|
||||
}
|
||||
return inv, sc.result(), err
|
||||
}
|
||||
|
||||
func (s *Service) ClaimInvite(ctx context.Context, token, userID string, identity ClaimIdentity, displayName, password string) (UserStatus, error) {
|
||||
inv, proj, err := s.GetInviteByToken(ctx, token)
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
if inv.Status == "claimed" {
|
||||
return UserStatus{}, ErrInviteClaimed
|
||||
}
|
||||
|
||||
projectDomain := ""
|
||||
var hostedDomain *hosted.DomainRow
|
||||
if strings.TrimSpace(proj.DomainID) != "" && s.hosted != nil {
|
||||
domain, err := s.hosted.GetDomain(ctx, proj.DomainID)
|
||||
if err != nil {
|
||||
return UserStatus{}, fmt.Errorf("migration domain: %w", err)
|
||||
}
|
||||
hostedDomain = &domain
|
||||
projectDomain = domain.Name
|
||||
}
|
||||
if !InviteEmailMatchesIdentity(inv.Email, inv.AlternateEmails, projectDomain, identity) {
|
||||
return UserStatus{}, ErrEmailMismatch
|
||||
}
|
||||
mailboxEmail := normalizeInviteEmail(inv.Email)
|
||||
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, `
|
||||
UPDATE migration_invites
|
||||
SET status = 'claimed', claimed_at = NOW(), user_id = $1::uuid
|
||||
WHERE id = $2::uuid AND status = 'invited'
|
||||
`, userID, inv.ID)
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
|
||||
if s.hosted != nil {
|
||||
provision := hosted.ProvisionMailboxInput{
|
||||
UserID: userID,
|
||||
Email: mailboxEmail,
|
||||
DisplayName: displayName,
|
||||
Password: password,
|
||||
QuotaBytes: 0,
|
||||
}
|
||||
if hostedDomain != nil {
|
||||
at := strings.LastIndex(mailboxEmail, "@")
|
||||
if at <= 0 || !strings.EqualFold(mailboxEmail[at+1:], hostedDomain.Name) {
|
||||
return UserStatus{}, ErrMigrationDomainMismatch
|
||||
}
|
||||
if hostedDomain.Status != "active" && !hostedDomain.IsPlatformDomain {
|
||||
return UserStatus{}, ErrMigrationDomainNotActive
|
||||
}
|
||||
provision.DomainID = proj.DomainID
|
||||
}
|
||||
_, err = s.hosted.ProvisionMailbox(ctx, provision)
|
||||
if err != nil {
|
||||
if errors.Is(err, hosted.ErrDomainNotActive) {
|
||||
return UserStatus{}, ErrMigrationDomainNotActive
|
||||
}
|
||||
if !errors.Is(err, hosted.ErrAddressTaken) {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
services := []string{"mail", "contacts", "calendar", "drive"}
|
||||
for _, svc := range services {
|
||||
_, err = tx.Exec(ctx, `
|
||||
INSERT INTO migration_jobs (project_id, user_id, service, status)
|
||||
VALUES ($1::uuid, $2::uuid, $3, 'pending')
|
||||
ON CONFLICT (project_id, user_id, service) DO NOTHING
|
||||
`, proj.ID, userID, svc)
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
|
||||
return s.GetUserStatus(ctx, userID, proj.ID)
|
||||
}
|
||||
|
||||
func (s *Service) StoreMigrationToken(ctx context.Context, userID, projectID, provider string, token *oauth2.Token, scopes []string) error {
|
||||
if s.creds == nil {
|
||||
return fmt.Errorf("credential manager not configured")
|
||||
}
|
||||
payload, err := json.Marshal(map[string]any{
|
||||
"access_token": token.AccessToken,
|
||||
"refresh_token": token.RefreshToken,
|
||||
"expiry": token.Expiry.UTC().Format(time.RFC3339),
|
||||
"token_type": token.TokenType,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
enc, err := s.creds.EncryptCredential(credentials.Credential{
|
||||
AuthType: credentials.AuthOAuth2,
|
||||
AccessToken: token.AccessToken,
|
||||
RefreshToken: token.RefreshToken,
|
||||
Expiry: token.Expiry,
|
||||
OAuthProvider: provider,
|
||||
})
|
||||
if err != nil {
|
||||
_ = payload
|
||||
return err
|
||||
}
|
||||
var expiresAt *time.Time
|
||||
if !token.Expiry.IsZero() {
|
||||
expiresAt = &token.Expiry
|
||||
}
|
||||
_, err = s.db.Exec(ctx, `
|
||||
INSERT INTO migration_credentials (user_id, project_id, provider, encrypted_token, scopes, expires_at)
|
||||
VALUES ($1::uuid, $2::uuid, $3, $4, $5, $6)
|
||||
ON CONFLICT (user_id, project_id, provider) DO UPDATE SET
|
||||
encrypted_token = EXCLUDED.encrypted_token,
|
||||
scopes = EXCLUDED.scopes,
|
||||
expires_at = EXCLUDED.expires_at,
|
||||
revoked_at = NULL
|
||||
`, userID, projectID, provider, enc, scopes, expiresAt)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) GetUserStatus(ctx context.Context, userID, projectID string) (UserStatus, error) {
|
||||
sc := newProjectScanner()
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT `+projectSelectSQL("")+`
|
||||
FROM migration_projects WHERE id = $1::uuid
|
||||
`, projectID).Scan(sc.targets()...)
|
||||
proj := sc.result()
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
|
||||
var inv Invite
|
||||
_ = s.db.QueryRow(ctx, `
|
||||
SELECT id::text, project_id::text, email, status, claimed_at::text, COALESCE(user_id::text, '')
|
||||
FROM migration_invites WHERE project_id = $1::uuid AND user_id = $2::uuid
|
||||
`, projectID, userID).Scan(
|
||||
&inv.ID, &inv.ProjectID, &inv.Email, &inv.Status, &inv.ClaimedAt, &inv.UserID,
|
||||
)
|
||||
|
||||
jobs, err := s.listJobs(ctx, projectID, userID)
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
return UserStatus{
|
||||
Project: proj,
|
||||
Invite: inv,
|
||||
Jobs: jobs,
|
||||
Onboarding: s.BuildOnboardingHints(ctx, userID, proj, inv),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) GetActiveUserStatus(ctx context.Context, userID string) (UserStatus, error) {
|
||||
var projectID string
|
||||
err := s.db.QueryRow(ctx, `
|
||||
SELECT project_id::text FROM migration_invites
|
||||
WHERE user_id = $1::uuid AND status = 'claimed'
|
||||
ORDER BY claimed_at DESC NULLS LAST LIMIT 1
|
||||
`, userID).Scan(&projectID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return UserStatus{}, nil
|
||||
}
|
||||
if err != nil {
|
||||
return UserStatus{}, err
|
||||
}
|
||||
return s.GetUserStatus(ctx, userID, projectID)
|
||||
}
|
||||
|
||||
func (s *Service) listJobs(ctx context.Context, projectID, userID string) ([]Job, error) {
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id::text, project_id::text, user_id::text, service, status,
|
||||
cursor_json, stats_json, error, started_at::text, completed_at::text
|
||||
FROM migration_jobs
|
||||
WHERE project_id = $1::uuid AND user_id = $2::uuid
|
||||
ORDER BY service ASC
|
||||
`, projectID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var out []Job
|
||||
for rows.Next() {
|
||||
var row Job
|
||||
var cursorRaw, statsRaw []byte
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
||||
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
||||
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
||||
if row.CursorJSON == nil {
|
||||
row.CursorJSON = map[string]any{}
|
||||
}
|
||||
if row.StatsJSON == nil {
|
||||
row.StatsJSON = map[string]any{}
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) PendingJobs(ctx context.Context, limit int) ([]Job, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT j.id::text, j.project_id::text, j.user_id::text, j.service, j.status,
|
||||
j.cursor_json, j.stats_json, j.error, j.started_at::text, j.completed_at::text
|
||||
FROM migration_jobs j
|
||||
JOIN migration_projects p ON p.id = j.project_id
|
||||
WHERE j.status IN ('pending', 'running')
|
||||
AND p.status IN ('active', 'cutover')
|
||||
ORDER BY j.updated_at ASC
|
||||
LIMIT $1
|
||||
`, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
return scanJobs(rows)
|
||||
}
|
||||
|
||||
func scanJobs(rows pgx.Rows) ([]Job, error) {
|
||||
var out []Job
|
||||
for rows.Next() {
|
||||
var row Job
|
||||
var cursorRaw, statsRaw []byte
|
||||
if err := rows.Scan(
|
||||
&row.ID, &row.ProjectID, &row.UserID, &row.Service, &row.Status,
|
||||
&cursorRaw, &statsRaw, &row.Error, &row.StartedAt, &row.CompletedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_ = json.Unmarshal(cursorRaw, &row.CursorJSON)
|
||||
_ = json.Unmarshal(statsRaw, &row.StatsJSON)
|
||||
if row.CursorJSON == nil {
|
||||
row.CursorJSON = map[string]any{}
|
||||
}
|
||||
if row.StatsJSON == nil {
|
||||
row.StatsJSON = map[string]any{}
|
||||
}
|
||||
out = append(out, row)
|
||||
}
|
||||
return out, rows.Err()
|
||||
}
|
||||
|
||||
func (s *Service) UpdateJobProgress(ctx context.Context, jobID, status string, cursor, stats map[string]any, jobErr string) error {
|
||||
cursorRaw, _ := json.Marshal(cursor)
|
||||
statsRaw, _ := json.Marshal(stats)
|
||||
_, err := s.db.Exec(ctx, `
|
||||
UPDATE migration_jobs SET
|
||||
status = $2,
|
||||
cursor_json = $3,
|
||||
stats_json = $4,
|
||||
error = $5,
|
||||
started_at = COALESCE(started_at, CASE WHEN $2 = 'running' THEN NOW() ELSE NULL END),
|
||||
completed_at = CASE WHEN $2 IN ('completed', 'failed') THEN NOW() ELSE completed_at END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
`, jobID, status, cursorRaw, statsRaw, jobErr)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *Service) ActivateProject(ctx context.Context, projectID string) (Project, error) {
|
||||
sc := newProjectScanner()
|
||||
err := s.db.QueryRow(ctx, `
|
||||
UPDATE migration_projects SET status = 'active', updated_at = NOW()
|
||||
WHERE id = $1::uuid
|
||||
RETURNING `+projectSelectSQL("")+`
|
||||
`, projectID).Scan(sc.targets()...)
|
||||
return sc.result(), err
|
||||
}
|
||||
|
||||
func (s *Service) LookupUserID(ctx context.Context, externalID string) (string, error) {
|
||||
var userID string
|
||||
err := s.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
|
||||
return userID, err
|
||||
}
|
||||
|
||||
func randomState() (string, error) {
|
||||
b := make([]byte, 24)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
11
internal/migration/test_exports.go
Normal file
11
internal/migration/test_exports.go
Normal file
@ -0,0 +1,11 @@
|
||||
package migration
|
||||
|
||||
// GmailUIDForTest exposes gmailUID for integration tests.
|
||||
func GmailUIDForTest(gmailID string) int64 {
|
||||
return gmailUID(gmailID)
|
||||
}
|
||||
|
||||
// RemoteMessageUIDForTest exposes remoteMessageUID for integration tests.
|
||||
func RemoteMessageUIDForTest(graphID string) int64 {
|
||||
return remoteMessageUID(graphID)
|
||||
}
|
||||
285
internal/migration/worker.go
Normal file
285
internal/migration/worker.go
Normal file
@ -0,0 +1,285 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
mailstorage "github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/observability"
|
||||
)
|
||||
|
||||
type Worker struct {
|
||||
db *pgxpool.Pool
|
||||
svc *Service
|
||||
oauth *OAuthService
|
||||
creds *credentials.Manager
|
||||
googleDWD *GoogleDWD
|
||||
microsoftApp *MicrosoftApp
|
||||
nc *nextcloud.Client
|
||||
storage *mailstorage.Client
|
||||
attachBucket string
|
||||
concurrency int
|
||||
jobLimit int
|
||||
logger *slog.Logger
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// WorkerConfig tunes migration worker parallelism and job pickup.
|
||||
type WorkerConfig struct {
|
||||
Concurrency int
|
||||
JobLimit int
|
||||
}
|
||||
|
||||
func NewWorker(db *pgxpool.Pool, svc *Service, oauth *OAuthService, creds *credentials.Manager, googleDWD *GoogleDWD, microsoftApp *MicrosoftApp, nc *nextcloud.Client, storage *mailstorage.Client, attachBucket string, cfg WorkerConfig) *Worker {
|
||||
concurrency := cfg.Concurrency
|
||||
if concurrency <= 0 {
|
||||
concurrency = 1
|
||||
}
|
||||
jobLimit := cfg.JobLimit
|
||||
if jobLimit <= 0 {
|
||||
jobLimit = concurrency * 3
|
||||
if jobLimit < 5 {
|
||||
jobLimit = 5
|
||||
}
|
||||
}
|
||||
return &Worker{
|
||||
db: db,
|
||||
svc: svc,
|
||||
oauth: oauth,
|
||||
creds: creds,
|
||||
googleDWD: googleDWD,
|
||||
microsoftApp: microsoftApp,
|
||||
nc: nc,
|
||||
storage: storage,
|
||||
attachBucket: attachBucket,
|
||||
concurrency: concurrency,
|
||||
jobLimit: jobLimit,
|
||||
logger: slog.Default().With("component", "migration-worker"),
|
||||
client: &http.Client{Timeout: 60 * time.Second},
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) Start(ctx context.Context, interval time.Duration) {
|
||||
if interval <= 0 {
|
||||
interval = 30 * time.Second
|
||||
}
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
w.tick(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Worker) tick(ctx context.Context) {
|
||||
jobs, err := w.svc.PendingJobs(ctx, w.jobLimit)
|
||||
if err != nil {
|
||||
w.logger.Error("list pending migration jobs", "error", err)
|
||||
return
|
||||
}
|
||||
observability.SetMigrationPendingJobs(len(jobs))
|
||||
if len(jobs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
sem := make(chan struct{}, w.concurrency)
|
||||
var wg sync.WaitGroup
|
||||
for _, job := range jobs {
|
||||
wg.Add(1)
|
||||
sem <- struct{}{}
|
||||
go func(job Job) {
|
||||
defer wg.Done()
|
||||
defer func() { <-sem }()
|
||||
if _, err := w.processJob(ctx, job); err != nil {
|
||||
w.logger.Error("migration job failed", "job_id", job.ID, "service", job.Service, "error", err)
|
||||
}
|
||||
}(job)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (w *Worker) processJob(ctx context.Context, job Job) (string, error) {
|
||||
start := time.Now()
|
||||
outcome := "unknown"
|
||||
defer func() {
|
||||
observability.ObserveMigrationJob(job.Service, outcome, time.Since(start))
|
||||
}()
|
||||
if err := w.svc.UpdateJobProgress(ctx, job.ID, "running", job.CursorJSON, job.StatsJSON, ""); err != nil {
|
||||
outcome = "failed"
|
||||
return outcome, err
|
||||
}
|
||||
|
||||
var provider string
|
||||
var delta bool
|
||||
var authMode string
|
||||
err := w.db.QueryRow(ctx, `
|
||||
SELECT p.source_provider, p.delta_mode, p.auth_mode
|
||||
FROM migration_projects p WHERE p.id = $1::uuid
|
||||
`, job.ProjectID).Scan(&provider, &delta, &authMode)
|
||||
if err != nil {
|
||||
outcome = "failed"
|
||||
return outcome, err
|
||||
}
|
||||
|
||||
accessToken, graphUserUPN, err := w.loadAccessToken(ctx, job, provider, authMode)
|
||||
if err != nil {
|
||||
outcome = "failed"
|
||||
_ = w.svc.UpdateJobProgress(ctx, job.ID, "failed", job.CursorJSON, job.StatsJSON, err.Error())
|
||||
return outcome, err
|
||||
}
|
||||
|
||||
var lastStatus string
|
||||
update := func(status string, cursor, stats map[string]any, jobErr string) error {
|
||||
lastStatus = status
|
||||
return w.svc.UpdateJobProgress(ctx, job.ID, status, cursor, stats, jobErr)
|
||||
}
|
||||
|
||||
var procErr error
|
||||
var selfManaged bool
|
||||
switch job.Service {
|
||||
case "mail":
|
||||
selfManaged = true
|
||||
if provider == "google" {
|
||||
procErr = NewGmailImporter(w.db).WithStorage(w.storage, w.attachBucket).ImportBatch(ctx, &job, accessToken, delta, update)
|
||||
} else {
|
||||
procErr = NewGraphImporter(w.db).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, delta, update)
|
||||
}
|
||||
case "contacts":
|
||||
selfManaged = true
|
||||
procErr = NewContactsImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update)
|
||||
case "calendar":
|
||||
selfManaged = true
|
||||
procErr = NewCalendarImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update)
|
||||
case "drive":
|
||||
selfManaged = true
|
||||
procErr = NewDriveImporter(w.db, w.nc).WithUserPrincipal(graphUserUPN).ImportBatch(ctx, &job, accessToken, provider, delta, update)
|
||||
default:
|
||||
procErr = fmt.Errorf("unknown service %q", job.Service)
|
||||
}
|
||||
|
||||
if procErr != nil {
|
||||
if IsRateLimitError(procErr) {
|
||||
if job.StatsJSON == nil {
|
||||
job.StatsJSON = map[string]any{}
|
||||
}
|
||||
job.StatsJSON["rate_limited"] = true
|
||||
job.StatsJSON["rate_limit_at"] = time.Now().UTC().Format(time.RFC3339)
|
||||
outcome = "rate_limited"
|
||||
return outcome, w.svc.UpdateJobProgress(ctx, job.ID, "pending", job.CursorJSON, job.StatsJSON, procErr.Error())
|
||||
}
|
||||
outcome = "failed"
|
||||
return outcome, w.svc.UpdateJobProgress(ctx, job.ID, "failed", job.CursorJSON, job.StatsJSON, procErr.Error())
|
||||
}
|
||||
if selfManaged {
|
||||
switch lastStatus {
|
||||
case "completed":
|
||||
outcome = "completed"
|
||||
case "pending":
|
||||
outcome = "pending"
|
||||
default:
|
||||
outcome = "completed"
|
||||
}
|
||||
return outcome, nil
|
||||
}
|
||||
outcome = "completed"
|
||||
return outcome, w.svc.UpdateJobProgress(ctx, job.ID, "completed", job.CursorJSON, job.StatsJSON, "")
|
||||
}
|
||||
|
||||
func (w *Worker) loadAccessToken(ctx context.Context, job Job, provider, authMode string) (accessToken, graphUserUPN string, err error) {
|
||||
if provider == "google" && authMode == AuthModeGoogleDWD {
|
||||
if w.googleDWD == nil || !w.googleDWD.Enabled() {
|
||||
return "", "", fmt.Errorf("google domain-wide delegation not configured")
|
||||
}
|
||||
email, err := w.inviteEmail(ctx, job.ProjectID, job.UserID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
token, err := w.googleDWD.AccessToken(ctx, email)
|
||||
return token, "", err
|
||||
}
|
||||
if provider == "microsoft" && authMode == AuthModeMicrosoftApp {
|
||||
if w.microsoftApp == nil || !w.microsoftApp.Enabled() {
|
||||
return "", "", fmt.Errorf("microsoft app-only auth not configured")
|
||||
}
|
||||
tenantID, err := w.projectMicrosoftTenant(ctx, job.ProjectID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
email, err := w.inviteEmail(ctx, job.ProjectID, job.UserID)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
token, err := w.microsoftApp.AccessToken(ctx, tenantID)
|
||||
return token, email, err
|
||||
}
|
||||
cred, err := w.loadToken(ctx, job.UserID, job.ProjectID, provider)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
if w.oauth != nil && cred.NeedsRefresh() {
|
||||
cred, err = w.svc.RefreshCredential(ctx, w.oauth, job.UserID, job.ProjectID, provider, cred)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
}
|
||||
return cred.AccessToken, "", nil
|
||||
}
|
||||
|
||||
func (w *Worker) projectMicrosoftTenant(ctx context.Context, projectID string) (string, error) {
|
||||
var tenantID string
|
||||
err := w.db.QueryRow(ctx, `
|
||||
SELECT COALESCE(NULLIF(microsoft_tenant_id, ''), '')
|
||||
FROM migration_projects WHERE id = $1::uuid
|
||||
`, projectID).Scan(&tenantID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("migration project tenant lookup: %w", err)
|
||||
}
|
||||
if strings.TrimSpace(tenantID) == "" {
|
||||
return "", fmt.Errorf("microsoft tenant id missing: complete admin consent first")
|
||||
}
|
||||
return tenantID, nil
|
||||
}
|
||||
|
||||
func (w *Worker) inviteEmail(ctx context.Context, projectID, userID string) (string, error) {
|
||||
var email string
|
||||
err := w.db.QueryRow(ctx, `
|
||||
SELECT email FROM migration_invites
|
||||
WHERE project_id = $1::uuid AND user_id = $2::uuid
|
||||
ORDER BY claimed_at DESC NULLS LAST LIMIT 1
|
||||
`, projectID, userID).Scan(&email)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("migration invite email missing for domain-wide delegation")
|
||||
}
|
||||
return email, nil
|
||||
}
|
||||
|
||||
func (w *Worker) loadToken(ctx context.Context, userID, projectID, provider string) (credentials.Credential, error) {
|
||||
var blob []byte
|
||||
err := w.db.QueryRow(ctx, `
|
||||
SELECT encrypted_token FROM migration_credentials
|
||||
WHERE user_id = $1::uuid AND project_id = $2::uuid AND provider = $3 AND revoked_at IS NULL
|
||||
`, userID, projectID, provider).Scan(&blob)
|
||||
if err != nil {
|
||||
return credentials.Credential{}, fmt.Errorf("migration credentials missing: run OAuth consent first")
|
||||
}
|
||||
cred, err := w.creds.DecryptCredential(blob)
|
||||
if err != nil {
|
||||
return credentials.Credential{}, err
|
||||
}
|
||||
cred.AuthType = credentials.AuthOAuth2
|
||||
cred.OAuthProvider = provider
|
||||
return cred, nil
|
||||
}
|
||||
26
internal/migration/worker_test.go
Normal file
26
internal/migration/worker_test.go
Normal file
@ -0,0 +1,26 @@
|
||||
package migration
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestWorkerConfigDefaults(t *testing.T) {
|
||||
w := NewWorker(nil, nil, nil, nil, nil, nil, nil, nil, "", WorkerConfig{})
|
||||
if w.concurrency != 1 {
|
||||
t.Fatalf("concurrency = %d", w.concurrency)
|
||||
}
|
||||
if w.jobLimit != 5 {
|
||||
t.Fatalf("jobLimit = %d", w.jobLimit)
|
||||
}
|
||||
|
||||
w = NewWorker(nil, nil, nil, nil, nil, nil, nil, nil, "", WorkerConfig{Concurrency: 4})
|
||||
if w.concurrency != 4 {
|
||||
t.Fatalf("concurrency = %d", w.concurrency)
|
||||
}
|
||||
if w.jobLimit != 12 {
|
||||
t.Fatalf("jobLimit = %d", w.jobLimit)
|
||||
}
|
||||
|
||||
w = NewWorker(nil, nil, nil, nil, nil, nil, nil, nil, "", WorkerConfig{Concurrency: 2, JobLimit: 8})
|
||||
if w.jobLimit != 8 {
|
||||
t.Fatalf("jobLimit = %d", w.jobLimit)
|
||||
}
|
||||
}
|
||||
@ -77,6 +77,27 @@ var (
|
||||
Name: "ultid_webhook_payload_truncated_total",
|
||||
Help: "Total number of webhook payloads truncated in logs.",
|
||||
})
|
||||
|
||||
migrationJobsProcessedTotal = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Name: "ultid_migration_jobs_processed_total",
|
||||
Help: "Total number of migration worker job runs.",
|
||||
}, []string{"service", "outcome"})
|
||||
|
||||
migrationJobDurationSeconds = promauto.NewHistogramVec(prometheus.HistogramOpts{
|
||||
Name: "ultid_migration_job_duration_seconds",
|
||||
Help: "Migration worker job run duration in seconds.",
|
||||
Buckets: []float64{0.1, 0.25, 0.5, 1, 2, 5, 10, 30, 60, 120, 300},
|
||||
}, []string{"service", "outcome"})
|
||||
|
||||
migrationPendingJobs = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "ultid_migration_pending_jobs",
|
||||
Help: "Migration jobs picked up on the latest worker tick.",
|
||||
})
|
||||
|
||||
migrationRateLimitRetriesTotal = promauto.NewCounter(prometheus.CounterOpts{
|
||||
Name: "ultid_migration_rate_limit_retries_total",
|
||||
Help: "Total number of migration provider API 429 retries.",
|
||||
})
|
||||
)
|
||||
|
||||
type metricsResponseWriter struct {
|
||||
@ -142,3 +163,22 @@ func IncWebhookDeadLetter() {
|
||||
func IncWebhookPayloadTruncated() {
|
||||
webhookPayloadTruncatedTotal.Inc()
|
||||
}
|
||||
|
||||
func ObserveMigrationJob(service, outcome string, duration time.Duration) {
|
||||
if service == "" {
|
||||
service = "unknown"
|
||||
}
|
||||
if outcome == "" {
|
||||
outcome = "unknown"
|
||||
}
|
||||
migrationJobsProcessedTotal.WithLabelValues(service, outcome).Inc()
|
||||
migrationJobDurationSeconds.WithLabelValues(service, outcome).Observe(duration.Seconds())
|
||||
}
|
||||
|
||||
func SetMigrationPendingJobs(count int) {
|
||||
migrationPendingJobs.Set(float64(count))
|
||||
}
|
||||
|
||||
func IncMigrationRateLimitRetry() {
|
||||
migrationRateLimitRetriesTotal.Inc()
|
||||
}
|
||||
|
||||
75
internal/provision/authentik.go
Normal file
75
internal/provision/authentik.go
Normal file
@ -0,0 +1,75 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func authorizeProvision(r *http.Request, secret string) bool {
|
||||
if secret == "" {
|
||||
return false
|
||||
}
|
||||
if r.Header.Get("X-Provision-Secret") == secret {
|
||||
return true
|
||||
}
|
||||
if r.URL.Query().Get("secret") == secret {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type authentikWebhookPayload struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
ExternalID string `json:"external_id"`
|
||||
Sub string `json:"sub"`
|
||||
User struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
UUID string `json:"uuid"`
|
||||
PK int `json:"pk"`
|
||||
} `json:"user"`
|
||||
}
|
||||
|
||||
func decodeProvisionBody(r *http.Request) (provisionUserRequest, error) {
|
||||
var ak authentikWebhookPayload
|
||||
if err := json.NewDecoder(r.Body).Decode(&ak); err != nil {
|
||||
return provisionUserRequest{}, err
|
||||
}
|
||||
|
||||
req := provisionUserRequest{
|
||||
Email: firstNonEmpty(ak.Email, ak.User.Email),
|
||||
Username: firstNonEmpty(ak.Username, ak.User.Username),
|
||||
Name: firstNonEmpty(ak.Name, ak.User.Name),
|
||||
Password: ak.Password,
|
||||
ExternalID: firstNonEmpty(ak.ExternalID, ak.Sub, ak.User.UUID),
|
||||
}
|
||||
if req.ExternalID == "" && ak.User.PK > 0 {
|
||||
req.ExternalID = strconv.Itoa(ak.User.PK)
|
||||
}
|
||||
normalizeProvisionRequest(&req)
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func normalizeProvisionRequest(dst *provisionUserRequest) {
|
||||
if dst.Email == "" {
|
||||
dst.Email = strings.ToLower(strings.TrimSpace(dst.Username))
|
||||
}
|
||||
if dst.Name == "" {
|
||||
dst.Name = dst.Email
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
151
internal/provision/handler.go
Normal file
151
internal/provision/handler.go
Normal file
@ -0,0 +1,151 @@
|
||||
package provision
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
"github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/users"
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
secret string
|
||||
platformDomain string
|
||||
hosted *hosted.Service
|
||||
nc *nextcloud.Client
|
||||
db *pgxpool.Pool
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewHandler(secret, platformDomain string, hostedSvc *hosted.Service, nc *nextcloud.Client, db *pgxpool.Pool) *Handler {
|
||||
return &Handler{
|
||||
secret: strings.TrimSpace(secret),
|
||||
platformDomain: strings.ToLower(strings.TrimSpace(platformDomain)),
|
||||
hosted: hostedSvc,
|
||||
nc: nc,
|
||||
db: db,
|
||||
logger: slog.Default().With("component", "provision"),
|
||||
}
|
||||
}
|
||||
|
||||
type provisionUserRequest struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
Password string `json:"password"`
|
||||
ExternalID string `json:"external_id"`
|
||||
}
|
||||
|
||||
func (h *Handler) ProvisionUser(w http.ResponseWriter, r *http.Request) {
|
||||
if h.secret == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, "provision_disabled", "provisioning webhook not configured", nil)
|
||||
return
|
||||
}
|
||||
if !authorizeProvision(r, h.secret) {
|
||||
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "invalid provision secret", nil)
|
||||
return
|
||||
}
|
||||
|
||||
req, err := decodeProvisionBody(r)
|
||||
if err != nil {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "invalid_json", "invalid request body", nil)
|
||||
return
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(req.Email))
|
||||
if email == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "validation_error", "email required", nil)
|
||||
return
|
||||
}
|
||||
|
||||
if h.platformDomain != "" && !strings.Contains(email, "@") {
|
||||
email = email + "@" + h.platformDomain
|
||||
} else if h.platformDomain != "" && !strings.HasSuffix(email, "@"+h.platformDomain) {
|
||||
local := strings.Split(email, "@")[0]
|
||||
email = local + "@" + h.platformDomain
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
var userID string
|
||||
externalID := strings.TrimSpace(req.ExternalID)
|
||||
if externalID != "" {
|
||||
err := h.db.QueryRow(ctx, `SELECT id::text FROM users WHERE external_id = $1`, externalID).Scan(&userID)
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
userID, err = users.EnsureUser(ctx, h.db, &auth.Claims{
|
||||
Sub: externalID,
|
||||
Email: email,
|
||||
Name: req.Name,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
h.logger.Error("ensure user", "error", err)
|
||||
apiresponse.WriteError(w, r, http.StatusInternalServerError, "internal_error", "failed to provision user", nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
result, err := h.hosted.ProvisionMailbox(ctx, hosted.ProvisionMailboxInput{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
DisplayName: req.Name,
|
||||
Password: req.Password,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, hosted.ErrAddressTaken) {
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "address_taken", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
h.logger.Error("provision mailbox", "error", err, "email", email)
|
||||
apiresponse.WriteError(w, r, http.StatusConflict, "provision_failed", err.Error(), nil)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "" {
|
||||
_ = migration.LinkHostedMailboxByEmail(ctx, h.db, userID, email)
|
||||
}
|
||||
|
||||
if h.nc != nil && userID != "" && externalID != "" {
|
||||
if _, err := h.nc.EnsurePrincipal(ctx, email, externalID, req.Name); err != nil {
|
||||
h.logger.Warn("nextcloud provision", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
||||
"user_id": userID,
|
||||
"email": email,
|
||||
"mailbox_id": result.Mailbox.ID,
|
||||
"mail_account_id": result.MailAccountID,
|
||||
})
|
||||
}
|
||||
|
||||
// CheckAddress validates local part availability (Authentik expression policy or public API).
|
||||
func (h *Handler) CheckAddress(w http.ResponseWriter, r *http.Request) {
|
||||
if h.hosted == nil {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": true})
|
||||
return
|
||||
}
|
||||
local := strings.TrimSpace(r.URL.Query().Get("local"))
|
||||
domain := strings.TrimSpace(r.URL.Query().Get("domain"))
|
||||
if domain == "" {
|
||||
domain = h.platformDomain
|
||||
}
|
||||
if local == "" || domain == "" {
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, "validation_error", "local and domain required", nil)
|
||||
return
|
||||
}
|
||||
available, err := h.hosted.IsAddressAvailable(r.Context(), domain, local)
|
||||
if err != nil {
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": false, "reason": err.Error()})
|
||||
return
|
||||
}
|
||||
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"available": available})
|
||||
}
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
@ -22,6 +23,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/docs"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/drive"
|
||||
mailapi "github.com/ultisuite/ulti-backend/internal/api/mail"
|
||||
migrationapi "github.com/ultisuite/ulti-backend/internal/api/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/mail/drivebridge"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/mail/sendguard"
|
||||
meetapi "github.com/ultisuite/ulti-backend/internal/api/meet"
|
||||
@ -40,6 +42,10 @@ import (
|
||||
mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/hosted"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/stalwart"
|
||||
"github.com/ultisuite/ulti-backend/internal/migration"
|
||||
"github.com/ultisuite/ulti-backend/internal/provision"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/smtp"
|
||||
mailstorage "github.com/ultisuite/ulti-backend/internal/mail/storage"
|
||||
@ -208,6 +214,80 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
RedirectURL: oauthRedirect,
|
||||
}, rdb)
|
||||
|
||||
stlwClient := stalwart.NewClient(stalwart.Config{
|
||||
Enabled: cfg.StalwartEnabled,
|
||||
BaseURL: cfg.StalwartAPIURL,
|
||||
APIKey: cfg.StalwartAPIKey,
|
||||
IMAPHost: cfg.StalwartIMAPHost,
|
||||
IMAPPort: cfg.StalwartIMAPPort,
|
||||
IMAPTLS: cfg.StalwartIMAPTLS,
|
||||
SMTPHost: cfg.StalwartSMTPHost,
|
||||
SMTPPort: cfg.StalwartSMTPPort,
|
||||
SMTPTLS: cfg.StalwartSMTPTLS,
|
||||
})
|
||||
hostedSvc := hosted.NewService(pool, stlwClient, credentialManager)
|
||||
if cfg.PlatformMailDomain != "" {
|
||||
if _, err := hostedSvc.EnsurePlatformDomain(workerCtx, cfg.PlatformMailDomain); err != nil {
|
||||
slog.Warn("platform mail domain bootstrap failed", "domain", cfg.PlatformMailDomain, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
migrationOAuthRedirect := cfg.MigrationOAuthRedirectURL
|
||||
if migrationOAuthRedirect == "" {
|
||||
migrationOAuthRedirect = fmt.Sprintf("http://localhost:%d/api/v1/migration/oauth/callback", cfg.Port)
|
||||
if cfg.Domain != "" && cfg.Domain != "localhost" {
|
||||
migrationOAuthRedirect = fmt.Sprintf("https://%s/api/v1/migration/oauth/callback", cfg.Domain)
|
||||
}
|
||||
}
|
||||
migrationOAuthSvc := migration.NewOAuthService(migration.OAuthConfig{
|
||||
GoogleClientID: firstNonEmpty(cfg.MigrationGoogleOAuthClientID, cfg.MailGoogleOAuthClientID),
|
||||
GoogleClientSecret: firstNonEmpty(cfg.MigrationGoogleOAuthClientSecret, cfg.MailGoogleOAuthClientSecret),
|
||||
MicrosoftClientID: firstNonEmpty(cfg.MigrationMicrosoftOAuthClientID, cfg.MailMicrosoftOAuthClientID),
|
||||
MicrosoftSecret: firstNonEmpty(cfg.MigrationMicrosoftOAuthSecret, cfg.MailMicrosoftOAuthSecret),
|
||||
MicrosoftTenant: firstNonEmpty(cfg.MigrationMicrosoftOAuthTenant, cfg.MailMicrosoftOAuthTenant),
|
||||
RedirectURL: migrationOAuthRedirect,
|
||||
}, rdb)
|
||||
migrationSvc := migration.NewService(pool, rdb, credentialManager, hostedSvc, migrationOAuthSvc)
|
||||
migrationSvc.SetCutoverConfig(migration.CutoverConfig{
|
||||
ExpectedMXHosts: migration.ParseCutoverMXHosts(
|
||||
cfg.MigrationCutoverMXHosts,
|
||||
cfg.PlatformMailDomain,
|
||||
cfg.StalwartIMAPHost,
|
||||
),
|
||||
RequireMX: cfg.MigrationCutoverRequireMX,
|
||||
})
|
||||
googleDWD, err := migration.NewGoogleDWD(cfg.MigrationGoogleServiceAccountJSON)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("google dwd: %w", err)
|
||||
}
|
||||
microsoftApp, err := migration.NewMicrosoftApp(migration.MicrosoftAppConfig{
|
||||
ClientID: firstNonEmpty(cfg.MigrationMicrosoftOAuthClientID, cfg.MailMicrosoftOAuthClientID),
|
||||
ClientSecret: firstNonEmpty(cfg.MigrationMicrosoftOAuthSecret, cfg.MailMicrosoftOAuthSecret),
|
||||
DefaultTenant: firstNonEmpty(cfg.MigrationMicrosoftOAuthTenant, cfg.MailMicrosoftOAuthTenant),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("microsoft app: %w", err)
|
||||
}
|
||||
migration.ConfigureRateLimit(migration.RateLimitConfig{
|
||||
MaxRetries: cfg.MigrationRateLimitMaxRetries,
|
||||
BaseDelay: cfg.MigrationRateLimitBaseDelay,
|
||||
MaxDelay: cfg.MigrationRateLimitMaxDelay,
|
||||
})
|
||||
migration.ConfigureImportBatch(migration.ImportBatchConfig{
|
||||
Mail: cfg.MigrationImportBatchSize,
|
||||
Drive: cfg.MigrationDriveBatchSize,
|
||||
})
|
||||
if !opts.WithoutWorkers {
|
||||
go migration.NewWorker(
|
||||
pool, migrationSvc, migrationOAuthSvc, credentialManager, googleDWD, microsoftApp, ncClient,
|
||||
attachmentStorage, cfg.MailAttachmentsBucket,
|
||||
migration.WorkerConfig{
|
||||
Concurrency: cfg.MigrationWorkerConcurrency,
|
||||
JobLimit: cfg.MigrationWorkerJobLimit,
|
||||
},
|
||||
).Start(workerCtx, cfg.MigrationWorkerInterval)
|
||||
}
|
||||
|
||||
orgPolicyLoader := orgpolicy.NewLoader(pool, cfg)
|
||||
fileScanner := filescan.NewScanner(orgPolicyLoader, slog.Default())
|
||||
|
||||
@ -238,6 +318,9 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
|
||||
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
|
||||
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
|
||||
mailHandler.SetHostedService(hostedSvc)
|
||||
migrationHandler := migrationapi.NewHandler(migrationSvc, migrationOAuthSvc, cfg.MailAppURL)
|
||||
provisionHandler := provision.NewHandler(cfg.ProvisionWebhookSecret, cfg.PlatformMailDomain, hostedSvc, ncClient, pool)
|
||||
mailHandler.SetFileScanner(fileScanner)
|
||||
if syncWorker != nil {
|
||||
mailHandler.SetAccountSync(syncWorker)
|
||||
@ -265,6 +348,10 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
|
||||
r.Get("/ws", hub.HandleWS)
|
||||
r.Get("/api/v1/mail/accounts/oauth/callback", mailHandler.OAuthCallback)
|
||||
r.Get("/api/v1/migration/oauth/callback", migrationHandler.OAuthCallback)
|
||||
r.Get("/api/v1/mail/addresses/check", mailHandler.CheckAddressAvailability)
|
||||
r.Get("/api/v1/migration/invite", migrationHandler.GetInvite)
|
||||
r.Post("/internal/provision/user", provisionHandler.ProvisionUser)
|
||||
|
||||
var driveHandler *drive.Handler
|
||||
var driveSvc *drive.Service
|
||||
@ -325,7 +412,10 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
r.Use(middleware.EnforceApiTokenPolicy())
|
||||
|
||||
r.Mount("/api/v1/users", usersapi.NewHandler(pool).Routes())
|
||||
r.Mount("/api/v1/admin", admin.NewHandler(pool, auditLogger, cfg, ncClient).Routes())
|
||||
adminHandler := admin.NewHandler(pool, auditLogger, cfg, ncClient)
|
||||
adminHandler.SetHostedService(hostedSvc)
|
||||
adminHandler.SetMigrationService(migrationSvc)
|
||||
r.Mount("/api/v1/admin", adminHandler.Routes())
|
||||
if driveHandler != nil {
|
||||
r.Mount("/api/v1/drive", driveHandler.Routes())
|
||||
}
|
||||
@ -333,6 +423,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middleware.RequireFullAccount)
|
||||
r.Mount("/api/v1/mail", mailHandler.Routes())
|
||||
r.Mount("/api/v1/migration", migrationHandler.Routes())
|
||||
r.Get("/api/v1/search", search.NewHandler(pool, search.Options{
|
||||
Nextcloud: ncClient,
|
||||
Engine: cfg.SearchEngine,
|
||||
@ -400,3 +491,12 @@ func closeOwned(ownsPool bool, pool *pgxpool.Pool, ownsRedis bool, rdb *redis.Cl
|
||||
_ = rdb.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func firstNonEmpty(values ...string) string {
|
||||
for _, v := range values {
|
||||
if strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/migration"
|
||||
)
|
||||
|
||||
// ProvisionEmail returns the email stored for a newly provisioned user.
|
||||
@ -58,5 +59,6 @@ func EnsureUser(ctx context.Context, db *pgxpool.Pool, claims *auth.Claims) (str
|
||||
return "", fmt.Errorf("bootstrap platform admin: %w", err)
|
||||
}
|
||||
}
|
||||
_ = migration.LinkHostedMailboxByEmail(ctx, db, userID, email)
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
6
migrations/000040_hosted_mail_migration.down.sql
Normal file
6
migrations/000040_hosted_mail_migration.down.sql
Normal file
@ -0,0 +1,6 @@
|
||||
DROP TABLE IF EXISTS migration_credentials;
|
||||
DROP TABLE IF EXISTS migration_jobs;
|
||||
DROP TABLE IF EXISTS migration_invites;
|
||||
DROP TABLE IF EXISTS migration_projects;
|
||||
DROP TABLE IF EXISTS mailboxes;
|
||||
DROP TABLE IF EXISTS mail_domains;
|
||||
96
migrations/000040_hosted_mail_migration.up.sql
Normal file
96
migrations/000040_hosted_mail_migration.up.sql
Normal file
@ -0,0 +1,96 @@
|
||||
-- Hosted mail domains and mailboxes (Stalwart provisioning)
|
||||
CREATE TABLE mail_domains (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
status TEXT NOT NULL DEFAULT 'pending_verification',
|
||||
verification_token TEXT NOT NULL DEFAULT '',
|
||||
dkim_selector TEXT NOT NULL DEFAULT '',
|
||||
dkim_public_key TEXT NOT NULL DEFAULT '',
|
||||
stalwart_domain_id TEXT NOT NULL DEFAULT '',
|
||||
is_platform_domain BOOLEAN NOT NULL DEFAULT false,
|
||||
mx_verified_at TIMESTAMPTZ,
|
||||
txt_verified_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_mail_domains_status ON mail_domains(status);
|
||||
|
||||
CREATE TABLE mailboxes (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
domain_id UUID NOT NULL REFERENCES mail_domains(id) ON DELETE CASCADE,
|
||||
local_part TEXT NOT NULL,
|
||||
user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
mail_account_id UUID REFERENCES mail_accounts(id) ON DELETE SET NULL,
|
||||
stalwart_account_id TEXT NOT NULL DEFAULT '',
|
||||
quota_bytes BIGINT NOT NULL DEFAULT 5368709120,
|
||||
status TEXT NOT NULL DEFAULT 'active',
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(domain_id, local_part)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_mailboxes_user ON mailboxes(user_id);
|
||||
CREATE INDEX idx_mailboxes_mail_account ON mailboxes(mail_account_id);
|
||||
|
||||
-- Migration projects (Google Workspace / Microsoft 365 transitions)
|
||||
CREATE TABLE migration_projects (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
domain_id UUID REFERENCES mail_domains(id) ON DELETE SET NULL,
|
||||
name TEXT NOT NULL DEFAULT '',
|
||||
source_provider TEXT NOT NULL DEFAULT 'google',
|
||||
status TEXT NOT NULL DEFAULT 'draft',
|
||||
cutover_at TIMESTAMPTZ,
|
||||
delta_mode BOOLEAN NOT NULL DEFAULT false,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX idx_migration_projects_status ON migration_projects(status);
|
||||
|
||||
CREATE TABLE migration_invites (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
|
||||
email TEXT NOT NULL,
|
||||
token TEXT NOT NULL UNIQUE,
|
||||
status TEXT NOT NULL DEFAULT 'invited',
|
||||
claimed_at TIMESTAMPTZ,
|
||||
user_id UUID REFERENCES users(id) ON DELETE SET NULL,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(project_id, email)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_migration_invites_token ON migration_invites(token);
|
||||
CREATE INDEX idx_migration_invites_email ON migration_invites(email);
|
||||
|
||||
CREATE TABLE migration_jobs (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
service TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'pending',
|
||||
cursor_json JSONB NOT NULL DEFAULT '{}',
|
||||
stats_json JSONB NOT NULL DEFAULT '{}',
|
||||
error TEXT NOT NULL DEFAULT '',
|
||||
started_at TIMESTAMPTZ,
|
||||
completed_at TIMESTAMPTZ,
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(project_id, user_id, service)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_migration_jobs_status ON migration_jobs(status);
|
||||
|
||||
CREATE TABLE migration_credentials (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE,
|
||||
project_id UUID NOT NULL REFERENCES migration_projects(id) ON DELETE CASCADE,
|
||||
provider TEXT NOT NULL,
|
||||
encrypted_token BYTEA NOT NULL,
|
||||
scopes TEXT[] NOT NULL DEFAULT '{}',
|
||||
expires_at TIMESTAMPTZ,
|
||||
revoked_at TIMESTAMPTZ,
|
||||
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(user_id, project_id, provider)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_migration_credentials_user ON migration_credentials(user_id);
|
||||
2
migrations/000041_migration_auth_mode.down.sql
Normal file
2
migrations/000041_migration_auth_mode.down.sql
Normal file
@ -0,0 +1,2 @@
|
||||
ALTER TABLE migration_projects DROP CONSTRAINT IF EXISTS migration_projects_auth_mode_check;
|
||||
ALTER TABLE migration_projects DROP COLUMN IF EXISTS auth_mode;
|
||||
7
migrations/000041_migration_auth_mode.up.sql
Normal file
7
migrations/000041_migration_auth_mode.up.sql
Normal file
@ -0,0 +1,7 @@
|
||||
-- Migration auth mode: per-user OAuth vs Google domain-wide delegation
|
||||
ALTER TABLE migration_projects
|
||||
ADD COLUMN auth_mode TEXT NOT NULL DEFAULT 'oauth';
|
||||
|
||||
ALTER TABLE migration_projects
|
||||
ADD CONSTRAINT migration_projects_auth_mode_check
|
||||
CHECK (auth_mode IN ('oauth', 'google_dwd'));
|
||||
1
migrations/000042_migration_imported_items.down.sql
Normal file
1
migrations/000042_migration_imported_items.down.sql
Normal file
@ -0,0 +1 @@
|
||||
DROP TABLE IF EXISTS migration_imported_items;
|
||||
11
migrations/000042_migration_imported_items.up.sql
Normal file
11
migrations/000042_migration_imported_items.up.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- Track per-job imported source IDs and drive paths outside cursor_json.
|
||||
CREATE TABLE migration_imported_items (
|
||||
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
|
||||
job_id UUID NOT NULL REFERENCES migration_jobs(id) ON DELETE CASCADE,
|
||||
source_id TEXT NOT NULL,
|
||||
rel_path TEXT NOT NULL DEFAULT '',
|
||||
imported_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
UNIQUE(job_id, source_id)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_migration_imported_items_job ON migration_imported_items(job_id);
|
||||
1
migrations/000043_migration_cutover_dns.down.sql
Normal file
1
migrations/000043_migration_cutover_dns.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE migration_projects DROP COLUMN IF EXISTS cutover_dns_json;
|
||||
3
migrations/000043_migration_cutover_dns.up.sql
Normal file
3
migrations/000043_migration_cutover_dns.up.sql
Normal file
@ -0,0 +1,3 @@
|
||||
-- Persist last migration cutover DNS check for admin dashboards.
|
||||
ALTER TABLE migration_projects
|
||||
ADD COLUMN IF NOT EXISTS cutover_dns_json JSONB NOT NULL DEFAULT '{}';
|
||||
@ -0,0 +1,6 @@
|
||||
ALTER TABLE migration_projects
|
||||
DROP COLUMN IF EXISTS microsoft_admin_consent_error,
|
||||
DROP COLUMN IF EXISTS microsoft_admin_consent_at,
|
||||
DROP COLUMN IF EXISTS microsoft_tenant_id;
|
||||
|
||||
DROP TABLE IF EXISTS migration_microsoft_admin_consents;
|
||||
20
migrations/000044_migration_microsoft_admin_consent.up.sql
Normal file
20
migrations/000044_migration_microsoft_admin_consent.up.sql
Normal file
@ -0,0 +1,20 @@
|
||||
-- Persist Microsoft tenant admin consent for migration OAuth app registration.
|
||||
CREATE TABLE IF NOT EXISTS migration_microsoft_admin_consents (
|
||||
tenant_id TEXT NOT NULL,
|
||||
client_id TEXT NOT NULL,
|
||||
project_id UUID REFERENCES migration_projects(id) ON DELETE SET NULL,
|
||||
granted BOOLEAN NOT NULL DEFAULT false,
|
||||
error_code TEXT NOT NULL DEFAULT '',
|
||||
error_description TEXT NOT NULL DEFAULT '',
|
||||
consented_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
PRIMARY KEY (tenant_id, client_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_migration_ms_admin_consents_project
|
||||
ON migration_microsoft_admin_consents(project_id);
|
||||
|
||||
ALTER TABLE migration_projects
|
||||
ADD COLUMN IF NOT EXISTS microsoft_tenant_id TEXT NOT NULL DEFAULT '',
|
||||
ADD COLUMN IF NOT EXISTS microsoft_admin_consent_at TIMESTAMPTZ,
|
||||
ADD COLUMN IF NOT EXISTS microsoft_admin_consent_error TEXT NOT NULL DEFAULT '';
|
||||
@ -0,0 +1,2 @@
|
||||
ALTER TABLE migration_invites
|
||||
DROP COLUMN IF EXISTS alternate_emails;
|
||||
@ -0,0 +1,2 @@
|
||||
ALTER TABLE migration_invites
|
||||
ADD COLUMN alternate_emails TEXT[] NOT NULL DEFAULT '{}';
|
||||
8
migrations/000046_migration_item_audit.down.sql
Normal file
8
migrations/000046_migration_item_audit.down.sql
Normal file
@ -0,0 +1,8 @@
|
||||
DROP INDEX IF EXISTS idx_migration_imported_items_job_status;
|
||||
|
||||
ALTER TABLE migration_imported_items
|
||||
DROP CONSTRAINT IF EXISTS migration_imported_items_status_check;
|
||||
|
||||
ALTER TABLE migration_imported_items
|
||||
DROP COLUMN IF EXISTS reason,
|
||||
DROP COLUMN IF EXISTS status;
|
||||
11
migrations/000046_migration_item_audit.up.sql
Normal file
11
migrations/000046_migration_item_audit.up.sql
Normal file
@ -0,0 +1,11 @@
|
||||
-- Per-item migration audit: track success, failure, and skip with reason.
|
||||
ALTER TABLE migration_imported_items
|
||||
ADD COLUMN status TEXT NOT NULL DEFAULT 'imported',
|
||||
ADD COLUMN reason TEXT NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE migration_imported_items
|
||||
ADD CONSTRAINT migration_imported_items_status_check
|
||||
CHECK (status IN ('imported', 'failed', 'skipped'));
|
||||
|
||||
CREATE INDEX idx_migration_imported_items_job_status
|
||||
ON migration_imported_items(job_id, status);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user