package migrationapi import ( "log/slog" "net/http" "net/url" "strings" "github.com/go-chi/chi/v5" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/api/apivalidate" "github.com/ultisuite/ulti-backend/internal/api/middleware" migr "github.com/ultisuite/ulti-backend/internal/migration" ) const maxMigrationRequestBody = 1 << 20 type Handler struct { svc *migr.Service oauth *migr.OAuthService appURL string logger *slog.Logger } func NewHandler(svc *migr.Service, oauth *migr.OAuthService, appURL string) *Handler { return &Handler{ svc: svc, oauth: oauth, appURL: strings.TrimRight(strings.TrimSpace(appURL), "/"), logger: slog.Default().With("component", "migration-api"), } } func (h *Handler) Routes() chi.Router { r := chi.NewRouter() r.Post("/claim", h.ClaimInvite) r.Get("/status", h.GetStatus) r.Get("/oauth/providers", h.ListOAuthProviders) r.Post("/oauth/start", h.StartOAuth) return r } func (h *Handler) OAuthCallback(w http.ResponseWriter, r *http.Request) { if handled, err := h.handleMicrosoftAdminConsentCallback(w, r); handled { if err != nil { h.logger.Error("microsoft admin consent", "error", err) } return } state := r.URL.Query().Get("state") code := r.URL.Query().Get("code") if state == "" || code == "" { http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound) return } pending, token, scopes, err := h.oauth.Exchange(r.Context(), state, code) if err != nil { h.logger.Error("oauth exchange", "error", err) http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound) return } if err := h.svc.StoreMigrationToken(r.Context(), pending.UserID, pending.ProjectID, pending.Provider, token, scopes); err != nil { h.logger.Error("store migration token", "error", err) http.Redirect(w, r, h.appURL+"/onboard/migration?oauth=error", http.StatusFound) return } redirect := h.appURL + "/onboard/migration?oauth=success" if pending.InviteToken != "" { redirect = h.appURL + "/onboard/migration?oauth=success&token=" + pending.InviteToken } http.Redirect(w, r, redirect, http.StatusFound) } func (h *Handler) handleMicrosoftAdminConsentCallback(w http.ResponseWriter, r *http.Request) (bool, error) { q := r.URL.Query() projectID := migr.ParseAdminConsentProjectID(q.Get("state")) granted := strings.EqualFold(q.Get("admin_consent"), "True") oauthErr := strings.TrimSpace(q.Get("error")) if !granted && (projectID == "" || oauthErr == "") { return false, nil } if h.oauth == nil { return true, migr.ErrProviderDisabled } record := migr.MicrosoftAdminConsentRecord{ TenantID: q.Get("tenant"), ClientID: h.oauth.MicrosoftClientID(), ProjectID: projectID, Granted: granted, ErrorCode: oauthErr, ErrorDescription: q.Get("error_description"), } if err := h.svc.RecordMicrosoftAdminConsent(r.Context(), record); err != nil { redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error" http.Redirect(w, r, redirect, http.StatusFound) return true, err } redirect := h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=success" if !granted { redirect = h.appURL + "/admin/settings/mail-domains?microsoft_admin_consent=error" } if tenant := strings.TrimSpace(record.TenantID); tenant != "" { redirect += "&tenant=" + url.QueryEscape(tenant) } if projectID != "" { redirect += "&project_id=" + url.QueryEscape(projectID) } http.Redirect(w, r, redirect, http.StatusFound) return true, nil } func (h *Handler) GetInvite(w http.ResponseWriter, r *http.Request) { token := strings.TrimSpace(r.URL.Query().Get("token")) if token == "" { apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ Field: "token", Message: "token required", })) return } inv, proj, err := h.svc.GetInviteByToken(r.Context(), token) if err != nil { apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{ "invite": inv, "project": proj, "onboarding": h.svc.BuildInviteOnboardingHints(proj, inv), }) } type claimRequest struct { Token string `json:"token"` Password string `json:"password"` } func (h *Handler) ClaimInvite(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil) return } var req claimRequest if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil { return } userID, err := h.svc.LookupUserID(r.Context(), claims.Sub) if err != nil { apivalidate.WriteInternal(w, r) return } status, err := h.svc.ClaimInvite(r.Context(), req.Token, userID, migr.ClaimIdentityFromAuth(claims), claims.Name, req.Password) if err != nil { code := http.StatusBadRequest errCode := "claim_failed" switch { case err == migr.ErrInviteNotFound: code = http.StatusNotFound errCode = "invite_not_found" case err == migr.ErrInviteClaimed: errCode = "invite_already_claimed" case err == migr.ErrEmailMismatch: errCode = "email_mismatch" case err == migr.ErrMigrationDomainNotActive: errCode = "migration_domain_not_active" case err == migr.ErrMigrationDomainMismatch: errCode = "migration_domain_mismatch" } apiresponse.WriteError(w, r, code, errCode, err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, status) } func (h *Handler) GetStatus(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil) return } userID, err := h.svc.LookupUserID(r.Context(), claims.Sub) if err != nil { apivalidate.WriteInternal(w, r) return } status, err := h.svc.GetActiveUserStatus(r.Context(), userID) if err != nil { apivalidate.WriteInternal(w, r) return } apiresponse.WriteJSON(w, http.StatusOK, status) } func (h *Handler) ListOAuthProviders(w http.ResponseWriter, r *http.Request) { apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"providers": h.oauth.EnabledProviders()}) } type startOAuthRequest struct { Provider string `json:"provider"` InviteToken string `json:"invite_token"` ProjectID string `json:"project_id"` } func (h *Handler) StartOAuth(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) if claims == nil { apiresponse.WriteError(w, r, http.StatusUnauthorized, "unauthorized", "authentication required", nil) return } var req startOAuthRequest if err := apivalidate.DecodeJSON(w, r, maxMigrationRequestBody, &req); err != nil { return } userID, err := h.svc.LookupUserID(r.Context(), claims.Sub) if err != nil { apivalidate.WriteInternal(w, r) return } projectID := req.ProjectID if projectID == "" && req.InviteToken != "" { inv, proj, err := h.svc.GetInviteByToken(r.Context(), req.InviteToken) if err != nil { apiresponse.WriteError(w, r, http.StatusNotFound, "invite_not_found", err.Error(), nil) return } if inv.UserID != "" && inv.UserID != userID { apiresponse.WriteError(w, r, http.StatusForbidden, "forbidden", "invite belongs to another user", nil) return } projectID = proj.ID } if projectID == "" { st, err := h.svc.GetActiveUserStatus(r.Context(), userID) if err != nil || st.Project.ID == "" { apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ Field: "project_id", Message: "project_id required", })) return } projectID = st.Project.ID } provider := migr.Provider(strings.ToLower(strings.TrimSpace(req.Provider))) if provider == "" { provider = migr.ProviderGoogle } authURL, _, err := h.oauth.Start(r.Context(), migr.PendingOAuth{ UserID: userID, ProjectID: projectID, InviteToken: req.InviteToken, }, provider) if err != nil { apiresponse.WriteError(w, r, http.StatusBadRequest, "oauth_start_failed", err.Error(), nil) return } apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"auth_url": authURL}) }