package mail import ( "bytes" "context" "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/go-chi/chi/v5" "github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/auth" ) type draftFakeService struct { fakeMailService drafts map[string]map[string]any created []draftRequest nextID int } func newDraftFakeService() *draftFakeService { return &draftFakeService{ fakeMailService: *newFakeMailService(), drafts: map[string]map[string]any{ "draft-1": { "id": "draft-1", "account_id": "acc-1", "subject": "Draft subject", "body_text": "Draft body", "to": json.RawMessage(`[]`), "cc": json.RawMessage(`[]`), "bcc": json.RawMessage(`[]`), }, }, } } func (f *draftFakeService) ListDrafts(_ context.Context, externalID string, params query.ListParams) (DraftsList, error) { if externalID != testExternalID { return DraftsList{}, ErrUserNotProvisioned } drafts := make([]map[string]any, 0, len(f.drafts)) for _, draft := range f.drafts { drafts = append(drafts, draft) } total := int64(len(drafts)) return DraftsList{ Drafts: drafts, Pagination: params.Meta(&total), }, nil } func (f *draftFakeService) GetDraft(_ context.Context, externalID, draftID string) (map[string]any, error) { if externalID != testExternalID { return nil, ErrUserNotProvisioned } draft, ok := f.drafts[draftID] if !ok { return nil, ErrNotFound } return draft, nil } func (f *draftFakeService) CreateDraft(_ context.Context, userID string, req *draftRequest) (string, error) { if userID != testUserID { return "", ErrAccountNotFound } f.nextID++ id := "draft-new" if f.nextID > 1 { id = "draft-new-2" } f.created = append(f.created, *req) f.drafts[id] = map[string]any{ "id": id, "account_id": req.AccountID, "subject": req.Subject, "body_text": req.BodyText, "to": req.To, } return id, nil } func (f *draftFakeService) UpdateDraft(_ context.Context, externalID, draftID string, req *draftRequest) error { if externalID != testExternalID { return ErrNotFound } if _, ok := f.drafts[draftID]; !ok { return ErrNotFound } return nil } func (f *draftFakeService) DeleteDraft(_ context.Context, externalID, draftID string) error { if externalID != testExternalID { return ErrNotFound } if _, ok := f.drafts[draftID]; !ok { return ErrNotFound } delete(f.drafts, draftID) return nil } func newTestDraftsRouter(svc ServiceAPI) http.Handler { h := NewHandlerWithService(svc) r := chi.NewRouter() r.Use(middleware.WithTestClaims(&auth.Claims{ Sub: testExternalID, Email: "user@example.com", })) r.Get("/drafts", h.ListDrafts) r.Post("/drafts", h.CreateDraft) r.Get("/drafts/{draftID}", h.GetDraft) r.Put("/drafts/{draftID}", h.UpdateDraft) r.Delete("/drafts/{draftID}", h.DeleteDraft) return r } func TestListDrafts(t *testing.T) { svc := newDraftFakeService() router := newTestDraftsRouter(svc) req := httptest.NewRequest(http.MethodGet, "/drafts", nil) rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) } var body DraftsList if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { t.Fatalf("decode body: %v", err) } if len(body.Drafts) != 1 { t.Fatalf("drafts len = %d, want 1", len(body.Drafts)) } if body.Drafts[0]["id"] != "draft-1" { t.Fatalf("draft id = %v", body.Drafts[0]["id"]) } } func TestCreateDraft(t *testing.T) { svc := newDraftFakeService() router := newTestDraftsRouter(svc) payload := map[string]any{ "account_id": "acc-1", "subject": "New draft", "body_text": "Hello draft", } body, err := json.Marshal(payload) if err != nil { t.Fatalf("marshal payload: %v", err) } req := httptest.NewRequest(http.MethodPost, "/drafts", bytes.NewReader(body)) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() router.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusCreated, rec.Body.String()) } var resp map[string]string if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode body: %v", err) } if resp["id"] != "draft-new" { t.Fatalf("response id = %q", resp["id"]) } if len(svc.created) != 1 { t.Fatalf("created count = %d, want 1", len(svc.created)) } if svc.created[0].Subject != "New draft" { t.Fatalf("created subject = %q", svc.created[0].Subject) } }