- Added endpoints for listing and importing migration rosters. - Introduced audit export functionality for migration jobs in CSV and NDJSON formats. - Implemented tenant mismatch validation for Microsoft migration claims. - Enhanced error handling for email claiming and migration processes. - Added integration tests for roster import and claim workflows.
260 lines
8.2 KiB
Go
260 lines
8.2 KiB
Go
package migrationapi
|
|
|
|
import (
|
|
"log/slog"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
|
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
|
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
|
migr "github.com/ultisuite/ulti-backend/internal/migration"
|
|
)
|
|
|
|
const maxMigrationRequestBody = 1 << 20
|
|
|
|
type Handler struct {
|
|
svc *migr.Service
|
|
oauth *migr.OAuthService
|
|
appURL string
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func NewHandler(svc *migr.Service, oauth *migr.OAuthService, appURL string) *Handler {
|
|
return &Handler{
|
|
svc: svc,
|
|
oauth: oauth,
|
|
appURL: strings.TrimRight(strings.TrimSpace(appURL), "/"),
|
|
logger: slog.Default().With("component", "migration-api"),
|
|
}
|
|
}
|
|
|
|
func (h *Handler) Routes() chi.Router {
|
|
r := chi.NewRouter()
|
|
r.Post("/claim", h.ClaimInvite)
|
|
r.Get("/status", h.GetStatus)
|
|
r.Get("/oauth/providers", h.ListOAuthProviders)
|
|
r.Post("/oauth/start", h.StartOAuth)
|
|
return r
|
|
}
|
|
|
|
func (h *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) {
|
|
if handled, err := h.handleMicrosoftAdminConsentCallback(w, r); handled {
|
|
if err != nil {
|
|
h.logger.Error("microsoft admin consent", "error", err)
|
|
}
|
|
return
|
|
}
|
|
state := r.URL.Query().Get("state")
|
|
code := r.URL.Query().Get("code")
|
|
if state == "" || code == "" {
|
|
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
|
return
|
|
}
|
|
pending, token, scopes, err := h.oauth.Exchange(r.Context(), state, code)
|
|
if err != nil {
|
|
h.logger.Error("oauth exchange", "error", err)
|
|
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
|
return
|
|
}
|
|
if err := h.svc.StoreMigrationToken(r.Context(), pending.UserID, pending.ProjectID, pending.Provider, token, scopes); err != nil {
|
|
h.logger.Error("store migration token", "error", err)
|
|
http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound)
|
|
return
|
|
}
|
|
redirect := h.appURL + "/onboard/migration?oauth=success"
|
|
if pending.InviteToken != "" {
|
|
redirect = h.appURL + "/onboard/migration?oauth=success&token=" + pending.InviteToken
|
|
}
|
|
http.Redirect(w, r, redirect, http.StatusFound)
|
|
}
|
|
|
|
func (h *Handler) handleMicrosoftAdminConsentCallback(w http.ResponseWriter, r *http.Request) (bool, error) {
|
|
q := r.URL.Query()
|
|
projectID := migr.ParseAdminConsentProjectID(q.Get("state"))
|
|
granted := strings.EqualFold(q.Get("admin_consent"), "True")
|
|
oauthErr := strings.TrimSpace(q.Get("error"))
|
|
if !granted && (projectID == "" || oauthErr == "") {
|
|
return false, nil
|
|
}
|
|
if h.oauth == nil {
|
|
return true, migr.ErrProviderDisabled
|
|
}
|
|
record := migr.MicrosoftAdminConsentRecord{
|
|
TenantID: q.Get("tenant"),
|
|
ClientID: h.oauth.MicrosoftClientID(),
|
|
ProjectID: projectID,
|
|
Granted: granted,
|
|
ErrorCode: oauthErr,
|
|
ErrorDescription: q.Get("error_description"),
|
|
}
|
|
if err := h.svc.RecordMicrosoftAdminConsent(r.Context(), record); err != nil {
|
|
redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error"
|
|
http.Redirect(w, r, redirect, http.StatusFound)
|
|
return true, err
|
|
}
|
|
redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=success"
|
|
if !granted {
|
|
redirect = h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error"
|
|
}
|
|
if tenant := strings.TrimSpace(record.TenantID); tenant != "" {
|
|
redirect += "&tenant=" + url.QueryEscape(tenant)
|
|
}
|
|
if projectID != "" {
|
|
redirect += "&project_id=" + url.QueryEscape(projectID)
|
|
}
|
|
http.Redirect(w, r, redirect, http.StatusFound)
|
|
return true, nil
|
|
}
|
|
|
|
func (h *Handler) GetInvite(w http.ResponseWriter, r *http.Request) {
|
|
token := strings.TrimSpace(r.URL.Query().Get("token"))
|
|
if token == "" {
|
|
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
|
Field: "token", Message: "token required",
|
|
}))
|
|
return
|
|
}
|
|
inv, proj, err := h.svc.GetInviteByToken(r.Context(), token)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
|
|
"invite": inv,
|
|
"project": proj,
|
|
"onboarding": h.svc.BuildInviteOnboardingHints(proj, inv),
|
|
})
|
|
}
|
|
|
|
type claimRequest struct {
|
|
Token string `json:"token"`
|
|
Password string `json:"password"`
|
|
}
|
|
|
|
func (h *Handler) ClaimInvite(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
|
return
|
|
}
|
|
var req claimRequest
|
|
if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil {
|
|
return
|
|
}
|
|
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
|
if err != nil {
|
|
apivalidate.WriteInternal(w, r)
|
|
return
|
|
}
|
|
status, err := h.svc.ClaimInvite(r.Context(), req.Token, userID, migr.ClaimIdentityFromAuth(claims), claims.Name, req.Password)
|
|
if err != nil {
|
|
code := http.StatusBadRequest
|
|
errCode := "claim_failed"
|
|
switch {
|
|
case err == migr.ErrInviteNotFound:
|
|
code = http.StatusNotFound
|
|
errCode = "invite_not_found"
|
|
case err == migr.ErrInviteClaimed:
|
|
errCode = "invite_already_claimed"
|
|
case err == migr.ErrEmailMismatch:
|
|
errCode = "email_mismatch"
|
|
case err == migr.ErrTenantMismatch:
|
|
errCode = "tenant_mismatch"
|
|
case err == migr.ErrMigrationDomainNotActive:
|
|
errCode = "migration_domain_not_active"
|
|
case err == migr.ErrMigrationDomainMismatch:
|
|
errCode = "migration_domain_mismatch"
|
|
}
|
|
apiresponse.WriteError(w, r, code, errCode, err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, status)
|
|
}
|
|
|
|
func (h *Handler) GetStatus(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
|
return
|
|
}
|
|
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
|
if err != nil {
|
|
apivalidate.WriteInternal(w, r)
|
|
return
|
|
}
|
|
status, err := h.svc.GetActiveUserStatus(r.Context(), userID)
|
|
if err != nil {
|
|
apivalidate.WriteInternal(w, r)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, status)
|
|
}
|
|
|
|
func (h *Handler) ListOAuthProviders(w http.ResponseWriter, r *http.Request) {
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"providers": h.oauth.EnabledProviders()})
|
|
}
|
|
|
|
type startOAuthRequest struct {
|
|
Provider string `json:"provider"`
|
|
InviteToken string `json:"invite_token"`
|
|
ProjectID string `json:"project_id"`
|
|
}
|
|
|
|
func (h *Handler) StartOAuth(w http.ResponseWriter, r *http.Request) {
|
|
claims := middleware.ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil)
|
|
return
|
|
}
|
|
var req startOAuthRequest
|
|
if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil {
|
|
return
|
|
}
|
|
userID, err := h.svc.LookupUserID(r.Context(), claims.Sub)
|
|
if err != nil {
|
|
apivalidate.WriteInternal(w, r)
|
|
return
|
|
}
|
|
projectID := req.ProjectID
|
|
if projectID == "" && req.InviteToken != "" {
|
|
inv, proj, err := h.svc.GetInviteByToken(r.Context(), req.InviteToken)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil)
|
|
return
|
|
}
|
|
if inv.UserID != "" && inv.UserID != userID {
|
|
apiresponse.WriteError(w, r, http.StatusForbidden, "forbidden", "invite belongs to another user", nil)
|
|
return
|
|
}
|
|
projectID = proj.ID
|
|
}
|
|
if projectID == "" {
|
|
st, err := h.svc.GetActiveUserStatus(r.Context(), userID)
|
|
if err != nil || st.Project.ID == "" {
|
|
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
|
|
Field: "project_id", Message: "project_id required",
|
|
}))
|
|
return
|
|
}
|
|
projectID = st.Project.ID
|
|
}
|
|
provider := migr.Provider(strings.ToLower(strings.TrimSpace(req.Provider)))
|
|
if provider == "" {
|
|
provider = migr.ProviderGoogle
|
|
}
|
|
authURL, _, err := h.oauth.Start(r.Context(), migr.PendingOAuth{
|
|
UserID: userID,
|
|
ProjectID: projectID,
|
|
InviteToken: req.InviteToken,
|
|
}, provider)
|
|
if err != nil {
|
|
apiresponse.WriteError(w, r, http.StatusBadRequest, "oauth_start_failed", err.Error(), nil)
|
|
return
|
|
}
|
|
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"auth_url": authURL})
|
|
}
|