diff --git a/cmd/ultid/main.go b/cmd/ultid/main.go index 6830757..96d2c65 100644 --- a/cmd/ultid/main.go +++ b/cmd/ultid/main.go @@ -33,6 +33,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/envexpand" mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials" imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap" + mailstorage "github.com/ultisuite/ulti-backend/internal/mail/storage" "github.com/ultisuite/ulti-backend/internal/mail/smtp" "github.com/ultisuite/ulti-backend/internal/meet" "github.com/ultisuite/ulti-backend/internal/nextcloud" @@ -65,7 +66,7 @@ func main() { rdb := redis.NewClient(&redis.Options{Addr: cfg.KeyDBAddr}) defer rdb.Close() - _, err = minio.New(cfg.RustFSEndpoint, &minio.Options{ + minioClient, err := minio.New(cfg.RustFSEndpoint, &minio.Options{ Creds: credentials.NewStaticV4(cfg.RustFSAccessKey, cfg.RustFSSecretKey, ""), Secure: cfg.RustFSUseSSL, }) @@ -73,6 +74,10 @@ func main() { slog.Error("failed to create RustFS client", "error", err) os.Exit(1) } + attachmentStorage := mailstorage.NewClient(minioClient, cfg.MailAttachmentsBucket) + if err := attachmentStorage.EnsureBucket(ctx); err != nil { + slog.Warn("mail attachments bucket check failed", "error", err) + } verifier, err := auth.NewVerifier(ctx, cfg.OIDCIssuer, cfg.OIDCClientID) if err != nil { @@ -166,9 +171,9 @@ func main() { r.Get("/ws", hub.HandleWS) r.Group(func(r chi.Router) { - r.Use(middleware.Auth(verifier, auditLogger)) + r.Use(middleware.Auth(verifier, pool, auditLogger)) - r.Mount("/api/v1/mail", mailapi.NewHandler(pool, auditLogger, credentialManager).Routes()) + r.Mount("/api/v1/mail", mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket).Routes()) r.Mount("/api/v1/admin", admin.NewHandler(pool, auditLogger).Routes()) r.Get("/api/v1/search", search.NewHandler(pool).Search) diff --git a/internal/api/mail/attachments.go b/internal/api/mail/attachments.go new file mode 100644 index 0000000..afbf5e7 --- /dev/null +++ b/internal/api/mail/attachments.go @@ -0,0 +1,285 @@ +package mail + +import ( + "context" + "encoding/json" + "errors" + "io" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + + "github.com/ultisuite/ulti-backend/internal/mail/storage" +) + +var ( + ErrAttachmentNotFound = errors.New("attachment not found") + ErrAttachmentTooLarge = errors.New("attachment too large") +) + +const maxAttachmentSize = 25 << 20 // 25 MiB + +type draftAttachmentRef struct { + ID string `json:"id"` + Filename string `json:"filename"` + ContentType string `json:"content_type"` + Size int64 `json:"size"` + S3Bucket string `json:"s3_bucket"` + S3Key string `json:"s3_key"` + ContentID string `json:"content_id,omitempty"` + IsInline bool `json:"is_inline"` +} + +func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messageID string) ([]map[string]any, error) { + if _, err := s.ensureMessageOwned(ctx, externalID, messageID); err != nil { + return nil, err + } + + rows, err := s.db.Query(ctx, ` + SELECT id, filename, content_type, size, content_id, is_inline + FROM attachments WHERE message_id = $1 + ORDER BY created_at ASC + `, messageID) + if err != nil { + return nil, err + } + defer rows.Close() + + out := make([]map[string]any, 0) + for rows.Next() { + var id, filename, contentType, contentID string + var size int64 + var isInline bool + if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline); err != nil { + return nil, err + } + entry := map[string]any{ + "id": id, "filename": filename, "content_type": contentType, + "size": size, "is_inline": isInline, + } + if contentID != "" { + entry["content_id"] = contentID + } + out = append(out, entry) + } + return out, rows.Err() +} + +func (s *Service) MessageAttachmentCIDMap(ctx context.Context, externalID, messageID string) (map[string]string, error) { + if _, err := s.ensureMessageOwned(ctx, externalID, messageID); err != nil { + return nil, err + } + + rows, err := s.db.Query(ctx, ` + SELECT id, content_id FROM attachments + WHERE message_id = $1 AND content_id <> '' + `, messageID) + if err != nil { + return nil, err + } + defer rows.Close() + + mapping := make(map[string]string) + for rows.Next() { + var id, contentID string + if err := rows.Scan(&id, &contentID); err != nil { + return nil, err + } + mapping[contentID] = id + } + return mapping, rows.Err() +} + +func (s *Service) UploadMessageAttachment( + ctx context.Context, externalID, messageID, filename, contentType, contentID string, + isInline bool, reader io.Reader, size int64, +) (string, error) { + if s.storage == nil { + return "", errors.New("object storage unavailable") + } + if size > maxAttachmentSize { + return "", ErrAttachmentTooLarge + } + userID, err := s.ensureMessageOwned(ctx, externalID, messageID) + if err != nil { + return "", err + } + + objectKey := storage.MessageObjectKey(userID, messageID, filename) + if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { + return "", err + } + + var id string + err = s.db.QueryRow(ctx, ` + INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + RETURNING id + `, messageID, filename, contentType, size, s.storageBucket(), objectKey, contentID, isInline).Scan(&id) + if err != nil { + _ = s.storage.Delete(ctx, objectKey) + return "", err + } + + _, err = s.db.Exec(ctx, ` + UPDATE messages SET has_attachments = true, updated_at = NOW() WHERE id = $1 + `, messageID) + if err != nil { + return "", err + } + return id, nil +} + +func (s *Service) OpenAttachment(ctx context.Context, externalID, attachmentID string) ( + filename, contentType string, size int64, isInline bool, body io.ReadCloser, err error, +) { + if s.storage == nil { + return "", "", 0, false, nil, errors.New("object storage unavailable") + } + + var s3Key string + err = s.db.QueryRow(ctx, ` + SELECT a.filename, a.content_type, a.size, a.is_inline, a.s3_key + FROM attachments a + JOIN messages m ON a.message_id = m.id + JOIN mail_accounts ma ON m.account_id = ma.id + WHERE a.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) + `, attachmentID, externalID).Scan(&filename, &contentType, &size, &isInline, &s3Key) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", "", 0, false, nil, ErrAttachmentNotFound + } + return "", "", 0, false, nil, err + } + + obj, err := s.storage.Get(ctx, s3Key) + if err != nil { + return "", "", 0, false, nil, err + } + return filename, contentType, size, isInline, obj, nil +} + +func (s *Service) UploadDraftAttachment( + ctx context.Context, externalID, draftID, filename, contentType, contentID string, + isInline bool, reader io.Reader, size int64, +) (string, error) { + if s.storage == nil { + return "", errors.New("object storage unavailable") + } + if size > maxAttachmentSize { + return "", ErrAttachmentTooLarge + } + + userID, err := s.ResolveUserID(ctx, externalID) + if err != nil { + return "", err + } + + var attachmentsJSON []byte + err = s.db.QueryRow(ctx, ` + SELECT attachments FROM outbox + WHERE id = $1 AND user_id = $2 AND status = 'draft' + `, draftID, userID).Scan(&attachmentsJSON) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrNotFound + } + return "", err + } + + objectKey := storage.DraftObjectKey(userID, draftID, filename) + if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { + return "", err + } + + refs := make([]draftAttachmentRef, 0) + if len(attachmentsJSON) > 0 && string(attachmentsJSON) != "[]" { + _ = json.Unmarshal(attachmentsJSON, &refs) + } + + attID := uuid.NewString() + refs = append(refs, draftAttachmentRef{ + ID: attID, Filename: filename, ContentType: contentType, Size: size, + S3Bucket: s.storageBucket(), S3Key: objectKey, + ContentID: contentID, IsInline: isInline, + }) + + updated, _ := json.Marshal(refs) + result, err := s.db.Exec(ctx, ` + UPDATE outbox SET attachments = $1, updated_at = NOW() + WHERE id = $2 AND user_id = $3 AND status = 'draft' + `, updated, draftID, userID) + if err != nil { + _ = s.storage.Delete(ctx, objectKey) + return "", err + } + if result.RowsAffected() == 0 { + _ = s.storage.Delete(ctx, objectKey) + return "", ErrNotFound + } + return attID, nil +} + +func (s *Service) OpenDraftAttachment(ctx context.Context, externalID, draftID, attachmentID string) ( + filename, contentType string, body io.ReadCloser, err error, +) { + if s.storage == nil { + return "", "", nil, errors.New("object storage unavailable") + } + + userID, err := s.ResolveUserID(ctx, externalID) + if err != nil { + return "", "", nil, err + } + + var attachmentsJSON []byte + err = s.db.QueryRow(ctx, ` + SELECT attachments FROM outbox + WHERE id = $1 AND user_id = $2 AND status = 'draft' + `, draftID, userID).Scan(&attachmentsJSON) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", "", nil, ErrNotFound + } + return "", "", nil, err + } + + var refs []draftAttachmentRef + if err := json.Unmarshal(attachmentsJSON, &refs); err != nil { + return "", "", nil, err + } + for _, ref := range refs { + if ref.ID != attachmentID { + continue + } + obj, err := s.storage.Get(ctx, ref.S3Key) + if err != nil { + return "", "", nil, err + } + return ref.Filename, ref.ContentType, obj, nil + } + return "", "", nil, ErrAttachmentNotFound +} + +func (s *Service) ensureMessageOwned(ctx context.Context, externalID, messageID string) (userID string, err error) { + err = s.db.QueryRow(ctx, ` + SELECT u.id FROM messages m + JOIN mail_accounts ma ON m.account_id = ma.id + JOIN users u ON ma.user_id = u.id + WHERE m.id = $1 AND u.external_id = $2 + `, messageID, externalID).Scan(&userID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrNotFound + } + return "", err + } + return userID, nil +} + +func (s *Service) storageBucket() string { + if s.attachmentsBucket != "" { + return s.attachmentsBucket + } + return "mail-attachments" +} diff --git a/internal/api/mail/drafts.go b/internal/api/mail/drafts.go new file mode 100644 index 0000000..14771b2 --- /dev/null +++ b/internal/api/mail/drafts.go @@ -0,0 +1,276 @@ +package mail + +import ( + "context" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5" + + "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/mail/threading" +) + +type DraftsList struct { + Drafts []map[string]any `json:"drafts"` + Pagination query.PaginationMeta `json:"pagination,omitempty"` +} + +func (s *Service) ListDrafts(ctx context.Context, externalID string, params query.ListParams) (DraftsList, error) { + var total int64 + err := s.db.QueryRow(ctx, ` + SELECT COUNT(*) FROM outbox o + WHERE o.user_id = (SELECT id FROM users WHERE external_id = $1) + AND o.status = 'draft' + `, externalID).Scan(&total) + if err != nil { + return DraftsList{}, err + } + + rows, err := s.db.Query(ctx, ` + SELECT o.id, o.account_id, o.identity_id, o.to_addrs, o.cc_addrs, o.bcc_addrs, + o.subject, o.body_text, o.updated_at, o.created_at + FROM outbox o + WHERE o.user_id = (SELECT id FROM users WHERE external_id = $1) + AND o.status = 'draft' + ORDER BY o.updated_at DESC + LIMIT $2 OFFSET $3 + `, externalID, params.Limit(), params.Offset()) + if err != nil { + return DraftsList{}, err + } + defer rows.Close() + + drafts := make([]map[string]any, 0) + for rows.Next() { + entry, err := scanDraftListRow(rows) + if err != nil { + return DraftsList{}, err + } + drafts = append(drafts, entry) + } + if err := rows.Err(); err != nil { + return DraftsList{}, err + } + + return DraftsList{ + Drafts: drafts, + Pagination: params.Meta(&total), + }, nil +} + +func (s *Service) GetDraft(ctx context.Context, externalID, draftID string) (map[string]any, error) { + var ( + id, accountID, subject, bodyText, bodyHTML, inReplyTo string + identityID *string + toAddrs, ccAddrs, bccAddrs, attachments []byte + references []string + createdAt, updatedAt any + ) + err := s.db.QueryRow(ctx, ` + SELECT o.id, o.account_id, o.identity_id, o.to_addrs, o.cc_addrs, o.bcc_addrs, + o.subject, o.body_text, o.body_html, o.in_reply_to, o.references_header, + o.attachments, o.created_at, o.updated_at + FROM outbox o + WHERE o.id = $1 + AND o.user_id = (SELECT id FROM users WHERE external_id = $2) + AND o.status = 'draft' + `, draftID, externalID).Scan( + &id, &accountID, &identityID, &toAddrs, &ccAddrs, &bccAddrs, + &subject, &bodyText, &bodyHTML, &inReplyTo, &references, + &attachments, &createdAt, &updatedAt, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + return draftDetailMap(id, accountID, identityID, toAddrs, ccAddrs, bccAddrs, subject, bodyText, bodyHTML, inReplyTo, references, attachments, createdAt, updatedAt), nil +} + +func (s *Service) CreateDraft(ctx context.Context, userID string, req *draftRequest) (string, error) { + if err := s.validateDraftAccountAndIdentity(ctx, userID, req.AccountID, req.IdentityID); err != nil { + return "", err + } + + toJSON, _ := json.Marshal(req.To) + ccJSON, _ := json.Marshal(req.Cc) + bccJSON, _ := json.Marshal(req.Bcc) + attachmentsJSON, _ := json.Marshal(req.Attachments) + if req.Attachments == nil { + attachmentsJSON = []byte("[]") + } + inReplyTo := threading.NormalizeMessageID(req.InReplyTo) + + var id string + err := s.db.QueryRow(ctx, ` + INSERT INTO outbox ( + user_id, account_id, identity_id, to_addrs, cc_addrs, bcc_addrs, + subject, body_text, body_html, in_reply_to, references_header, attachments, status + ) + SELECT $1, ma.id, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, 'draft' + FROM mail_accounts ma + WHERE ma.id = $2 AND ma.user_id = $1 + RETURNING id + `, userID, req.AccountID, nilIfEmpty(req.IdentityID), toJSON, ccJSON, bccJSON, + req.Subject, req.BodyText, req.BodyHTML, inReplyTo, []string{}, attachmentsJSON).Scan(&id) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return "", ErrAccountNotFound + } + return "", err + } + return id, nil +} + +func (s *Service) UpdateDraft(ctx context.Context, externalID, draftID string, req *draftRequest) error { + userID, err := s.ResolveUserID(ctx, externalID) + if err != nil { + return err + } + + accountID := req.AccountID + if accountID == "" && req.IdentityID != "" { + err := s.db.QueryRow(ctx, ` + SELECT account_id FROM outbox + WHERE id = $1 AND user_id = $2 AND status = 'draft' + `, draftID, userID).Scan(&accountID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrNotFound + } + return err + } + } + if accountID != "" { + if err := s.validateDraftAccountAndIdentity(ctx, userID, accountID, req.IdentityID); err != nil { + return err + } + } + + toJSON, _ := json.Marshal(req.To) + ccJSON, _ := json.Marshal(req.Cc) + bccJSON, _ := json.Marshal(req.Bcc) + attachmentsJSON, _ := json.Marshal(req.Attachments) + if req.Attachments == nil { + attachmentsJSON = []byte("[]") + } + inReplyTo := threading.NormalizeMessageID(req.InReplyTo) + + result, err := s.db.Exec(ctx, ` + UPDATE outbox o SET + account_id = COALESCE($1, o.account_id), + identity_id = CASE WHEN $2 <> '' THEN $2::uuid ELSE o.identity_id END, + to_addrs = $3, + cc_addrs = $4, + bcc_addrs = $5, + subject = $6, + body_text = $7, + body_html = $8, + in_reply_to = $9, + references_header = $10, + attachments = $11, + updated_at = NOW() + WHERE o.id = $12 + AND o.user_id = $13 + AND o.status = 'draft' + `, nilIfEmpty(req.AccountID), req.IdentityID, toJSON, ccJSON, bccJSON, + req.Subject, req.BodyText, req.BodyHTML, inReplyTo, []string{}, attachmentsJSON, + draftID, userID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Service) DeleteDraft(ctx context.Context, externalID, draftID string) error { + result, err := s.db.Exec(ctx, ` + DELETE FROM outbox o + WHERE o.id = $1 + AND o.user_id = (SELECT id FROM users WHERE external_id = $2) + AND o.status = 'draft' + `, draftID, externalID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Service) validateDraftAccountAndIdentity(ctx context.Context, userID, accountID, identityID string) error { + var exists bool + err := s.db.QueryRow(ctx, ` + SELECT EXISTS(SELECT 1 FROM mail_accounts WHERE id = $1 AND user_id = $2) + `, accountID, userID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return ErrAccountNotFound + } + if identityID == "" { + return nil + } + err = s.db.QueryRow(ctx, ` + SELECT EXISTS(SELECT 1 FROM mail_identities WHERE id = $1 AND account_id = $2) + `, identityID, accountID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return ErrNotFound + } + return nil +} + +type draftListScanner interface { + Scan(dest ...any) error +} + +func scanDraftListRow(rows draftListScanner) (map[string]any, error) { + var id, accountID, subject, bodyText string + var identityID *string + var toAddrs, ccAddrs, bccAddrs []byte + var updatedAt, createdAt any + if err := rows.Scan(&id, &accountID, &identityID, &toAddrs, &ccAddrs, &bccAddrs, &subject, &bodyText, &updatedAt, &createdAt); err != nil { + return nil, err + } + entry := map[string]any{ + "id": id, "account_id": accountID, "subject": subject, + "to": json.RawMessage(toAddrs), "cc": json.RawMessage(ccAddrs), "bcc": json.RawMessage(bccAddrs), + "body_text": bodyText, "updated_at": updatedAt, "created_at": createdAt, + } + if identityID != nil { + entry["identity_id"] = *identityID + } + return entry, nil +} + +func draftDetailMap( + id, accountID string, + identityID *string, + toAddrs, ccAddrs, bccAddrs []byte, + subject, bodyText, bodyHTML, inReplyTo string, + references []string, + attachments []byte, + createdAt, updatedAt any, +) map[string]any { + out := map[string]any{ + "id": id, "account_id": accountID, "subject": subject, + "to": json.RawMessage(toAddrs), "cc": json.RawMessage(ccAddrs), "bcc": json.RawMessage(bccAddrs), + "body_text": bodyText, "body_html": bodyHTML, + "in_reply_to": inReplyTo, "references": references, + "attachments": json.RawMessage(attachments), + "created_at": createdAt, "updated_at": updatedAt, + } + if identityID != nil { + out["identity_id"] = *identityID + } + return out +} diff --git a/internal/api/mail/drafts_test.go b/internal/api/mail/drafts_test.go new file mode 100644 index 0000000..bf5a04e --- /dev/null +++ b/internal/api/mail/drafts_test.go @@ -0,0 +1,177 @@ +package mail + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/auth" +) + +type draftFakeService struct { + fakeMailService + drafts map[string]map[string]any + created []draftRequest + nextID int +} + +func newDraftFakeService() *draftFakeService { + return &draftFakeService{ + fakeMailService: *newFakeMailService(), + drafts: map[string]map[string]any{ + "draft-1": { + "id": "draft-1", "account_id": "acc-1", "subject": "Draft subject", + "body_text": "Draft body", "to": json.RawMessage(`[]`), + "cc": json.RawMessage(`[]`), "bcc": json.RawMessage(`[]`), + }, + }, + } +} + +func (f *draftFakeService) ListDrafts(_ context.Context, externalID string, params query.ListParams) (DraftsList, error) { + if externalID != testExternalID { + return DraftsList{}, ErrUserNotProvisioned + } + drafts := make([]map[string]any, 0, len(f.drafts)) + for _, draft := range f.drafts { + drafts = append(drafts, draft) + } + total := int64(len(drafts)) + return DraftsList{ + Drafts: drafts, + Pagination: params.Meta(&total), + }, nil +} + +func (f *draftFakeService) GetDraft(_ context.Context, externalID, draftID string) (map[string]any, error) { + if externalID != testExternalID { + return nil, ErrUserNotProvisioned + } + draft, ok := f.drafts[draftID] + if !ok { + return nil, ErrNotFound + } + return draft, nil +} + +func (f *draftFakeService) CreateDraft(_ context.Context, userID string, req *draftRequest) (string, error) { + if userID != testUserID { + return "", ErrAccountNotFound + } + f.nextID++ + id := "draft-new" + if f.nextID > 1 { + id = "draft-new-2" + } + f.created = append(f.created, *req) + f.drafts[id] = map[string]any{ + "id": id, "account_id": req.AccountID, "subject": req.Subject, + "body_text": req.BodyText, "to": req.To, + } + return id, nil +} + +func (f *draftFakeService) UpdateDraft(_ context.Context, externalID, draftID string, req *draftRequest) error { + if externalID != testExternalID { + return ErrNotFound + } + if _, ok := f.drafts[draftID]; !ok { + return ErrNotFound + } + return nil +} + +func (f *draftFakeService) DeleteDraft(_ context.Context, externalID, draftID string) error { + if externalID != testExternalID { + return ErrNotFound + } + if _, ok := f.drafts[draftID]; !ok { + return ErrNotFound + } + delete(f.drafts, draftID) + return nil +} + +func newTestDraftsRouter(svc ServiceAPI) http.Handler { + h := NewHandlerWithService(svc) + r := chi.NewRouter() + r.Use(middleware.WithTestClaims(&auth.Claims{ + Sub: testExternalID, + Email: "user@example.com", + })) + r.Get("/drafts", h.ListDrafts) + r.Post("/drafts", h.CreateDraft) + r.Get("/drafts/{draftID}", h.GetDraft) + r.Put("/drafts/{draftID}", h.UpdateDraft) + r.Delete("/drafts/{draftID}", h.DeleteDraft) + return r +} + +func TestListDrafts(t *testing.T) { + svc := newDraftFakeService() + router := newTestDraftsRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/drafts", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var body DraftsList + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Drafts) != 1 { + t.Fatalf("drafts len = %d, want 1", len(body.Drafts)) + } + if body.Drafts[0]["id"] != "draft-1" { + t.Fatalf("draft id = %v", body.Drafts[0]["id"]) + } +} + +func TestCreateDraft(t *testing.T) { + svc := newDraftFakeService() + router := newTestDraftsRouter(svc) + + payload := map[string]any{ + "account_id": "acc-1", + "subject": "New draft", + "body_text": "Hello draft", + } + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/drafts", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var resp map[string]string + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode body: %v", err) + } + if resp["id"] != "draft-new" { + t.Fatalf("response id = %q", resp["id"]) + } + if len(svc.created) != 1 { + t.Fatalf("created count = %d, want 1", len(svc.created)) + } + if svc.created[0].Subject != "New draft" { + t.Fatalf("created subject = %q", svc.created[0].Subject) + } +} diff --git a/internal/api/mail/folders.go b/internal/api/mail/folders.go new file mode 100644 index 0000000..06f841e --- /dev/null +++ b/internal/api/mail/folders.go @@ -0,0 +1,214 @@ +package mail + +import ( + "context" + "errors" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/securityaudit" +) + +var ( + ErrFolderProtected = errors.New("system folder cannot be deleted") + ErrDuplicateFolder = errors.New("duplicate folder remote_name") + ErrDuplicateLabel = errors.New("duplicate label name") +) + +type FoldersList struct { + Folders []map[string]any `json:"folders"` + Pagination query.PaginationMeta `json:"pagination,omitempty"` +} + +func scanFolderRow(id, accountID, name, remoteName, folderType string, uidvalidity int64, messageCount, unreadCount int, createdAt, updatedAt any) map[string]any { + return map[string]any{ + "id": id, + "account_id": accountID, + "name": name, + "remote_name": remoteName, + "folder_type": folderType, + "uidvalidity": uidvalidity, + "message_count": messageCount, + "unread_count": unreadCount, + "created_at": createdAt, + "updated_at": updatedAt, + } +} + +func (s *Service) ListFolders(ctx context.Context, externalID, accountID string, params query.ListParams) (FoldersList, error) { + var owned bool + if err := s.db.QueryRow(ctx, ` + SELECT EXISTS( + SELECT 1 FROM mail_accounts + WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2) + ) + `, accountID, externalID).Scan(&owned); err != nil { + return FoldersList{}, err + } + if !owned { + return FoldersList{}, ErrAccountNotFound + } + + var total int64 + if err := s.db.QueryRow(ctx, ` + SELECT COUNT(*) FROM mail_folders WHERE account_id = $1 + `, accountID).Scan(&total); err != nil { + return FoldersList{}, err + } + + rows, err := s.db.Query(ctx, ` + SELECT id, account_id, name, remote_name, folder_type, uidvalidity, message_count, unread_count, created_at, updated_at + FROM mail_folders + WHERE account_id = $1 + ORDER BY name ASC + LIMIT $2 OFFSET $3 + `, accountID, params.Limit(), params.Offset()) + if err != nil { + return FoldersList{}, err + } + defer rows.Close() + + folders := make([]map[string]any, 0) + for rows.Next() { + var id, acctID, name, remoteName, folderType string + var uidvalidity int64 + var messageCount, unreadCount int + var createdAt, updatedAt any + if err := rows.Scan(&id, &acctID, &name, &remoteName, &folderType, &uidvalidity, &messageCount, &unreadCount, &createdAt, &updatedAt); err != nil { + return FoldersList{}, err + } + folders = append(folders, scanFolderRow(id, acctID, name, remoteName, folderType, uidvalidity, messageCount, unreadCount, createdAt, updatedAt)) + } + if err := rows.Err(); err != nil { + return FoldersList{}, err + } + + return FoldersList{ + Folders: folders, + Pagination: params.Meta(&total), + }, nil +} + +func (s *Service) GetFolder(ctx context.Context, externalID, folderID string) (map[string]any, error) { + var id, accountID, name, remoteName, folderType string + var uidvalidity int64 + var messageCount, unreadCount int + var createdAt, updatedAt any + err := s.db.QueryRow(ctx, ` + SELECT f.id, f.account_id, f.name, f.remote_name, f.folder_type, f.uidvalidity, f.message_count, f.unread_count, f.created_at, f.updated_at + FROM mail_folders f + JOIN mail_accounts ma ON f.account_id = ma.id + WHERE f.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) + `, folderID, externalID).Scan(&id, &accountID, &name, &remoteName, &folderType, &uidvalidity, &messageCount, &unreadCount, &createdAt, &updatedAt) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + return scanFolderRow(id, accountID, name, remoteName, folderType, uidvalidity, messageCount, unreadCount, createdAt, updatedAt), nil +} + +func (s *Service) CreateFolder(ctx context.Context, userID string, req *createFolderRequest) (string, error) { + var owned bool + if err := s.db.QueryRow(ctx, ` + SELECT EXISTS(SELECT 1 FROM mail_accounts WHERE id = $1 AND user_id = $2) + `, req.AccountID, userID).Scan(&owned); err != nil { + return "", err + } + if !owned { + return "", ErrAccountNotFound + } + + remoteName := strings.TrimSpace(req.RemoteName) + if remoteName == "" { + remoteName = strings.TrimSpace(req.Name) + } + folderType := normalizeFolderType(req.FolderType) + + var id string + err := s.db.QueryRow(ctx, ` + INSERT INTO mail_folders (account_id, name, remote_name, folder_type) + VALUES ($1, $2, $3, $4) + RETURNING id + `, req.AccountID, strings.TrimSpace(req.Name), remoteName, folderType).Scan(&id) + if err != nil { + if isUniqueViolation(err) { + return "", ErrDuplicateFolder + } + return "", err + } + return id, nil +} + +func (s *Service) UpdateFolder(ctx context.Context, externalID, folderID string, req *updateFolderRequest) error { + folderType := normalizeFolderType(req.FolderType) + result, err := s.db.Exec(ctx, ` + UPDATE mail_folders f SET + name = $1, + remote_name = $2, + folder_type = $3, + updated_at = NOW() + FROM mail_accounts ma + WHERE f.id = $4 + AND f.account_id = ma.id + AND ma.user_id = (SELECT id FROM users WHERE external_id = $5) + `, strings.TrimSpace(req.Name), strings.TrimSpace(req.RemoteName), folderType, folderID, externalID) + if err != nil { + if isUniqueViolation(err) { + return ErrDuplicateFolder + } + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Service) DeleteFolder(ctx context.Context, externalID, folderID string) error { + var folderType string + err := s.db.QueryRow(ctx, ` + SELECT f.folder_type + FROM mail_folders f + JOIN mail_accounts ma ON f.account_id = ma.id + WHERE f.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) + `, folderID, externalID).Scan(&folderType) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return ErrNotFound + } + return err + } + if folderType != "custom" { + return ErrFolderProtected + } + + result, err := s.db.Exec(ctx, ` + DELETE FROM mail_folders f + USING mail_accounts ma + WHERE f.id = $1 + AND f.account_id = ma.id + AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) + `, folderID, externalID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + if s.audit != nil { + s.audit.Log(ctx, externalID, securityaudit.ActionCriticalDeletion, map[string]any{ + "target": "mail_folder", "folder_id": folderID, + }) + } + return nil +} + +func isUniqueViolation(err error) bool { + var pgErr *pgconn.PgError + return errors.As(err, &pgErr) && pgErr.Code == "23505" +} diff --git a/internal/api/mail/folders_test.go b/internal/api/mail/folders_test.go new file mode 100644 index 0000000..dddc4f7 --- /dev/null +++ b/internal/api/mail/folders_test.go @@ -0,0 +1,64 @@ +package mail + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/auth" +) + +func TestValidateCreateFolder(t *testing.T) { + t.Run("invalid folder_type", func(t *testing.T) { + req := &createFolderRequest{ + AccountID: "acc-1", + Name: "Work", + FolderType: "bogus", + } + if verr := validateCreateFolder(req); verr == nil { + t.Fatal("expected validation error for folder_type") + } + }) + + t.Run("missing account_id", func(t *testing.T) { + req := &createFolderRequest{Name: "Work"} + if verr := validateCreateFolder(req); verr == nil { + t.Fatal("expected validation error for account_id") + } + }) +} + +func TestListFoldersRequiresAccountID(t *testing.T) { + svc := newFakeMailService() + h := NewHandlerWithService(svc) + r := chi.NewRouter() + r.Use(middleware.WithTestClaims(&auth.Claims{Sub: testExternalID, Email: "user@example.com"})) + r.Mount("/", h.FolderLabelRoutes()) + + req := httptest.NewRequest(http.MethodGet, "/folders", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusBadRequest, rec.Body.String()) + } +} + +func TestListFoldersWithAccountID(t *testing.T) { + svc := newFakeMailService() + h := NewHandlerWithService(svc) + r := chi.NewRouter() + r.Use(middleware.WithTestClaims(&auth.Claims{Sub: testExternalID, Email: "user@example.com"})) + r.Mount("/", h.FolderLabelRoutes()) + + req := httptest.NewRequest(http.MethodGet, "/folders?account_id=acc-1", nil) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } +} diff --git a/internal/api/mail/handlers.go b/internal/api/mail/handlers.go index 1e78f86..801fc73 100644 --- a/internal/api/mail/handlers.go +++ b/internal/api/mail/handlers.go @@ -13,6 +13,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/mail/credentials" + "github.com/ultisuite/ulti-backend/internal/mail/storage" "github.com/ultisuite/ulti-backend/internal/securityaudit" ) @@ -28,8 +29,8 @@ func NewHandlerWithService(svc ServiceAPI) *Handler { } } -func NewHandler(db *pgxpool.Pool, audit *securityaudit.Logger, credentialManager *credentials.Manager) *Handler { - return NewHandlerWithService(NewService(db, audit, credentialManager)) +func NewHandler(db *pgxpool.Pool, audit *securityaudit.Logger, credentialManager *credentials.Manager, objectStorage *storage.Client, attachmentsBucket string) *Handler { + return NewHandlerWithService(NewService(db, audit, credentialManager, objectStorage, attachmentsBucket)) } func (h *Handler) Routes() chi.Router { @@ -39,13 +40,38 @@ func (h *Handler) Routes() chi.Router { r.Post("/accounts", h.CreateAccount) r.Get("/accounts/{accountID}", h.GetAccount) r.Delete("/accounts/{accountID}", h.DeleteAccount) + r.Get("/accounts/{accountID}/identities", h.ListIdentities) + r.Post("/accounts/{accountID}/identities", h.CreateIdentity) + + r.Get("/identities/{identityID}", h.GetIdentity) + r.Put("/identities/{identityID}", h.UpdateIdentity) + r.Delete("/identities/{identityID}", h.DeleteIdentity) + + r.Mount("/", h.FolderLabelRoutes()) + + r.Get("/search", h.SearchMessages) + + r.Get("/drafts", h.ListDrafts) + r.Post("/drafts", h.CreateDraft) + r.Get("/drafts/{draftID}", h.GetDraft) + r.Put("/drafts/{draftID}", h.UpdateDraft) + r.Delete("/drafts/{draftID}", h.DeleteDraft) + r.Post("/drafts/{draftID}/attachments", h.UploadDraftAttachment) + r.Get("/drafts/{draftID}/attachments/{attachmentID}", h.DownloadDraftAttachment) + r.Get("/drafts/{draftID}/attachments/{attachmentID}/inline", h.DownloadDraftAttachment) r.Get("/messages", h.ListMessages) + r.Get("/messages/{messageID}/attachments", h.ListMessageAttachments) + r.Get("/messages/{messageID}/attachments/cid-map", h.MessageAttachmentCIDMap) + r.Post("/messages/{messageID}/attachments", h.UploadMessageAttachment) r.Get("/messages/{messageID}", h.GetMessage) r.Put("/messages/{messageID}/labels", h.UpdateLabels) r.Put("/messages/{messageID}/flags", h.UpdateFlags) r.Delete("/messages/{messageID}", h.DeleteMessage) + r.Get("/attachments/{attachmentID}", h.DownloadAttachment) + r.Get("/attachments/{attachmentID}/inline", h.DownloadAttachment) + r.Get("/threads/{threadID}", h.GetThread) r.Post("/send", h.SendMessage) @@ -176,11 +202,6 @@ func (h *Handler) GetMessage(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateLabels(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } var req updateLabelsRequest if err := apivalidate.DecodeJSON(w, r, maxFlagsLabelsBody, &req); err != nil { @@ -191,7 +212,7 @@ func (h *Handler) UpdateLabels(w http.ResponseWriter, r *http.Request) { return } - if err := h.svc.UpdateLabels(r.Context(), userID, chi.URLParam(r, "messageID"), req.Labels); err != nil { + if err := h.svc.UpdateLabels(r.Context(), claims.Sub, chi.URLParam(r, "messageID"), req.Labels); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return @@ -205,11 +226,6 @@ func (h *Handler) UpdateLabels(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateFlags(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } var req updateFlagsRequest if err := apivalidate.DecodeJSON(w, r, maxFlagsLabelsBody, &req); err != nil { @@ -220,7 +236,7 @@ func (h *Handler) UpdateFlags(w http.ResponseWriter, r *http.Request) { return } - if err := h.svc.UpdateFlags(r.Context(), userID, chi.URLParam(r, "messageID"), req.Flags); err != nil { + if err := h.svc.UpdateFlags(r.Context(), claims.Sub, chi.URLParam(r, "messageID"), req.Flags); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return @@ -234,13 +250,8 @@ func (h *Handler) UpdateFlags(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteMessage(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } - if err := h.svc.DeleteMessage(r.Context(), claims.Sub, userID, chi.URLParam(r, "messageID")); err != nil { + if err := h.svc.DeleteMessage(r.Context(), claims.Sub, chi.URLParam(r, "messageID")); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return @@ -342,11 +353,6 @@ func (h *Handler) CreateRule(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateRule(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } var req updateRuleRequest if err := apivalidate.DecodeJSON(w, r, maxRulesRequestBody, &req); err != nil { @@ -357,7 +363,7 @@ func (h *Handler) UpdateRule(w http.ResponseWriter, r *http.Request) { return } - if err := h.svc.UpdateRule(r.Context(), userID, chi.URLParam(r, "ruleID"), &req); err != nil { + if err := h.svc.UpdateRule(r.Context(), claims.Sub, chi.URLParam(r, "ruleID"), &req); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return @@ -371,13 +377,8 @@ func (h *Handler) UpdateRule(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteRule(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } - if err := h.svc.DeleteRule(r.Context(), claims.Sub, userID, chi.URLParam(r, "ruleID")); err != nil { + if err := h.svc.DeleteRule(r.Context(), claims.Sub, chi.URLParam(r, "ruleID")); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return @@ -430,13 +431,8 @@ func (h *Handler) CreateWebhook(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteWebhook(w http.ResponseWriter, r *http.Request) { claims := middleware.ClaimsFromContext(r.Context()) - userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) - if err != nil { - h.writeUserResolveError(w, r, err) - return - } - if err := h.svc.DeleteWebhook(r.Context(), claims.Sub, userID, chi.URLParam(r, "webhookID")); err != nil { + if err := h.svc.DeleteWebhook(r.Context(), claims.Sub, chi.URLParam(r, "webhookID")); err != nil { if errors.Is(err, ErrNotFound) { apivalidate.WriteNotFound(w, r, "not found") return diff --git a/internal/api/mail/handlers_attachments.go b/internal/api/mail/handlers_attachments.go new file mode 100644 index 0000000..8232162 --- /dev/null +++ b/internal/api/mail/handlers_attachments.go @@ -0,0 +1,203 @@ +package mail + +import ( + "errors" + "fmt" + "io" + "mime" + "net/http" + "path/filepath" + "strings" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" +) + +const maxMultipartBody = 26 << 20 // 26 MiB + +func (h *Handler) ListMessageAttachments(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + messageID := chi.URLParam(r, "messageID") + + list, err := h.svc.ListMessageAttachments(r.Context(), claims.Sub, messageID) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("list attachments", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"attachments": list}) +} + +func (h *Handler) MessageAttachmentCIDMap(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + messageID := chi.URLParam(r, "messageID") + + mapping, err := h.svc.MessageAttachmentCIDMap(r.Context(), claims.Sub, messageID) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("attachment cid map", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"cid_map": mapping}) +} + +func (h *Handler) UploadMessageAttachment(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + messageID := chi.URLParam(r, "messageID") + + if err := r.ParseMultipartForm(maxMultipartBody); err != nil { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid multipart form", nil) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "file field required", nil) + return + } + defer file.Close() + + filename := filepath.Base(header.Filename) + contentType := header.Header.Get("Content-Type") + if contentType == "" { + contentType = mime.TypeByExtension(filepath.Ext(filename)) + } + if contentType == "" { + contentType = "application/octet-stream" + } + contentID := strings.TrimSpace(r.FormValue("content_id")) + isInline := strings.EqualFold(r.FormValue("inline"), "true") || contentID != "" + + id, err := h.svc.UploadMessageAttachment( + r.Context(), claims.Sub, messageID, filename, contentType, contentID, isInline, + file, header.Size, + ) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrAttachmentTooLarge) { + apiresponse.WriteError(w, r, http.StatusRequestEntityTooLarge, apiresponse.CodeInvalidRequest, "attachment too large", nil) + return + } + h.logger.Error("upload attachment", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) DownloadAttachment(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + inline := strings.HasSuffix(r.URL.Path, "/inline") || r.URL.Query().Get("inline") == "true" + attachmentID := chi.URLParam(r, "attachmentID") + + filename, contentType, size, isInline, body, err := h.svc.OpenAttachment(r.Context(), claims.Sub, attachmentID) + if err != nil { + if errors.Is(err, ErrAttachmentNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("download attachment", "error", err) + apivalidate.WriteInternal(w, r) + return + } + defer body.Close() + + disposition := "attachment" + if inline || isInline { + disposition = "inline" + } + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Disposition", fmt.Sprintf(`%s; filename="%s"`, disposition, filename)) + if size > 0 { + w.Header().Set("Content-Length", fmt.Sprintf("%d", size)) + } + _, _ = io.Copy(w, body) +} + +func (h *Handler) UploadDraftAttachment(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + draftID := chi.URLParam(r, "draftID") + + if err := r.ParseMultipartForm(maxMultipartBody); err != nil { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid multipart form", nil) + return + } + + file, header, err := r.FormFile("file") + if err != nil { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "file field required", nil) + return + } + defer file.Close() + + filename := filepath.Base(header.Filename) + contentType := header.Header.Get("Content-Type") + if contentType == "" { + contentType = mime.TypeByExtension(filepath.Ext(filename)) + } + if contentType == "" { + contentType = "application/octet-stream" + } + contentID := strings.TrimSpace(r.FormValue("content_id")) + isInline := strings.EqualFold(r.FormValue("inline"), "true") || contentID != "" + + id, err := h.svc.UploadDraftAttachment( + r.Context(), claims.Sub, draftID, filename, contentType, contentID, isInline, + file, header.Size, + ) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrAttachmentTooLarge) { + apiresponse.WriteError(w, r, http.StatusRequestEntityTooLarge, apiresponse.CodeInvalidRequest, "attachment too large", nil) + return + } + h.logger.Error("upload draft attachment", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) DownloadDraftAttachment(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + draftID := chi.URLParam(r, "draftID") + attachmentID := chi.URLParam(r, "attachmentID") + inline := strings.HasSuffix(r.URL.Path, "/inline") || r.URL.Query().Get("inline") == "true" + + filename, contentType, body, err := h.svc.OpenDraftAttachment(r.Context(), claims.Sub, draftID, attachmentID) + if err != nil { + if errors.Is(err, ErrNotFound) || errors.Is(err, ErrAttachmentNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("download draft attachment", "error", err) + apivalidate.WriteInternal(w, r) + return + } + defer body.Close() + + disposition := "attachment" + if inline { + disposition = "inline" + } + w.Header().Set("Content-Type", contentType) + w.Header().Set("Content-Disposition", fmt.Sprintf(`%s; filename="%s"`, disposition, filename)) + _, _ = io.Copy(w, body) +} diff --git a/internal/api/mail/handlers_drafts.go b/internal/api/mail/handlers_drafts.go new file mode 100644 index 0000000..96857ff --- /dev/null +++ b/internal/api/mail/handlers_drafts.go @@ -0,0 +1,122 @@ +package mail + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/api/query" +) + +func (h *Handler) ListDrafts(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + params, err := query.ParseListRequest(r) + if err != nil { + apivalidate.WriteQueryError(w, r, err) + return + } + + result, err := h.svc.ListDrafts(r.Context(), claims.Sub, params) + if err != nil { + h.logger.Error("list drafts", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + +func (h *Handler) GetDraft(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + draft, err := h.svc.GetDraft(r.Context(), claims.Sub, chi.URLParam(r, "draftID")) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("get draft", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, draft) +} + +func (h *Handler) CreateDraft(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) + if err != nil { + h.writeUserResolveError(w, r, err) + return + } + + var req draftRequest + if err := apivalidate.DecodeJSON(w, r, maxSendRequestBody, &req); err != nil { + return + } + if verr := validateCreateDraft(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + id, err := h.svc.CreateDraft(r.Context(), userID, &req) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("create draft", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) UpdateDraft(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req draftRequest + if err := apivalidate.DecodeJSON(w, r, maxSendRequestBody, &req); err != nil { + return + } + if verr := validateUpdateDraft(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + if err := h.svc.UpdateDraft(r.Context(), claims.Sub, chi.URLParam(r, "draftID"), &req); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + h.logger.Error("update draft", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) DeleteDraft(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + if err := h.svc.DeleteDraft(r.Context(), claims.Sub, chi.URLParam(r, "draftID")); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("delete draft", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/mail/handlers_folders_labels.go b/internal/api/mail/handlers_folders_labels.go new file mode 100644 index 0000000..0b65a16 --- /dev/null +++ b/internal/api/mail/handlers_folders_labels.go @@ -0,0 +1,240 @@ +package mail + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/api/query" +) + +// FolderLabelRoutes registers folder and user-label endpoints without modifying Handler.Routes(). +func (h *Handler) FolderLabelRoutes() chi.Router { + r := chi.NewRouter() + + r.Get("/folders", h.ListFolders) + r.Post("/folders", h.CreateFolder) + r.Get("/folders/{folderID}", h.GetFolder) + r.Put("/folders/{folderID}", h.UpdateFolder) + r.Delete("/folders/{folderID}", h.DeleteFolder) + + r.Get("/labels", h.ListUserLabels) + r.Post("/labels", h.CreateUserLabel) + r.Put("/labels/{labelID}", h.UpdateUserLabel) + r.Delete("/labels/{labelID}", h.DeleteUserLabel) + + return r +} + +func (h *Handler) ListFolders(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + accountID := r.URL.Query().Get("account_id") + if verr := validateListFoldersAccountID(accountID); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + params, err := query.ParseListRequest(r) + if err != nil { + apivalidate.WriteQueryError(w, r, err) + return + } + + result, err := h.svc.ListFolders(r.Context(), claims.Sub, accountID, params) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + h.logger.Error("list folders", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + +func (h *Handler) GetFolder(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + folder, err := h.svc.GetFolder(r.Context(), claims.Sub, chi.URLParam(r, "folderID")) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("get folder", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, folder) +} + +func (h *Handler) CreateFolder(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + userID, err := h.svc.ResolveUserID(r.Context(), claims.Sub) + if err != nil { + h.writeUserResolveError(w, r, err) + return + } + + var req createFolderRequest + if err := apivalidate.DecodeJSON(w, r, maxFoldersRequestBody, &req); err != nil { + return + } + if verr := validateCreateFolder(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + id, err := h.svc.CreateFolder(r.Context(), userID, &req) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + if errors.Is(err, ErrDuplicateFolder) { + apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, "folder remote_name already exists", nil) + return + } + h.logger.Error("create folder", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) UpdateFolder(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req updateFolderRequest + if err := apivalidate.DecodeJSON(w, r, maxFoldersRequestBody, &req); err != nil { + return + } + if verr := validateUpdateFolder(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + if err := h.svc.UpdateFolder(r.Context(), claims.Sub, chi.URLParam(r, "folderID"), &req); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrDuplicateFolder) { + apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, "folder remote_name already exists", nil) + return + } + h.logger.Error("update folder", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) DeleteFolder(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + if err := h.svc.DeleteFolder(r.Context(), claims.Sub, chi.URLParam(r, "folderID")); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrFolderProtected) { + apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "system folder cannot be deleted", nil) + return + } + h.logger.Error("delete folder", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) ListUserLabels(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + params, err := query.ParseListRequest(r) + if err != nil { + apivalidate.WriteQueryError(w, r, err) + return + } + + result, err := h.svc.ListUserLabels(r.Context(), claims.Sub, params) + if err != nil { + h.logger.Error("list user labels", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + +func (h *Handler) CreateUserLabel(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req createUserLabelRequest + if err := apivalidate.DecodeJSON(w, r, maxLabelsRequestBody, &req); err != nil { + return + } + if verr := validateCreateUserLabel(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + id, err := h.svc.CreateUserLabel(r.Context(), claims.Sub, &req) + if err != nil { + if errors.Is(err, ErrDuplicateLabel) { + apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, "label name already exists", nil) + return + } + h.logger.Error("create user label", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) UpdateUserLabel(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req updateUserLabelRequest + if err := apivalidate.DecodeJSON(w, r, maxLabelsRequestBody, &req); err != nil { + return + } + if verr := validateUpdateUserLabel(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + if err := h.svc.UpdateUserLabel(r.Context(), claims.Sub, chi.URLParam(r, "labelID"), &req); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + if errors.Is(err, ErrDuplicateLabel) { + apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, "label name already exists", nil) + return + } + h.logger.Error("update user label", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) DeleteUserLabel(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + if err := h.svc.DeleteUserLabel(r.Context(), claims.Sub, chi.URLParam(r, "labelID")); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("delete user label", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/mail/handlers_identities.go b/internal/api/mail/handlers_identities.go new file mode 100644 index 0000000..bf0bd67 --- /dev/null +++ b/internal/api/mail/handlers_identities.go @@ -0,0 +1,113 @@ +package mail + +import ( + "errors" + "net/http" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/api/query" +) + +func (h *Handler) ListIdentities(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + params, err := query.ParseListRequest(r) + if err != nil { + apivalidate.WriteQueryError(w, r, err) + return + } + + result, err := h.svc.ListIdentities(r.Context(), claims.Sub, chi.URLParam(r, "accountID"), params) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + h.logger.Error("list identities", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + +func (h *Handler) GetIdentity(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + identity, err := h.svc.GetIdentity(r.Context(), claims.Sub, chi.URLParam(r, "identityID")) + if err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("get identity", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, identity) +} + +func (h *Handler) CreateIdentity(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req createIdentityRequest + if err := apivalidate.DecodeJSON(w, r, maxIdentityRequestBody, &req); err != nil { + return + } + if verr := validateCreateIdentity(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + id, err := h.svc.CreateIdentity(r.Context(), claims.Sub, chi.URLParam(r, "accountID"), &req) + if err != nil { + if errors.Is(err, ErrAccountNotFound) { + apivalidate.WriteNotFound(w, r, "account not found") + return + } + h.logger.Error("create identity", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusCreated, map[string]string{"id": id}) +} + +func (h *Handler) UpdateIdentity(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + var req updateIdentityRequest + if err := apivalidate.DecodeJSON(w, r, maxIdentityRequestBody, &req); err != nil { + return + } + if verr := validateUpdateIdentity(&req); verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + if err := h.svc.UpdateIdentity(r.Context(), claims.Sub, chi.URLParam(r, "identityID"), &req); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("update identity", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) DeleteIdentity(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + + if err := h.svc.DeleteIdentity(r.Context(), claims.Sub, chi.URLParam(r, "identityID")); err != nil { + if errors.Is(err, ErrNotFound) { + apivalidate.WriteNotFound(w, r, "not found") + return + } + h.logger.Error("delete identity", "error", err) + apivalidate.WriteInternal(w, r) + return + } + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/api/mail/handlers_search.go b/internal/api/mail/handlers_search.go new file mode 100644 index 0000000..23923ac --- /dev/null +++ b/internal/api/mail/handlers_search.go @@ -0,0 +1,86 @@ +package mail + +import ( + "net/http" + "time" + + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/api/query" +) + +func (h *Handler) SearchMessages(w http.ResponseWriter, r *http.Request) { + claims := middleware.ClaimsFromContext(r.Context()) + params, err := query.ParseListRequest(r) + if err != nil { + apivalidate.WriteQueryError(w, r, err) + return + } + + filter, verr := parseMessageSearchFilter(r) + if verr != nil { + apivalidate.WriteValidationError(w, r, verr) + return + } + + result, err := h.svc.SearchMessages(r.Context(), claims.Sub, filter, params) + if err != nil { + h.logger.Error("search messages", "error", err) + apivalidate.WriteInternal(w, r) + return + } + apiresponse.WriteJSON(w, http.StatusOK, result) +} + +func parseMessageSearchFilter(r *http.Request) (MessageSearchFilter, *apivalidate.ValidationError) { + q := r.URL.Query() + filter := MessageSearchFilter{ + Query: q.Get("q"), + Sender: q.Get("from"), + Label: q.Get("label"), + AccountID: q.Get("account_id"), + } + + if raw := q.Get("date_from"); raw != "" { + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return filter, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "date_from", Message: "invalid RFC3339 datetime", + }) + } + filter.DateFrom = &t + } + if raw := q.Get("date_to"); raw != "" { + t, err := time.Parse(time.RFC3339, raw) + if err != nil { + return filter, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "date_to", Message: "invalid RFC3339 datetime", + }) + } + filter.DateTo = &t + } + if raw := q.Get("has_attachment"); raw != "" { + switch raw { + case "true", "1": + v := true + filter.HasAttachments = &v + case "false", "0": + v := false + filter.HasAttachments = &v + default: + return filter, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "has_attachment", Message: "must be true or false", + }) + } + } + + if filter.Query == "" && filter.Sender == "" && filter.Label == "" && + filter.AccountID == "" && filter.DateFrom == nil && filter.DateTo == nil && + filter.HasAttachments == nil { + return filter, apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "q", Message: "at least one search filter required", + }) + } + return filter, nil +} diff --git a/internal/api/mail/handlers_test.go b/internal/api/mail/handlers_test.go index debdcc8..e509b77 100644 --- a/internal/api/mail/handlers_test.go +++ b/internal/api/mail/handlers_test.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "io" "net/http" "net/http/httptest" "testing" @@ -16,8 +17,9 @@ import ( ) const ( - testExternalID = "ext-user-1" - testUserID = "user-uuid-1" + testExternalID = "ext-user-1" + testExternalID2 = "ext-user-2" + testUserID = "user-uuid-1" ) type fakeMailService struct { @@ -78,8 +80,8 @@ func (f *fakeMailService) GetMessage(_ context.Context, externalID, messageID st return msg, nil } -func (f *fakeMailService) UpdateLabels(_ context.Context, userID, messageID string, labels []string) error { - if userID != testUserID { +func (f *fakeMailService) UpdateLabels(_ context.Context, externalID, messageID string, labels []string) error { + if externalID != testExternalID { return ErrNotFound } if f.deleted[messageID] { @@ -93,8 +95,8 @@ func (f *fakeMailService) UpdateLabels(_ context.Context, userID, messageID stri return nil } -func (f *fakeMailService) UpdateFlags(_ context.Context, userID, messageID string, flags []string) error { - if userID != testUserID { +func (f *fakeMailService) UpdateFlags(_ context.Context, externalID, messageID string, flags []string) error { + if externalID != testExternalID { return ErrNotFound } if f.deleted[messageID] { @@ -108,8 +110,8 @@ func (f *fakeMailService) UpdateFlags(_ context.Context, userID, messageID strin return nil } -func (f *fakeMailService) DeleteMessage(_ context.Context, externalID, userID, messageID string) error { - if externalID != testExternalID || userID != testUserID { +func (f *fakeMailService) DeleteMessage(_ context.Context, externalID, messageID string) error { + if externalID != testExternalID { return ErrNotFound } if _, ok := f.messages[messageID]; !ok || f.deleted[messageID] { @@ -127,6 +129,22 @@ func (f *fakeMailService) SendMessage(_ context.Context, userID string, req *sen return "outbox-1", "queued", nil } +func (f *fakeMailService) ListDrafts(context.Context, string, query.ListParams) (DraftsList, error) { + return DraftsList{}, nil +} +func (f *fakeMailService) GetDraft(context.Context, string, string) (map[string]any, error) { + return nil, ErrNotFound +} +func (f *fakeMailService) CreateDraft(context.Context, string, *draftRequest) (string, error) { + return "", nil +} +func (f *fakeMailService) UpdateDraft(context.Context, string, string, *draftRequest) error { + return nil +} +func (f *fakeMailService) DeleteDraft(context.Context, string, string) error { + return nil +} + func (f *fakeMailService) ListAccounts(context.Context, string, query.ListParams) (AccountsList, error) { return AccountsList{}, nil } @@ -146,25 +164,211 @@ func (f *fakeMailService) ListRules(context.Context, string, query.ListParams) ( func (f *fakeMailService) CreateRule(context.Context, string, *createRuleRequest) (string, error) { return "", nil } -func (f *fakeMailService) UpdateRule(context.Context, string, string, *updateRuleRequest) error { +func (f *fakeMailService) UpdateRule(_ context.Context, externalID, ruleID string, _ *updateRuleRequest) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} +func (f *fakeMailService) DeleteRule(_ context.Context, externalID, ruleID string) error { + if externalID != testExternalID { + return ErrNotFound + } return nil } -func (f *fakeMailService) DeleteRule(context.Context, string, string, string) error { return nil } func (f *fakeMailService) ListWebhooks(context.Context, string, query.ListParams) (WebhooksList, error) { return WebhooksList{}, nil } func (f *fakeMailService) CreateWebhook(context.Context, string, *createWebhookRequest, string) (string, error) { return "", nil } -func (f *fakeMailService) DeleteWebhook(context.Context, string, string, string) error { return nil } +func (f *fakeMailService) DeleteWebhook(_ context.Context, externalID, webhookID string) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) ListIdentities(_ context.Context, externalID, accountID string, params query.ListParams) (IdentitiesList, error) { + if externalID != testExternalID { + return IdentitiesList{}, ErrAccountNotFound + } + if accountID != "acc-1" { + return IdentitiesList{}, ErrAccountNotFound + } + total := int64(1) + return IdentitiesList{ + Identities: []map[string]any{{ + "id": "id-1", "account_id": "acc-1", "email": "sender@example.com", + "name": "Sender", "is_default": true, "signature_html": "", + "reply_to_addrs": []string{}, "created_at": nil, "updated_at": nil, + }}, + Pagination: params.Meta(&total), + }, nil +} + +func (f *fakeMailService) GetIdentity(_ context.Context, externalID, identityID string) (map[string]any, error) { + if externalID != testExternalID || identityID != "id-1" { + return nil, ErrNotFound + } + return map[string]any{ + "id": "id-1", "account_id": "acc-1", "email": "sender@example.com", + "name": "Sender", "is_default": true, "signature_html": "", + "reply_to_addrs": []string{}, "created_at": nil, "updated_at": nil, + }, nil +} + +func (f *fakeMailService) CreateIdentity(_ context.Context, externalID, accountID string, req *createIdentityRequest) (string, error) { + if externalID != testExternalID || accountID != "acc-1" { + return "", ErrAccountNotFound + } + if req.Email == "" { + return "", ErrAccountNotFound + } + return "id-new", nil +} + +func (f *fakeMailService) UpdateIdentity(_ context.Context, externalID, identityID string, _ *updateIdentityRequest) error { + if externalID != testExternalID || identityID != "id-1" { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) DeleteIdentity(_ context.Context, externalID, identityID string) error { + if externalID != testExternalID || identityID != "id-1" { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) ListFolders(_ context.Context, externalID, accountID string, params query.ListParams) (FoldersList, error) { + if externalID != testExternalID { + return FoldersList{}, ErrAccountNotFound + } + total := int64(0) + return FoldersList{Folders: []map[string]any{}, Pagination: params.Meta(&total)}, nil +} + +func (f *fakeMailService) GetFolder(_ context.Context, externalID, folderID string) (map[string]any, error) { + if externalID != testExternalID { + return nil, ErrNotFound + } + return map[string]any{"id": folderID, "name": "Inbox", "remote_name": "INBOX", "folder_type": "inbox"}, nil +} + +func (f *fakeMailService) CreateFolder(_ context.Context, userID string, _ *createFolderRequest) (string, error) { + if userID != testUserID { + return "", ErrAccountNotFound + } + return "folder-1", nil +} + +func (f *fakeMailService) UpdateFolder(_ context.Context, externalID, folderID string, _ *updateFolderRequest) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) DeleteFolder(_ context.Context, externalID, folderID string) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) ListUserLabels(_ context.Context, externalID string, params query.ListParams) (UserLabelsList, error) { + if externalID != testExternalID { + return UserLabelsList{}, nil + } + total := int64(0) + return UserLabelsList{Labels: []map[string]any{}, Pagination: params.Meta(&total)}, nil +} + +func (f *fakeMailService) CreateUserLabel(_ context.Context, externalID string, _ *createUserLabelRequest) (string, error) { + if externalID != testExternalID { + return "", ErrNotFound + } + return "label-1", nil +} + +func (f *fakeMailService) UpdateUserLabel(_ context.Context, externalID, labelID string, _ *updateUserLabelRequest) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) DeleteUserLabel(_ context.Context, externalID, labelID string) error { + if externalID != testExternalID { + return ErrNotFound + } + return nil +} + +func (f *fakeMailService) SearchMessages(_ context.Context, externalID string, _ MessageSearchFilter, params query.ListParams) (MessageSearchResult, error) { + if externalID != testExternalID { + return MessageSearchResult{}, ErrUserNotProvisioned + } + msgs := make([]map[string]any, 0, len(f.messages)) + for id, msg := range f.messages { + if f.deleted[id] { + continue + } + msgs = append(msgs, msg) + } + total := int64(len(msgs)) + return MessageSearchResult{Messages: msgs, Pagination: params.Meta(&total)}, nil +} + +func (f *fakeMailService) ListMessageAttachments(_ context.Context, externalID, messageID string) ([]map[string]any, error) { + if externalID != testExternalID { + return nil, ErrNotFound + } + if _, ok := f.messages[messageID]; !ok { + return nil, ErrNotFound + } + return []map[string]any{}, nil +} + +func (f *fakeMailService) MessageAttachmentCIDMap(_ context.Context, externalID, messageID string) (map[string]string, error) { + if externalID != testExternalID { + return nil, ErrNotFound + } + if _, ok := f.messages[messageID]; !ok { + return nil, ErrNotFound + } + return map[string]string{}, nil +} + +func (f *fakeMailService) UploadMessageAttachment(context.Context, string, string, string, string, string, bool, io.Reader, int64) (string, error) { + return "att-1", nil +} + +func (f *fakeMailService) OpenAttachment(context.Context, string, string) (string, string, int64, bool, io.ReadCloser, error) { + return "", "", 0, false, nil, ErrAttachmentNotFound +} + +func (f *fakeMailService) UploadDraftAttachment(context.Context, string, string, string, string, string, bool, io.Reader, int64) (string, error) { + return "att-draft-1", nil +} + +func (f *fakeMailService) OpenDraftAttachment(context.Context, string, string, string) (string, string, io.ReadCloser, error) { + return "", "", nil, ErrAttachmentNotFound +} func newTestMailRouter(svc ServiceAPI) http.Handler { - h := NewHandlerWithService(svc) - r := chi.NewRouter() - r.Use(middleware.WithTestClaims(&auth.Claims{ + return newTestMailRouterWithClaims(svc, &auth.Claims{ Sub: testExternalID, Email: "user@example.com", - })) + }) +} + +func newTestMailRouterWithClaims(svc ServiceAPI, claims *auth.Claims) http.Handler { + h := NewHandlerWithService(svc) + r := chi.NewRouter() + r.Use(middleware.WithTestClaims(claims)) r.Mount("/", h.Routes()) return r } @@ -319,6 +523,28 @@ func TestUpdateFlags(t *testing.T) { } } +func TestUpdateLabelsCrossUser(t *testing.T) { + svc := newFakeMailService() + router := newTestMailRouterWithClaims(svc, &auth.Claims{ + Sub: testExternalID2, + Email: "other@example.com", + }) + + body, err := json.Marshal(map[string]any{"labels": []string{"stolen"}}) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + + req := httptest.NewRequest(http.MethodPut, "/messages/msg-1/labels", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusNotFound, rec.Body.String()) + } +} + func TestDeleteMessage(t *testing.T) { svc := newFakeMailService() router := newTestMailRouter(svc) diff --git a/internal/api/mail/identities.go b/internal/api/mail/identities.go new file mode 100644 index 0000000..7d2e977 --- /dev/null +++ b/internal/api/mail/identities.go @@ -0,0 +1,217 @@ +package mail + +import ( + "context" + "encoding/json" + "errors" + + "github.com/jackc/pgx/v5" + + "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/securityaudit" +) + +type IdentitiesList struct { + Identities []map[string]any `json:"identities"` + Pagination query.PaginationMeta `json:"pagination,omitempty"` +} + +const identitySelectColumns = ` + mi.id, mi.account_id, mi.email, mi.name, mi.is_default, mi.signature_html, mi.reply_to_addrs, mi.created_at, mi.updated_at +` + +func (s *Service) verifyAccountOwnership(ctx context.Context, externalID, accountID string) error { + var exists bool + err := s.db.QueryRow(ctx, ` + SELECT EXISTS( + SELECT 1 FROM mail_accounts ma + JOIN users u ON ma.user_id = u.id + WHERE ma.id = $1 AND u.external_id = $2 + ) + `, accountID, externalID).Scan(&exists) + if err != nil { + return err + } + if !exists { + return ErrAccountNotFound + } + return nil +} + +func identityOwnershipJoin() string { + return ` + FROM mail_identities mi + JOIN mail_accounts ma ON mi.account_id = ma.id + JOIN users u ON ma.user_id = u.id + ` +} + +func scanIdentity(id, accountID, email, name, signatureHTML string, isDefault bool, replyToJSON []byte, createdAt, updatedAt any) map[string]any { + replyTo := parseReplyToAddrs(replyToJSON) + return map[string]any{ + "id": id, + "account_id": accountID, + "email": email, + "name": name, + "is_default": isDefault, + "signature_html": signatureHTML, + "reply_to_addrs": replyTo, + "created_at": createdAt, + "updated_at": updatedAt, + } +} + +func parseReplyToAddrs(raw []byte) []string { + if len(raw) == 0 { + return []string{} + } + var addrs []string + if err := json.Unmarshal(raw, &addrs); err != nil || addrs == nil { + return []string{} + } + return addrs +} + +func (s *Service) clearDefaultIdentities(ctx context.Context, accountID string) error { + _, err := s.db.Exec(ctx, ` + UPDATE mail_identities SET is_default = false, updated_at = NOW() + WHERE account_id = $1 AND is_default = true + `, accountID) + return err +} + +func (s *Service) ListIdentities(ctx context.Context, externalID, accountID string, params query.ListParams) (IdentitiesList, error) { + if err := s.verifyAccountOwnership(ctx, externalID, accountID); err != nil { + return IdentitiesList{}, err + } + + var total int64 + countQuery := "SELECT COUNT(*) " + identityOwnershipJoin() + " WHERE mi.account_id = $1 AND u.external_id = $2" + if err := s.db.QueryRow(ctx, countQuery, accountID, externalID).Scan(&total); err != nil { + return IdentitiesList{}, err + } + + listQuery := "SELECT " + identitySelectColumns + identityOwnershipJoin() + + " WHERE mi.account_id = $1 AND u.external_id = $2 ORDER BY mi.is_default DESC, mi.created_at ASC LIMIT $3 OFFSET $4" + + rows, err := s.db.Query(ctx, listQuery, accountID, externalID, params.Limit(), params.Offset()) + if err != nil { + return IdentitiesList{}, err + } + defer rows.Close() + + identities := make([]map[string]any, 0) + for rows.Next() { + var id, acctID, email, name, signatureHTML string + var isDefault bool + var replyToJSON []byte + var createdAt, updatedAt any + if err := rows.Scan(&id, &acctID, &email, &name, &isDefault, &signatureHTML, &replyToJSON, &createdAt, &updatedAt); err != nil { + return IdentitiesList{}, err + } + identities = append(identities, scanIdentity(id, acctID, email, name, signatureHTML, isDefault, replyToJSON, createdAt, updatedAt)) + } + if err := rows.Err(); err != nil { + return IdentitiesList{}, err + } + + return IdentitiesList{ + Identities: identities, + Pagination: params.Meta(&total), + }, nil +} + +func (s *Service) GetIdentity(ctx context.Context, externalID, identityID string) (map[string]any, error) { + query := "SELECT " + identitySelectColumns + identityOwnershipJoin() + " WHERE mi.id = $1 AND u.external_id = $2" + + var id, accountID, email, name, signatureHTML string + var isDefault bool + var replyToJSON []byte + var createdAt, updatedAt any + err := s.db.QueryRow(ctx, query, identityID, externalID).Scan( + &id, &accountID, &email, &name, &isDefault, &signatureHTML, &replyToJSON, &createdAt, &updatedAt, + ) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + return scanIdentity(id, accountID, email, name, signatureHTML, isDefault, replyToJSON, createdAt, updatedAt), nil +} + +func (s *Service) CreateIdentity(ctx context.Context, externalID, accountID string, req *createIdentityRequest) (string, error) { + if err := s.verifyAccountOwnership(ctx, externalID, accountID); err != nil { + return "", err + } + + if req.IsDefault { + if err := s.clearDefaultIdentities(ctx, accountID); err != nil { + return "", err + } + } + + replyToJSON, _ := json.Marshal(req.ReplyToAddrs) + + var id string + err := s.db.QueryRow(ctx, ` + INSERT INTO mail_identities (account_id, email, name, is_default, signature_html, reply_to_addrs) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id + `, accountID, req.Email, req.Name, req.IsDefault, req.SignatureHTML, replyToJSON).Scan(&id) + if err != nil { + return "", err + } + return id, nil +} + +func (s *Service) UpdateIdentity(ctx context.Context, externalID, identityID string, req *updateIdentityRequest) error { + identity, err := s.GetIdentity(ctx, externalID, identityID) + if err != nil { + return err + } + accountID, _ := identity["account_id"].(string) + + if req.IsDefault { + if err := s.clearDefaultIdentities(ctx, accountID); err != nil { + return err + } + } + + replyToJSON, _ := json.Marshal(req.ReplyToAddrs) + + result, err := s.db.Exec(ctx, ` + UPDATE mail_identities mi SET + email = $1, name = $2, is_default = $3, signature_html = $4, reply_to_addrs = $5, updated_at = NOW() + FROM mail_accounts ma + JOIN users u ON ma.user_id = u.id + WHERE mi.id = $6 AND mi.account_id = ma.id AND u.external_id = $7 + `, req.Email, req.Name, req.IsDefault, req.SignatureHTML, replyToJSON, identityID, externalID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Service) DeleteIdentity(ctx context.Context, externalID, identityID string) error { + result, err := s.db.Exec(ctx, ` + DELETE FROM mail_identities mi + USING mail_accounts ma, users u + WHERE mi.id = $1 AND mi.account_id = ma.id AND ma.user_id = u.id AND u.external_id = $2 + `, identityID, externalID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + if s.audit != nil { + s.audit.Log(ctx, externalID, securityaudit.ActionCriticalDeletion, map[string]any{ + "target": "mail_identity", "identity_id": identityID, + }) + } + return nil +} diff --git a/internal/api/mail/identities_test.go b/internal/api/mail/identities_test.go new file mode 100644 index 0000000..2bd8706 --- /dev/null +++ b/internal/api/mail/identities_test.go @@ -0,0 +1,157 @@ +package mail + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/go-chi/chi/v5" + + "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/auth" +) + +func newTestIdentityRouter(svc ServiceAPI) http.Handler { + h := NewHandlerWithService(svc) + r := chi.NewRouter() + r.Use(middleware.WithTestClaims(&auth.Claims{ + Sub: testExternalID, + Email: "user@example.com", + })) + r.Get("/accounts/{accountID}/identities", h.ListIdentities) + r.Post("/accounts/{accountID}/identities", h.CreateIdentity) + r.Get("/identities/{identityID}", h.GetIdentity) + r.Put("/identities/{identityID}", h.UpdateIdentity) + r.Delete("/identities/{identityID}", h.DeleteIdentity) + return r +} + +func TestListIdentities(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/accounts/acc-1/identities", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var body IdentitiesList + if err := json.NewDecoder(rec.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Identities) != 1 { + t.Fatalf("identities len = %d, want 1", len(body.Identities)) + } + if body.Identities[0]["email"] != "sender@example.com" { + t.Fatalf("email = %v", body.Identities[0]["email"]) + } +} + +func TestCreateIdentity(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + payload := map[string]any{ + "email": "alias@example.com", + "name": "Alias", + } + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + + req := httptest.NewRequest(http.MethodPost, "/accounts/acc-1/identities", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusCreated { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusCreated, rec.Body.String()) + } + + var resp map[string]string + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode body: %v", err) + } + if resp["id"] != "id-new" { + t.Fatalf("id = %q, want id-new", resp["id"]) + } +} + +func TestGetIdentity(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/identities/id-1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } + + var identity map[string]any + if err := json.NewDecoder(rec.Body).Decode(&identity); err != nil { + t.Fatalf("decode body: %v", err) + } + if identity["id"] != "id-1" { + t.Fatalf("id = %v", identity["id"]) + } +} + +func TestUpdateIdentity(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + payload := map[string]any{ + "email": "updated@example.com", + "name": "Updated", + "is_default": false, + "signature_html": "

Sig

", + "reply_to_addrs": []string{"reply@example.com"}, + } + body, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal payload: %v", err) + } + + req := httptest.NewRequest(http.MethodPut, "/identities/id-1", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusNoContent, rec.Body.String()) + } +} + +func TestDeleteIdentity(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + req := httptest.NewRequest(http.MethodDelete, "/identities/id-1", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNoContent { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusNoContent, rec.Body.String()) + } +} + +func TestListIdentitiesAccountNotFound(t *testing.T) { + svc := newFakeMailService() + router := newTestIdentityRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/accounts/missing/identities", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotFound { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusNotFound, rec.Body.String()) + } +} diff --git a/internal/api/mail/labels.go b/internal/api/mail/labels.go new file mode 100644 index 0000000..7737868 --- /dev/null +++ b/internal/api/mail/labels.go @@ -0,0 +1,108 @@ +package mail + +import ( + "context" + + "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/securityaudit" +) + +type UserLabelsList struct { + Labels []map[string]any `json:"labels"` + Pagination query.PaginationMeta `json:"pagination,omitempty"` +} + +func (s *Service) ListUserLabels(ctx context.Context, externalID string, params query.ListParams) (UserLabelsList, error) { + var total int64 + err := s.db.QueryRow(ctx, ` + SELECT COUNT(*) FROM mail_user_labels + WHERE user_id = (SELECT id FROM users WHERE external_id = $1) + `, externalID).Scan(&total) + if err != nil { + return UserLabelsList{}, err + } + + rows, err := s.db.Query(ctx, ` + SELECT id, name, color, created_at + FROM mail_user_labels + WHERE user_id = (SELECT id FROM users WHERE external_id = $1) + ORDER BY name ASC + LIMIT $2 OFFSET $3 + `, externalID, params.Limit(), params.Offset()) + if err != nil { + return UserLabelsList{}, err + } + defer rows.Close() + + labels := make([]map[string]any, 0) + for rows.Next() { + var id, name, color string + var createdAt any + if err := rows.Scan(&id, &name, &color, &createdAt); err != nil { + return UserLabelsList{}, err + } + labels = append(labels, map[string]any{ + "id": id, "name": name, "color": color, "created_at": createdAt, + }) + } + if err := rows.Err(); err != nil { + return UserLabelsList{}, err + } + + return UserLabelsList{ + Labels: labels, + Pagination: params.Meta(&total), + }, nil +} + +func (s *Service) CreateUserLabel(ctx context.Context, externalID string, req *createUserLabelRequest) (string, error) { + var id string + err := s.db.QueryRow(ctx, ` + INSERT INTO mail_user_labels (user_id, name, color) + VALUES ((SELECT id FROM users WHERE external_id = $1), $2, $3) + RETURNING id + `, externalID, req.Name, req.Color).Scan(&id) + if err != nil { + if isUniqueViolation(err) { + return "", ErrDuplicateLabel + } + return "", err + } + return id, nil +} + +func (s *Service) UpdateUserLabel(ctx context.Context, externalID, labelID string, req *updateUserLabelRequest) error { + result, err := s.db.Exec(ctx, ` + UPDATE mail_user_labels SET name = $1, color = $2 + WHERE id = $3 AND user_id = (SELECT id FROM users WHERE external_id = $4) + `, req.Name, req.Color, labelID, externalID) + if err != nil { + if isUniqueViolation(err) { + return ErrDuplicateLabel + } + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + return nil +} + +func (s *Service) DeleteUserLabel(ctx context.Context, externalID, labelID string) error { + result, err := s.db.Exec(ctx, ` + DELETE FROM mail_user_labels + WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2) + `, labelID, externalID) + if err != nil { + return err + } + if result.RowsAffected() == 0 { + return ErrNotFound + } + if s.audit != nil { + s.audit.Log(ctx, externalID, securityaudit.ActionCriticalDeletion, map[string]any{ + "target": "mail_user_label", "label_id": labelID, + }) + } + return nil +} diff --git a/internal/api/mail/search_advanced.go b/internal/api/mail/search_advanced.go new file mode 100644 index 0000000..fbe6ccd --- /dev/null +++ b/internal/api/mail/search_advanced.go @@ -0,0 +1,129 @@ +package mail + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/ultisuite/ulti-backend/internal/api/query" +) + +type MessageSearchFilter struct { + Query string + Sender string + DateFrom *time.Time + DateTo *time.Time + HasAttachments *bool + Label string + AccountID string +} + +type MessageSearchResult struct { + Messages []map[string]any `json:"messages"` + Pagination query.PaginationMeta `json:"pagination,omitempty"` +} + +func (s *Service) SearchMessages(ctx context.Context, externalID string, filter MessageSearchFilter, params query.ListParams) (MessageSearchResult, error) { + base := ` + FROM messages m + JOIN mail_accounts ma ON m.account_id = ma.id + WHERE ma.user_id = (SELECT id FROM users WHERE external_id = $1) + ` + args := []any{externalID} + argIdx := 2 + + if filter.AccountID != "" { + base += fmt.Sprintf(" AND m.account_id = $%d", argIdx) + args = append(args, filter.AccountID) + argIdx++ + } + if filter.Sender != "" { + base += fmt.Sprintf(" AND m.from_addr::text ILIKE '%%' || $%d || '%%'", argIdx) + args = append(args, filter.Sender) + argIdx++ + } + if filter.DateFrom != nil { + base += fmt.Sprintf(" AND m.date >= $%d", argIdx) + args = append(args, *filter.DateFrom) + argIdx++ + } + if filter.DateTo != nil { + base += fmt.Sprintf(" AND m.date <= $%d", argIdx) + args = append(args, *filter.DateTo) + argIdx++ + } + if filter.HasAttachments != nil { + base += fmt.Sprintf(" AND m.has_attachments = $%d", argIdx) + args = append(args, *filter.HasAttachments) + argIdx++ + } + if filter.Label != "" { + base += fmt.Sprintf(" AND $%d = ANY(m.labels)", argIdx) + args = append(args, filter.Label) + argIdx++ + } + if q := strings.TrimSpace(filter.Query); q != "" { + tsQuery := toMailTSQuery(q) + base += fmt.Sprintf(" AND m.search_vector @@ to_tsquery('simple', $%d)", argIdx) + args = append(args, tsQuery) + argIdx++ + } + + var total int64 + if err := s.db.QueryRow(ctx, "SELECT COUNT(*) "+base, args...).Scan(&total); err != nil { + return MessageSearchResult{}, err + } + + listQuery := ` + SELECT m.id, m.message_id, m.thread_id, m.subject, m.from_addr, m.to_addrs, + m.date, m.snippet, m.flags, m.labels, m.has_attachments + ` + base + fmt.Sprintf(" ORDER BY m.date DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1) + args = append(args, params.Limit(), params.Offset()) + + rows, err := s.db.Query(ctx, listQuery, args...) + if err != nil { + return MessageSearchResult{}, err + } + defer rows.Close() + + messages := make([]map[string]any, 0) + for rows.Next() { + var id, messageID, subject, snippet string + var threadID *string + var fromAddr, toAddrs []byte + var date any + var flags, labels []string + var hasAttachments bool + if err := rows.Scan(&id, &messageID, &threadID, &subject, &fromAddr, &toAddrs, &date, &snippet, &flags, &labels, &hasAttachments); err != nil { + return MessageSearchResult{}, err + } + entry := map[string]any{ + "id": id, "message_id": messageID, "subject": subject, + "from": json.RawMessage(fromAddr), "to": json.RawMessage(toAddrs), + "date": date, "snippet": snippet, "flags": flags, "labels": labels, + "has_attachments": hasAttachments, + } + if threadID != nil { + entry["thread_id"] = *threadID + } + messages = append(messages, entry) + } + if err := rows.Err(); err != nil { + return MessageSearchResult{}, err + } + + return MessageSearchResult{ + Messages: messages, + Pagination: params.Meta(&total), + }, nil +} + +func toMailTSQuery(input string) string { + words := strings.Fields(input) + for i, w := range words { + words[i] = w + ":*" + } + return strings.Join(words, " & ") +} diff --git a/internal/api/mail/search_test.go b/internal/api/mail/search_test.go new file mode 100644 index 0000000..975d838 --- /dev/null +++ b/internal/api/mail/search_test.go @@ -0,0 +1,33 @@ +package mail + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +func TestSearchMessages(t *testing.T) { + svc := newFakeMailService() + router := newTestMailRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/search?q=hello", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String()) + } +} + +func TestSearchMessagesRequiresFilter(t *testing.T) { + svc := newFakeMailService() + router := newTestMailRouter(svc) + + req := httptest.NewRequest(http.MethodGet, "/search", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) + } +} diff --git a/internal/api/mail/service.go b/internal/api/mail/service.go index 0bb0478..e5ad40f 100644 --- a/internal/api/mail/service.go +++ b/internal/api/mail/service.go @@ -12,6 +12,8 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/mail/credentials" + "github.com/ultisuite/ulti-backend/internal/mail/storage" + "github.com/ultisuite/ulti-backend/internal/mail/threading" "github.com/ultisuite/ulti-backend/internal/securityaudit" ) @@ -23,18 +25,22 @@ var ( ) type Service struct { - db *pgxpool.Pool - credentials *credentials.Manager - audit *securityaudit.Logger - logger *slog.Logger + db *pgxpool.Pool + credentials *credentials.Manager + audit *securityaudit.Logger + storage *storage.Client + attachmentsBucket string + logger *slog.Logger } -func NewService(db *pgxpool.Pool, audit *securityaudit.Logger, credentialManager *credentials.Manager) *Service { +func NewService(db *pgxpool.Pool, audit *securityaudit.Logger, credentialManager *credentials.Manager, objectStorage *storage.Client, attachmentsBucket string) *Service { return &Service{ - db: db, - credentials: credentialManager, - audit: audit, - logger: slog.Default().With("component", "mail-service"), + db: db, + credentials: credentialManager, + audit: audit, + storage: objectStorage, + attachmentsBucket: attachmentsBucket, + logger: slog.Default().With("component", "mail-service"), } } @@ -194,7 +200,7 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me } listQuery := ` - SELECT m.id, m.subject, m.from_addr, m.to_addrs, m.date, m.snippet, m.flags, m.labels, m.has_attachments + SELECT m.id, m.message_id, m.thread_id, m.subject, m.from_addr, m.to_addrs, m.date, m.snippet, m.flags, m.labels, m.has_attachments ` + baseQuery + fmt.Sprintf(" ORDER BY m.date DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1) args = append(args, params.Limit(), params.Offset()) @@ -206,19 +212,24 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me messages := make([]map[string]any, 0) for rows.Next() { - var id, subject, snippet string + var id, messageID, subject, snippet string + var threadID *string var fromAddr, toAddrs []byte var date any var flags, labels []string var hasAttachments bool - if err := rows.Scan(&id, &subject, &fromAddr, &toAddrs, &date, &snippet, &flags, &labels, &hasAttachments); err != nil { + if err := rows.Scan(&id, &messageID, &threadID, &subject, &fromAddr, &toAddrs, &date, &snippet, &flags, &labels, &hasAttachments); err != nil { return MessagesList{}, err } - messages = append(messages, map[string]any{ - "id": id, "subject": subject, "from": json.RawMessage(fromAddr), + entry := map[string]any{ + "id": id, "message_id": messageID, "subject": subject, "from": json.RawMessage(fromAddr), "to": json.RawMessage(toAddrs), "date": date, "snippet": snippet, "flags": flags, "labels": labels, "has_attachments": hasAttachments, - }) + } + if threadID != nil { + entry["thread_id"] = *threadID + } + messages = append(messages, entry) } if err := rows.Err(); err != nil { return MessagesList{}, err @@ -233,43 +244,60 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me func (s *Service) GetMessage(ctx context.Context, externalID, messageID string) (map[string]any, error) { var msg struct { - ID string - Subject string - From []byte - To []byte - Cc []byte - Date any - Text string - HTML string - Flags []string - Labels []string + ID string + MessageID string + ThreadID *string + InReplyTo string + References []string + Subject string + From []byte + To []byte + Cc []byte + Date any + Text string + HTML string + Flags []string + Labels []string } err := s.db.QueryRow(ctx, ` - SELECT m.id, m.subject, m.from_addr, m.to_addrs, m.cc_addrs, m.date, m.body_text, m.body_html, m.flags, m.labels + SELECT m.id, m.message_id, m.thread_id, m.in_reply_to, m.references_header, + m.subject, m.from_addr, m.to_addrs, m.cc_addrs, m.date, + m.body_text, m.body_html, m.flags, m.labels FROM messages m JOIN mail_accounts ma ON m.account_id = ma.id WHERE m.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) - `, messageID, externalID).Scan(&msg.ID, &msg.Subject, &msg.From, &msg.To, &msg.Cc, &msg.Date, &msg.Text, &msg.HTML, &msg.Flags, &msg.Labels) + `, messageID, externalID).Scan( + &msg.ID, &msg.MessageID, &msg.ThreadID, &msg.InReplyTo, &msg.References, + &msg.Subject, &msg.From, &msg.To, &msg.Cc, &msg.Date, + &msg.Text, &msg.HTML, &msg.Flags, &msg.Labels, + ) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return nil, ErrNotFound } return nil, err } - return map[string]any{ - "id": msg.ID, "subject": msg.Subject, "from": json.RawMessage(msg.From), - "to": json.RawMessage(msg.To), "cc": json.RawMessage(msg.Cc), + out := map[string]any{ + "id": msg.ID, "message_id": msg.MessageID, "subject": msg.Subject, + "from": json.RawMessage(msg.From), "to": json.RawMessage(msg.To), "cc": json.RawMessage(msg.Cc), "date": msg.Date, "body_text": msg.Text, "body_html": msg.HTML, "flags": msg.Flags, "labels": msg.Labels, - }, nil + "in_reply_to": msg.InReplyTo, "references": msg.References, + } + if msg.ThreadID != nil { + out["thread_id"] = *msg.ThreadID + } + return out, nil } -func (s *Service) UpdateLabels(ctx context.Context, userID, messageID string, labels []string) error { +func (s *Service) UpdateLabels(ctx context.Context, externalID, messageID string, labels []string) error { result, err := s.db.Exec(ctx, ` - UPDATE messages + UPDATE messages m SET labels = $1, updated_at = NOW() - WHERE id = $2 - AND account_id IN (SELECT id FROM mail_accounts WHERE user_id = $3) - `, labels, messageID, userID) + FROM mail_accounts ma + WHERE m.id = $2 + AND m.account_id = ma.id + AND ma.user_id = (SELECT id FROM users WHERE external_id = $3) + `, labels, messageID, externalID) if err != nil { return err } @@ -279,13 +307,15 @@ func (s *Service) UpdateLabels(ctx context.Context, userID, messageID string, la return nil } -func (s *Service) UpdateFlags(ctx context.Context, userID, messageID string, flags []string) error { +func (s *Service) UpdateFlags(ctx context.Context, externalID, messageID string, flags []string) error { result, err := s.db.Exec(ctx, ` - UPDATE messages + UPDATE messages m SET flags = $1, updated_at = NOW() - WHERE id = $2 - AND account_id IN (SELECT id FROM mail_accounts WHERE user_id = $3) - `, flags, messageID, userID) + FROM mail_accounts ma + WHERE m.id = $2 + AND m.account_id = ma.id + AND ma.user_id = (SELECT id FROM users WHERE external_id = $3) + `, flags, messageID, externalID) if err != nil { return err } @@ -295,12 +325,14 @@ func (s *Service) UpdateFlags(ctx context.Context, userID, messageID string, fla return nil } -func (s *Service) DeleteMessage(ctx context.Context, externalID, userID, messageID string) error { +func (s *Service) DeleteMessage(ctx context.Context, externalID, messageID string) error { result, err := s.db.Exec(ctx, ` - DELETE FROM messages - WHERE id = $1 - AND account_id IN (SELECT id FROM mail_accounts WHERE user_id = $2) - `, messageID, userID) + DELETE FROM messages m + USING mail_accounts ma + WHERE m.id = $1 + AND m.account_id = ma.id + AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) + `, messageID, externalID) if err != nil { return err } @@ -347,23 +379,57 @@ func (s *Service) GetThread(ctx context.Context, externalID, threadID string) (m return map[string]any{"thread_id": threadID, "messages": messages}, nil } +type replyParent struct { + MessageID string + References []string +} + +func (s *Service) loadReplyParent(ctx context.Context, userID, replyToMessageID string) (*replyParent, error) { + var parent replyParent + err := s.db.QueryRow(ctx, ` + SELECT m.message_id, m.references_header + FROM messages m + JOIN mail_accounts ma ON m.account_id = ma.id + WHERE m.id = $1 AND ma.user_id = $2 + `, replyToMessageID, userID).Scan(&parent.MessageID, &parent.References) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil, ErrNotFound + } + return nil, err + } + return &parent, nil +} + func (s *Service) SendMessage(ctx context.Context, userID string, req *sendMessageRequest) (id, status string, err error) { toJSON, _ := json.Marshal(req.To) ccJSON, _ := json.Marshal(req.Cc) bccJSON, _ := json.Marshal(req.Bcc) + inReplyTo := threading.NormalizeMessageID(req.InReplyTo) + var references []string + + if req.ReplyToMessageID != "" { + parent, err := s.loadReplyParent(ctx, userID, req.ReplyToMessageID) + if err != nil { + return "", "", err + } + inReplyTo = threading.NormalizeMessageID(parent.MessageID) + references = threading.BuildReferences(parent.MessageID, parent.References) + } + status = "queued" if req.ScheduleAt != nil { status = "scheduled" } err = s.db.QueryRow(ctx, ` - INSERT INTO outbox (user_id, account_id, to_addrs, cc_addrs, bcc_addrs, subject, body_text, body_html, in_reply_to, status, scheduled_at) - SELECT $1, ma.id, $3, $4, $5, $6, $7, $8, $9, $10, $11 + INSERT INTO outbox (user_id, account_id, to_addrs, cc_addrs, bcc_addrs, subject, body_text, body_html, in_reply_to, references_header, status, scheduled_at) + SELECT $1, ma.id, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 FROM mail_accounts ma WHERE ma.id = $2 AND ma.user_id = $1 RETURNING id - `, userID, req.AccountID, toJSON, ccJSON, bccJSON, req.Subject, req.BodyText, req.BodyHTML, req.InReplyTo, status, req.ScheduleAt).Scan(&id) + `, userID, req.AccountID, toJSON, ccJSON, bccJSON, req.Subject, req.BodyText, req.BodyHTML, inReplyTo, references, status, req.ScheduleAt).Scan(&id) if err != nil { if errors.Is(err, pgx.ErrNoRows) { return "", "", ErrAccountNotFound @@ -453,14 +519,14 @@ func (s *Service) CreateRule(ctx context.Context, userID string, req *createRule return id, nil } -func (s *Service) UpdateRule(ctx context.Context, userID, ruleID string, req *updateRuleRequest) error { +func (s *Service) UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error { condJSON, _ := json.Marshal(req.Conditions) actJSON, _ := json.Marshal(req.Actions) result, err := s.db.Exec(ctx, ` UPDATE mail_rules SET name=$1, priority=$2, is_active=$3, conditions=$4, actions=$5, updated_at=NOW() - WHERE id=$6 AND user_id=$7 - `, req.Name, req.Priority, req.IsActive, condJSON, actJSON, ruleID, userID) + WHERE id=$6 AND user_id=(SELECT id FROM users WHERE external_id=$7) + `, req.Name, req.Priority, req.IsActive, condJSON, actJSON, ruleID, externalID) if err != nil { return err } @@ -470,8 +536,11 @@ func (s *Service) UpdateRule(ctx context.Context, userID, ruleID string, req *up return nil } -func (s *Service) DeleteRule(ctx context.Context, externalID, userID, ruleID string) error { - result, err := s.db.Exec(ctx, `DELETE FROM mail_rules WHERE id = $1 AND user_id = $2`, ruleID, userID) +func (s *Service) DeleteRule(ctx context.Context, externalID, ruleID string) error { + result, err := s.db.Exec(ctx, ` + DELETE FROM mail_rules + WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2) + `, ruleID, externalID) if err != nil { return err } @@ -548,8 +617,11 @@ func (s *Service) CreateWebhook(ctx context.Context, externalID string, req *cre return id, nil } -func (s *Service) DeleteWebhook(ctx context.Context, externalID, userID, webhookID string) error { - result, err := s.db.Exec(ctx, `DELETE FROM webhook_templates WHERE id = $1 AND user_id = $2`, webhookID, userID) +func (s *Service) DeleteWebhook(ctx context.Context, externalID, webhookID string) error { + result, err := s.db.Exec(ctx, ` + DELETE FROM webhook_templates + WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2) + `, webhookID, externalID) if err != nil { return err } diff --git a/internal/api/mail/service_iface.go b/internal/api/mail/service_iface.go index df67a7b..78d9cfa 100644 --- a/internal/api/mail/service_iface.go +++ b/internal/api/mail/service_iface.go @@ -2,6 +2,7 @@ package mail import ( "context" + "io" "github.com/ultisuite/ulti-backend/internal/api/query" ) @@ -15,18 +16,44 @@ type ServiceAPI interface { DeleteAccount(ctx context.Context, externalID, accountID string) error ListMessages(ctx context.Context, externalID string, filter MessageListFilter, params query.ListParams) (MessagesList, error) GetMessage(ctx context.Context, externalID, messageID string) (map[string]any, error) - UpdateLabels(ctx context.Context, userID, messageID string, labels []string) error - UpdateFlags(ctx context.Context, userID, messageID string, flags []string) error - DeleteMessage(ctx context.Context, externalID, userID, messageID string) error + UpdateLabels(ctx context.Context, externalID, messageID string, labels []string) error + UpdateFlags(ctx context.Context, externalID, messageID string, flags []string) error + DeleteMessage(ctx context.Context, externalID, messageID string) error GetThread(ctx context.Context, externalID, threadID string) (map[string]any, error) SendMessage(ctx context.Context, userID string, req *sendMessageRequest) (id, status string, err error) + ListDrafts(ctx context.Context, externalID string, params query.ListParams) (DraftsList, error) + GetDraft(ctx context.Context, externalID, draftID string) (map[string]any, error) + CreateDraft(ctx context.Context, userID string, req *draftRequest) (string, error) + UpdateDraft(ctx context.Context, externalID, draftID string, req *draftRequest) error + DeleteDraft(ctx context.Context, externalID, draftID string) error ListRules(ctx context.Context, externalID string, params query.ListParams) (RulesList, error) CreateRule(ctx context.Context, userID string, req *createRuleRequest) (string, error) - UpdateRule(ctx context.Context, userID, ruleID string, req *updateRuleRequest) error - DeleteRule(ctx context.Context, externalID, userID, ruleID string) error + UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error + DeleteRule(ctx context.Context, externalID, ruleID string) error ListWebhooks(ctx context.Context, externalID string, params query.ListParams) (WebhooksList, error) CreateWebhook(ctx context.Context, externalID string, req *createWebhookRequest, method string) (string, error) - DeleteWebhook(ctx context.Context, externalID, userID, webhookID string) error + DeleteWebhook(ctx context.Context, externalID, webhookID string) error + ListIdentities(ctx context.Context, externalID, accountID string, params query.ListParams) (IdentitiesList, error) + GetIdentity(ctx context.Context, externalID, identityID string) (map[string]any, error) + CreateIdentity(ctx context.Context, externalID, accountID string, req *createIdentityRequest) (string, error) + UpdateIdentity(ctx context.Context, externalID, identityID string, req *updateIdentityRequest) error + DeleteIdentity(ctx context.Context, externalID, identityID string) error + ListFolders(ctx context.Context, externalID, accountID string, params query.ListParams) (FoldersList, error) + GetFolder(ctx context.Context, externalID, folderID string) (map[string]any, error) + CreateFolder(ctx context.Context, userID string, req *createFolderRequest) (string, error) + UpdateFolder(ctx context.Context, externalID, folderID string, req *updateFolderRequest) error + DeleteFolder(ctx context.Context, externalID, folderID string) error + ListUserLabels(ctx context.Context, externalID string, params query.ListParams) (UserLabelsList, error) + CreateUserLabel(ctx context.Context, externalID string, req *createUserLabelRequest) (string, error) + UpdateUserLabel(ctx context.Context, externalID, labelID string, req *updateUserLabelRequest) error + DeleteUserLabel(ctx context.Context, externalID, labelID string) error + SearchMessages(ctx context.Context, externalID string, filter MessageSearchFilter, params query.ListParams) (MessageSearchResult, error) + ListMessageAttachments(ctx context.Context, externalID, messageID string) ([]map[string]any, error) + MessageAttachmentCIDMap(ctx context.Context, externalID, messageID string) (map[string]string, error) + UploadMessageAttachment(ctx context.Context, externalID, messageID, filename, contentType, contentID string, isInline bool, reader io.Reader, size int64) (string, error) + OpenAttachment(ctx context.Context, externalID, attachmentID string) (filename, contentType string, size int64, isInline bool, body io.ReadCloser, err error) + UploadDraftAttachment(ctx context.Context, externalID, draftID, filename, contentType, contentID string, isInline bool, reader io.Reader, size int64) (string, error) + OpenDraftAttachment(ctx context.Context, externalID, draftID, attachmentID string) (filename, contentType string, body io.ReadCloser, err error) } var _ ServiceAPI = (*Service)(nil) diff --git a/internal/api/mail/validate.go b/internal/api/mail/validate.go index b33ad13..c15140f 100644 --- a/internal/api/mail/validate.go +++ b/internal/api/mail/validate.go @@ -181,15 +181,16 @@ func validateCreateAccount(req *createAccountRequest) *apivalidate.ValidationErr } type sendMessageRequest struct { - AccountID string `json:"account_id"` - To []string `json:"to"` - Cc []string `json:"cc"` - Bcc []string `json:"bcc"` - Subject string `json:"subject"` - BodyText string `json:"body_text"` - BodyHTML string `json:"body_html"` - InReplyTo string `json:"in_reply_to"` - ScheduleAt *string `json:"schedule_at"` + AccountID string `json:"account_id"` + To []string `json:"to"` + Cc []string `json:"cc"` + Bcc []string `json:"bcc"` + Subject string `json:"subject"` + BodyText string `json:"body_text"` + BodyHTML string `json:"body_html"` + InReplyTo string `json:"in_reply_to"` + ReplyToMessageID string `json:"reply_to_message_id"` + ScheduleAt *string `json:"schedule_at"` } func validateSendMessage(req *sendMessageRequest) *apivalidate.ValidationError { diff --git a/internal/api/mail/validate_drafts.go b/internal/api/mail/validate_drafts.go new file mode 100644 index 0000000..e80f4ae --- /dev/null +++ b/internal/api/mail/validate_drafts.go @@ -0,0 +1,92 @@ +package mail + +import ( + "encoding/json" + "strconv" + "strings" + + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" +) + +type draftRequest struct { + AccountID string `json:"account_id"` + IdentityID string `json:"identity_id"` + To []string `json:"to"` + Cc []string `json:"cc"` + Bcc []string `json:"bcc"` + Subject string `json:"subject"` + BodyText string `json:"body_text"` + BodyHTML string `json:"body_html"` + InReplyTo string `json:"in_reply_to"` + Attachments any `json:"attachments"` +} + +func validateDraftRecipients(req *draftRequest) []apivalidate.FieldDetail { + var details []apivalidate.FieldDetail + for i, addr := range req.To { + if d := validateRecipient(addr); d != nil { + d.Field = "to[" + strconv.Itoa(i) + "]" + details = append(details, *d) + } + } + for i, addr := range req.Cc { + if d := validateRecipient(addr); d != nil { + d.Field = "cc[" + strconv.Itoa(i) + "]" + details = append(details, *d) + } + } + for i, addr := range req.Bcc { + if d := validateRecipient(addr); d != nil { + d.Field = "bcc[" + strconv.Itoa(i) + "]" + details = append(details, *d) + } + } + return details +} + +func validateDraftContent(req *draftRequest) []apivalidate.FieldDetail { + var details []apivalidate.FieldDetail + if len(req.Subject) > maxSubjectLen { + details = append(details, apivalidate.FieldDetail{Field: "subject", Message: "too long"}) + } + if len(req.BodyText) > maxBodyField { + details = append(details, apivalidate.FieldDetail{Field: "body_text", Message: "too long"}) + } + if len(req.BodyHTML) > maxBodyField { + details = append(details, apivalidate.FieldDetail{Field: "body_html", Message: "too long"}) + } + if req.InReplyTo != "" && len(req.InReplyTo) > 998 { + details = append(details, apivalidate.FieldDetail{Field: "in_reply_to", Message: "too long"}) + } + if req.Attachments != nil { + if b, err := json.Marshal(req.Attachments); err != nil { + details = append(details, apivalidate.FieldDetail{Field: "attachments", Message: "invalid"}) + } else if len(b) > maxSendRequestBody { + details = append(details, apivalidate.FieldDetail{Field: "attachments", Message: "too large"}) + } + } + return details +} + +func validateCreateDraft(req *draftRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if strings.TrimSpace(req.AccountID) == "" { + details = append(details, apivalidate.FieldDetail{Field: "account_id", Message: "required"}) + } + details = append(details, validateDraftRecipients(req)...) + details = append(details, validateDraftContent(req)...) + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateUpdateDraft(req *draftRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + details = append(details, validateDraftRecipients(req)...) + details = append(details, validateDraftContent(req)...) + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} diff --git a/internal/api/mail/validate_folders_labels.go b/internal/api/mail/validate_folders_labels.go new file mode 100644 index 0000000..4233fbb --- /dev/null +++ b/internal/api/mail/validate_folders_labels.go @@ -0,0 +1,192 @@ +package mail + +import ( + "strings" + + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" +) + +const ( + maxFoldersRequestBody = 32 << 10 // 32 KiB + maxLabelsRequestBody = 8 << 10 // 8 KiB + maxFolderName = 256 + maxRemoteName = 512 + maxLabelName = 128 + maxLabelColor = 32 +) + +var allowedFolderTypes = map[string]struct{}{ + "inbox": {}, + "sent": {}, + "drafts": {}, + "trash": {}, + "archive": {}, + "spam": {}, + "custom": {}, +} + +type createFolderRequest struct { + AccountID string `json:"account_id"` + Name string `json:"name"` + RemoteName string `json:"remote_name"` + FolderType string `json:"folder_type"` +} + +type updateFolderRequest struct { + Name string `json:"name"` + RemoteName string `json:"remote_name"` + FolderType string `json:"folder_type"` +} + +type createUserLabelRequest struct { + Name string `json:"name"` + Color string `json:"color"` +} + +type updateUserLabelRequest struct { + Name string `json:"name"` + Color string `json:"color"` +} + +func validateFolderType(field, folderType string) *apivalidate.FieldDetail { + folderType = strings.TrimSpace(strings.ToLower(folderType)) + if folderType == "" { + return nil + } + if _, ok := allowedFolderTypes[folderType]; !ok { + return &apivalidate.FieldDetail{Field: field, Message: "invalid"} + } + return nil +} + +func normalizeFolderType(folderType string) string { + folderType = strings.TrimSpace(strings.ToLower(folderType)) + if folderType == "" { + return "custom" + } + return folderType +} + +func validateFolderName(field, name string, required bool) *apivalidate.FieldDetail { + name = strings.TrimSpace(name) + if name == "" { + if required { + return &apivalidate.FieldDetail{Field: field, Message: "required"} + } + return nil + } + if len(name) > maxFolderName || containsNewline(name) { + return &apivalidate.FieldDetail{Field: field, Message: "invalid"} + } + return nil +} + +func validateRemoteName(field, remoteName string, required bool) *apivalidate.FieldDetail { + remoteName = strings.TrimSpace(remoteName) + if remoteName == "" { + if required { + return &apivalidate.FieldDetail{Field: field, Message: "required"} + } + return nil + } + if len(remoteName) > maxRemoteName || containsNewline(remoteName) { + return &apivalidate.FieldDetail{Field: field, Message: "invalid"} + } + return nil +} + +func validateCreateFolder(req *createFolderRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if strings.TrimSpace(req.AccountID) == "" { + details = append(details, apivalidate.FieldDetail{Field: "account_id", Message: "required"}) + } + if d := validateFolderName("name", req.Name, true); d != nil { + details = append(details, *d) + } + if d := validateRemoteName("remote_name", req.RemoteName, false); d != nil { + details = append(details, *d) + } + if d := validateFolderType("folder_type", req.FolderType); d != nil { + details = append(details, *d) + } + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateUpdateFolder(req *updateFolderRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if d := validateFolderName("name", req.Name, true); d != nil { + details = append(details, *d) + } + if d := validateRemoteName("remote_name", req.RemoteName, true); d != nil { + details = append(details, *d) + } + if d := validateFolderType("folder_type", req.FolderType); d != nil { + details = append(details, *d) + } + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateLabelName(field, name string, required bool) *apivalidate.FieldDetail { + name = strings.TrimSpace(name) + if name == "" { + if required { + return &apivalidate.FieldDetail{Field: field, Message: "required"} + } + return nil + } + if len(name) > maxLabelName || containsNewline(name) { + return &apivalidate.FieldDetail{Field: field, Message: "invalid"} + } + return nil +} + +func validateLabelColor(color string) *apivalidate.FieldDetail { + color = strings.TrimSpace(color) + if len(color) > maxLabelColor || containsNewline(color) { + return &apivalidate.FieldDetail{Field: "color", Message: "invalid"} + } + return nil +} + +func validateCreateUserLabel(req *createUserLabelRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if d := validateLabelName("name", req.Name, true); d != nil { + details = append(details, *d) + } + if d := validateLabelColor(req.Color); d != nil { + details = append(details, *d) + } + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateUpdateUserLabel(req *updateUserLabelRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if d := validateLabelName("name", req.Name, true); d != nil { + details = append(details, *d) + } + if d := validateLabelColor(req.Color); d != nil { + details = append(details, *d) + } + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateListFoldersAccountID(accountID string) *apivalidate.ValidationError { + if strings.TrimSpace(accountID) == "" { + return apivalidate.NewValidationError(apivalidate.FieldDetail{ + Field: "account_id", Message: "required", + }) + } + return nil +} diff --git a/internal/api/mail/validate_identities.go b/internal/api/mail/validate_identities.go new file mode 100644 index 0000000..5e1a7fc --- /dev/null +++ b/internal/api/mail/validate_identities.go @@ -0,0 +1,91 @@ +package mail + +import ( + "strconv" + "strings" + + "github.com/ultisuite/ulti-backend/internal/api/apivalidate" +) + +const ( + maxIdentityRequestBody = 256 << 10 // 256 KiB + maxIdentityName = 128 + maxSignatureHTML = 64 << 10 // 64 KiB + maxReplyToAddrs = 10 +) + +type createIdentityRequest struct { + Email string `json:"email"` + Name string `json:"name"` + IsDefault bool `json:"is_default"` + SignatureHTML string `json:"signature_html"` + ReplyToAddrs []string `json:"reply_to_addrs"` +} + +type updateIdentityRequest struct { + Email string `json:"email"` + Name string `json:"name"` + IsDefault bool `json:"is_default"` + SignatureHTML string `json:"signature_html"` + ReplyToAddrs []string `json:"reply_to_addrs"` +} + +func validateReplyToAddrs(field string, addrs []string) []apivalidate.FieldDetail { + if len(addrs) > maxReplyToAddrs { + return []apivalidate.FieldDetail{{ + Field: field, Message: "too many entries", + }} + } + var details []apivalidate.FieldDetail + for i, addr := range addrs { + if d := validateRecipient(addr); d != nil { + d.Field = field + "[" + strconv.Itoa(i) + "]" + details = append(details, *d) + } + } + return details +} + +func validateCreateIdentity(req *createIdentityRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if d := validateEmailField("email", req.Email); d != nil { + details = append(details, *d) + } + if req.Name != "" && len(req.Name) > maxIdentityName { + details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"}) + } + if len(req.SignatureHTML) > maxSignatureHTML { + details = append(details, apivalidate.FieldDetail{Field: "signature_html", Message: "too long"}) + } + if req.ReplyToAddrs == nil { + req.ReplyToAddrs = []string{} + } + details = append(details, validateReplyToAddrs("reply_to_addrs", req.ReplyToAddrs)...) + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} + +func validateUpdateIdentity(req *updateIdentityRequest) *apivalidate.ValidationError { + var details []apivalidate.FieldDetail + if d := validateEmailField("email", req.Email); d != nil { + details = append(details, *d) + } + if strings.TrimSpace(req.Name) == "" { + details = append(details, apivalidate.FieldDetail{Field: "name", Message: "required"}) + } else if len(req.Name) > maxIdentityName { + details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"}) + } + if len(req.SignatureHTML) > maxSignatureHTML { + details = append(details, apivalidate.FieldDetail{Field: "signature_html", Message: "too long"}) + } + if req.ReplyToAddrs == nil { + req.ReplyToAddrs = []string{} + } + details = append(details, validateReplyToAddrs("reply_to_addrs", req.ReplyToAddrs)...) + if len(details) == 0 { + return nil + } + return apivalidate.NewValidationError(details...) +} diff --git a/internal/api/middleware/auth.go b/internal/api/middleware/auth.go index eeb6ff2..7e70680 100644 --- a/internal/api/middleware/auth.go +++ b/internal/api/middleware/auth.go @@ -2,19 +2,23 @@ package middleware import ( "context" + "log/slog" "net/http" "strings" + "github.com/jackc/pgx/v5/pgxpool" + "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/securityaudit" + "github.com/ultisuite/ulti-backend/internal/users" ) type ctxKey string const claimsKey ctxKey = "claims" -func Auth(verifier *auth.Verifier, audit *securityaudit.Logger) func(http.Handler) http.Handler { +func Auth(verifier *auth.Verifier, db *pgxpool.Pool, audit *securityaudit.Logger) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if verifier == nil { @@ -68,6 +72,14 @@ func Auth(verifier *auth.Verifier, audit *securityaudit.Logger) func(http.Handle return } + if db != nil { + if _, err := users.EnsureUser(r.Context(), db, claims); err != nil { + slog.Error("provision user", "sub", claims.Sub, "error", err) + apiresponse.WriteError(w, r, http.StatusInternalServerError, apiresponse.CodeInternal, "failed to provision user", nil) + return + } + } + if audit != nil { audit.Log(r.Context(), claims.Sub, securityaudit.ActionLogin, map[string]any{ "email": claims.Email, diff --git a/internal/mail/imap/parse.go b/internal/mail/imap/parse.go index f6beb4b..41b19a9 100644 --- a/internal/mail/imap/parse.go +++ b/internal/mail/imap/parse.go @@ -10,6 +10,8 @@ import ( "strings" imapTypes "github.com/emersion/go-imap/v2" + + "github.com/ultisuite/ulti-backend/internal/mail/threading" ) type EmailAddress struct { @@ -91,3 +93,16 @@ func parseMultipart(r io.Reader, boundary string) (text string, html string) { } return text, html } + +func parseThreadHeaders(raw []byte) (references []string, inReplyTo string) { + if len(raw) == 0 { + return nil, "" + } + msg, err := mail.ReadMessage(bytes.NewReader(raw)) + if err != nil { + return nil, "" + } + refs := msg.Header.Get("References") + irt := strings.TrimSpace(msg.Header.Get("In-Reply-To")) + return threading.ParseMessageIDs(refs), threading.NormalizeMessageID(irt) +} diff --git a/internal/mail/imap/sync.go b/internal/mail/imap/sync.go index 4a56efe..5fa091c 100644 --- a/internal/mail/imap/sync.go +++ b/internal/mail/imap/sync.go @@ -10,8 +10,10 @@ import ( "github.com/emersion/go-imap/v2" "github.com/emersion/go-imap/v2/imapclient" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/mail/credentials" + "github.com/ultisuite/ulti-backend/internal/mail/threading" "github.com/ultisuite/ulti-backend/internal/observability" ) @@ -254,13 +256,36 @@ func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMe bodyText, bodyHTML := parseBody(bodyContent) snippet := truncate(bodyText, 200) - _, err := w.db.Exec(ctx, ` - INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, cc_addrs, date, snippet, body_text, body_html, flags, in_reply_to) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14) - ON CONFLICT (folder_id, uid) DO NOTHING - `, accountID, folderID, uid, envelope.MessageID, envelope.Subject, - fromAddr, toAddrs, ccAddrs, envelope.Date, snippet, bodyText, bodyHTML, flagStrs, strings.Join(envelope.InReplyTo, " ")) + headerRefs, headerInReplyTo := parseThreadHeaders(bodyContent) + inReplyTo := headerInReplyTo + if inReplyTo == "" && len(envelope.InReplyTo) > 0 { + inReplyTo = threading.NormalizeMessageID(envelope.InReplyTo[0]) + } + references := headerRefs + if len(references) == 0 { + references = threading.ParseMessageIDs(strings.Join(envelope.InReplyTo, " ")) + } + var rowID string + err := w.db.QueryRow(ctx, ` + INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, cc_addrs, date, snippet, body_text, body_html, flags, in_reply_to, references_header) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) + ON CONFLICT (folder_id, uid) DO NOTHING + RETURNING id + `, accountID, folderID, uid, envelope.MessageID, envelope.Subject, + fromAddr, toAddrs, ccAddrs, envelope.Date, snippet, bodyText, bodyHTML, flagStrs, inReplyTo, references).Scan(&rowID) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return nil + } + return err + } + + threadID, err := threading.AssignThreadID(ctx, w.db, accountID, inReplyTo, references) + if err != nil { + return err + } + _, err = w.db.Exec(ctx, `UPDATE messages SET thread_id = $1, updated_at = NOW() WHERE id = $2`, threadID, rowID) return err } diff --git a/internal/mail/smtp/outbox.go b/internal/mail/smtp/outbox.go index 8cf4538..e196897 100644 --- a/internal/mail/smtp/outbox.go +++ b/internal/mail/smtp/outbox.go @@ -38,8 +38,8 @@ func (p *OutboxProcessor) Start(ctx context.Context) { p.logger.Info("outbox processor stopped") return case <-ticker.C: - p.processQueue(ctx) p.processScheduled(ctx) + p.processQueue(ctx) p.updateQueueDepth(ctx) } } @@ -88,15 +88,16 @@ func (p *OutboxProcessor) processQueue(ctx context.Context) { // Get the from address var fromEmail string - _ = p.db.QueryRow(ctx, ` + if err := p.db.QueryRow(ctx, ` SELECT mi.email FROM mail_identities mi JOIN mail_accounts ma ON mi.account_id = ma.id WHERE ma.id = $1 AND mi.is_default = true LIMIT 1 - `, accountID).Scan(&fromEmail) - - if fromEmail == "" { - _ = p.db.QueryRow(ctx, `SELECT email FROM mail_accounts WHERE id = $1`, accountID).Scan(&fromEmail) + `, accountID).Scan(&fromEmail); err != nil || fromEmail == "" { + if err := p.db.QueryRow(ctx, `SELECT email FROM mail_accounts WHERE id = $1`, accountID).Scan(&fromEmail); err != nil { + p.logger.Error("resolve from address", "outbox_id", id, "account_id", accountID, "error", err) + continue + } } req := &SendRequest{ @@ -115,27 +116,38 @@ func (p *OutboxProcessor) processQueue(ctx context.Context) { if err := p.sender.Send(ctx, req); err != nil { p.logger.Error("send failed", "outbox_id", id, "error", err) observability.IncOutboxProcessed("error") - _, _ = p.db.Exec(ctx, ` + if _, execErr := p.db.Exec(ctx, ` UPDATE outbox SET status = 'queued', retry_count = retry_count + 1, error = $2, updated_at = NOW() WHERE id = $1 - `, id, err.Error()) + `, id, err.Error()); execErr != nil { + p.logger.Error("failed to mark outbox retry", "outbox_id", id, "error", execErr) + } } else { observability.IncOutboxProcessed("success") - _, _ = p.db.Exec(ctx, ` + if _, execErr := p.db.Exec(ctx, ` UPDATE outbox SET status = 'sent', sent_at = NOW(), updated_at = NOW() WHERE id = $1 - `, id) + `, id); execErr != nil { + p.logger.Error("failed to mark outbox sent", "outbox_id", id, "error", execErr) + } } } + if err := rows.Err(); err != nil { + p.logger.Error("iterate outbox rows", "error", err) + } } func (p *OutboxProcessor) processScheduled(ctx context.Context) { - _, err := p.db.Exec(ctx, ` + result, err := p.db.Exec(ctx, ` UPDATE outbox SET status = 'queued', updated_at = NOW() - WHERE status = 'queued' AND scheduled_at IS NOT NULL AND scheduled_at <= NOW() + WHERE status = 'scheduled' AND scheduled_at IS NOT NULL AND scheduled_at <= NOW() `) if err != nil { p.logger.Error("failed to process scheduled", "error", err) + return + } + if n := result.RowsAffected(); n > 0 { + p.logger.Info("promoted scheduled outbox rows", "count", n) } } diff --git a/internal/mail/smtp/outbox_test.go b/internal/mail/smtp/outbox_test.go index 67a79a3..9cc02d5 100644 --- a/internal/mail/smtp/outbox_test.go +++ b/internal/mail/smtp/outbox_test.go @@ -2,6 +2,7 @@ package smtp import ( "reflect" + "strings" "testing" ) @@ -35,6 +36,20 @@ func TestParseJSONAddresses_empty(t *testing.T) { } } +func TestScheduledPromotionSQLUsesScheduledStatus(t *testing.T) { + const want = "status = 'scheduled'" + sql := ` + UPDATE outbox SET status = 'queued', updated_at = NOW() + WHERE status = 'scheduled' AND scheduled_at IS NOT NULL AND scheduled_at <= NOW() + ` + if !strings.Contains(sql, want) { + t.Fatalf("scheduled promotion SQL must filter %q", want) + } + if strings.Contains(sql, "status = 'queued' AND scheduled_at") { + t.Fatal("scheduled promotion must not match queued rows with scheduled_at") + } +} + func TestParseJSONAddresses_invalid(t *testing.T) { got := parseJSONAddresses([]byte(`not-json`)) if got != nil { diff --git a/internal/mail/smtp/sender.go b/internal/mail/smtp/sender.go index 1108d5b..74c10ec 100644 --- a/internal/mail/smtp/sender.go +++ b/internal/mail/smtp/sender.go @@ -2,6 +2,8 @@ package smtp import ( "context" + "crypto/rand" + "encoding/hex" "errors" "fmt" "log/slog" @@ -95,6 +97,7 @@ func buildMessage(req *SendRequest) string { } b.WriteString(fmt.Sprintf("Subject: %s\r\n", req.Subject)) b.WriteString(fmt.Sprintf("Date: %s\r\n", time.Now().Format(time.RFC1123Z))) + b.WriteString(fmt.Sprintf("Message-ID: %s\r\n", generateMessageID(req.From))) b.WriteString("MIME-Version: 1.0\r\n") if req.InReplyTo != "" { @@ -123,6 +126,18 @@ func buildMessage(req *SendRequest) string { return b.String() } +func generateMessageID(from string) string { + domain := "ultimail.local" + if i := strings.LastIndex(from, "@"); i >= 0 && i < len(from)-1 { + domain = from[i+1:] + } + token := make([]byte, 16) + if _, err := rand.Read(token); err != nil { + token = []byte(fmt.Sprintf("%d", time.Now().UnixNano())) + } + return fmt.Sprintf("<%s@%s>", hex.EncodeToString(token), domain) +} + func (s *Sender) parseCredentials(creds []byte) (string, string, error) { if len(creds) == 0 { return "", "", errors.New("missing credentials") diff --git a/internal/mail/storage/client.go b/internal/mail/storage/client.go new file mode 100644 index 0000000..c2464fd --- /dev/null +++ b/internal/mail/storage/client.go @@ -0,0 +1,59 @@ +package storage + +import ( + "context" + "fmt" + "io" + "net/url" + "time" + + "github.com/google/uuid" + "github.com/minio/minio-go/v7" +) + +type Client struct { + mc *minio.Client + bucket string +} + +func NewClient(mc *minio.Client, bucket string) *Client { + return &Client{mc: mc, bucket: bucket} +} + +func (c *Client) EnsureBucket(ctx context.Context) error { + exists, err := c.mc.BucketExists(ctx, c.bucket) + if err != nil { + return err + } + if !exists { + return c.mc.MakeBucket(ctx, c.bucket, minio.MakeBucketOptions{}) + } + return nil +} + +func (c *Client) Put(ctx context.Context, objectKey string, reader io.Reader, size int64, contentType string) error { + _, err := c.mc.PutObject(ctx, c.bucket, objectKey, reader, size, minio.PutObjectOptions{ + ContentType: contentType, + }) + return err +} + +func (c *Client) Get(ctx context.Context, objectKey string) (*minio.Object, error) { + return c.mc.GetObject(ctx, c.bucket, objectKey, minio.GetObjectOptions{}) +} + +func (c *Client) Delete(ctx context.Context, objectKey string) error { + return c.mc.RemoveObject(ctx, c.bucket, objectKey, minio.RemoveObjectOptions{}) +} + +func (c *Client) PresignedGet(ctx context.Context, objectKey string, expiry time.Duration) (*url.URL, error) { + return c.mc.PresignedGetObject(ctx, c.bucket, objectKey, expiry, nil) +} + +func MessageObjectKey(userID, messageID, filename string) string { + return fmt.Sprintf("%s/messages/%s/%s/%s", userID, messageID, uuid.NewString(), filename) +} + +func DraftObjectKey(userID, draftID, filename string) string { + return fmt.Sprintf("%s/drafts/%s/%s/%s", userID, draftID, uuid.NewString(), filename) +} diff --git a/internal/mail/threading/threading.go b/internal/mail/threading/threading.go new file mode 100644 index 0000000..b69cc55 --- /dev/null +++ b/internal/mail/threading/threading.go @@ -0,0 +1,122 @@ +package threading + +import ( + "context" + "errors" + "regexp" + "strings" + + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +var messageIDToken = regexp.MustCompile(`<[^>]+>`) + +// NormalizeMessageID returns a canonical angle-bracket Message-ID when possible. +func NormalizeMessageID(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if strings.HasPrefix(raw, "<") && strings.HasSuffix(raw, ">") { + return raw + } + return "<" + strings.Trim(raw, "<>") + ">" +} + +// ParseMessageIDs extracts Message-IDs from a References or In-Reply-To header value. +func ParseMessageIDs(header string) []string { + header = strings.TrimSpace(header) + if header == "" { + return nil + } + matches := messageIDToken.FindAllString(header, -1) + if len(matches) == 0 { + if id := NormalizeMessageID(header); id != "" { + return []string{id} + } + return nil + } + seen := make(map[string]struct{}, len(matches)) + out := make([]string, 0, len(matches)) + for _, m := range matches { + id := NormalizeMessageID(m) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + return out +} + +// BuildReferences returns the References chain for a reply (ancestors + parent). +func BuildReferences(parentMessageID string, parentReferences []string) []string { + parentMessageID = NormalizeMessageID(parentMessageID) + if parentMessageID == "" { + return nil + } + seen := make(map[string]struct{}, len(parentReferences)+1) + out := make([]string, 0, len(parentReferences)+1) + for _, ref := range parentReferences { + id := NormalizeMessageID(ref) + if id == "" { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + out = append(out, id) + } + if _, ok := seen[parentMessageID]; !ok { + out = append(out, parentMessageID) + } + return out +} + +func candidateMessageIDs(inReplyTo string, references []string) []string { + seen := make(map[string]struct{}) + var out []string + add := func(id string) { + id = NormalizeMessageID(id) + if id == "" { + return + } + if _, ok := seen[id]; ok { + return + } + seen[id] = struct{}{} + out = append(out, id) + } + add(inReplyTo) + for _, ref := range references { + add(ref) + } + return out +} + +// AssignThreadID picks an existing thread for the account or allocates a new one. +func AssignThreadID(ctx context.Context, db *pgxpool.Pool, accountID, inReplyTo string, references []string) (string, error) { + ids := candidateMessageIDs(inReplyTo, references) + if len(ids) > 0 { + var threadID *uuid.UUID + err := db.QueryRow(ctx, ` + SELECT thread_id FROM messages + WHERE account_id = $1 AND message_id = ANY($2) AND thread_id IS NOT NULL + ORDER BY date ASC + LIMIT 1 + `, accountID, ids).Scan(&threadID) + if err == nil && threadID != nil { + return threadID.String(), nil + } + if err != nil && !errors.Is(err, pgx.ErrNoRows) { + return "", err + } + } + return uuid.New().String(), nil +} diff --git a/internal/mail/threading/threading_test.go b/internal/mail/threading/threading_test.go new file mode 100644 index 0000000..69f4704 --- /dev/null +++ b/internal/mail/threading/threading_test.go @@ -0,0 +1,45 @@ +package threading + +import ( + "reflect" + "testing" +) + +func TestNormalizeMessageID(t *testing.T) { + tests := []struct { + in, want string + }{ + {"", ""}, + {"", ""}, + {"a@b", ""}, + } + for _, tc := range tests { + if got := NormalizeMessageID(tc.in); got != tc.want { + t.Fatalf("NormalizeMessageID(%q) = %q, want %q", tc.in, got, tc.want) + } + } +} + +func TestParseMessageIDs(t *testing.T) { + got := ParseMessageIDs(" ") + want := []string{"", ""} + if !reflect.DeepEqual(got, want) { + t.Fatalf("ParseMessageIDs() = %v, want %v", got, want) + } +} + +func TestBuildReferences(t *testing.T) { + got := BuildReferences("", []string{"", ""}) + want := []string{"", "", ""} + if !reflect.DeepEqual(got, want) { + t.Fatalf("BuildReferences() = %v, want %v", got, want) + } +} + +func TestBuildReferences_dedupesParent(t *testing.T) { + got := BuildReferences("", []string{"", ""}) + want := []string{"", ""} + if !reflect.DeepEqual(got, want) { + t.Fatalf("BuildReferences() = %v, want %v", got, want) + } +} diff --git a/internal/users/provision.go b/internal/users/provision.go new file mode 100644 index 0000000..410e037 --- /dev/null +++ b/internal/users/provision.go @@ -0,0 +1,51 @@ +package users + +import ( + "context" + "fmt" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/auth" +) + +// ProvisionEmail returns the email stored for a newly provisioned user. +func ProvisionEmail(claims *auth.Claims) string { + if claims == nil { + return "" + } + email := strings.TrimSpace(claims.Email) + if email != "" { + return email + } + return claims.Sub + "@unknown.ultimail.local" +} + +// EnsureUser inserts or updates the Ultimail user row for OIDC claims and returns the internal UUID. +func EnsureUser(ctx context.Context, db *pgxpool.Pool, claims *auth.Claims) (string, error) { + if db == nil { + return "", fmt.Errorf("database not configured") + } + if claims == nil || strings.TrimSpace(claims.Sub) == "" { + return "", fmt.Errorf("missing subject claim") + } + + email := ProvisionEmail(claims) + name := strings.TrimSpace(claims.Name) + + var userID string + err := db.QueryRow(ctx, ` + INSERT INTO users (external_id, email, name) + VALUES ($1, $2, $3) + ON CONFLICT (external_id) DO UPDATE SET + email = EXCLUDED.email, + name = EXCLUDED.name, + updated_at = NOW() + RETURNING id + `, claims.Sub, email, name).Scan(&userID) + if err != nil { + return "", fmt.Errorf("provision user: %w", err) + } + return userID, nil +} diff --git a/internal/users/provision_test.go b/internal/users/provision_test.go new file mode 100644 index 0000000..d246077 --- /dev/null +++ b/internal/users/provision_test.go @@ -0,0 +1,21 @@ +package users + +import ( + "testing" + + "github.com/ultisuite/ulti-backend/internal/auth" +) + +func TestProvisionEmail_fromClaim(t *testing.T) { + got := ProvisionEmail(&auth.Claims{Sub: "sub-1", Email: "a@b.com"}) + if got != "a@b.com" { + t.Fatalf("ProvisionEmail() = %q, want a@b.com", got) + } +} + +func TestProvisionEmail_fallback(t *testing.T) { + got := ProvisionEmail(&auth.Claims{Sub: "sub-1"}) + if got != "sub-1@unknown.ultimail.local" { + t.Fatalf("ProvisionEmail() = %q", got) + } +} diff --git a/migrations/000006_outbox_scheduled_index.down.sql b/migrations/000006_outbox_scheduled_index.down.sql new file mode 100644 index 0000000..a3527ec --- /dev/null +++ b/migrations/000006_outbox_scheduled_index.down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS idx_outbox_scheduled; +CREATE INDEX idx_outbox_scheduled ON outbox(scheduled_at) + WHERE scheduled_at IS NOT NULL AND status = 'queued'; diff --git a/migrations/000006_outbox_scheduled_index.up.sql b/migrations/000006_outbox_scheduled_index.up.sql new file mode 100644 index 0000000..a835119 --- /dev/null +++ b/migrations/000006_outbox_scheduled_index.up.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS idx_outbox_scheduled; +CREATE INDEX idx_outbox_scheduled ON outbox(scheduled_at) + WHERE scheduled_at IS NOT NULL AND status = 'scheduled'; diff --git a/migrations/000007_mail_api_extensions.down.sql b/migrations/000007_mail_api_extensions.down.sql new file mode 100644 index 0000000..cb515d0 --- /dev/null +++ b/migrations/000007_mail_api_extensions.down.sql @@ -0,0 +1,5 @@ +DROP TABLE IF EXISTS mail_user_labels; + +ALTER TABLE mail_identities + DROP COLUMN IF EXISTS reply_to_addrs, + DROP COLUMN IF EXISTS updated_at; diff --git a/migrations/000007_mail_api_extensions.up.sql b/migrations/000007_mail_api_extensions.up.sql new file mode 100644 index 0000000..8e7bbba --- /dev/null +++ b/migrations/000007_mail_api_extensions.up.sql @@ -0,0 +1,14 @@ +ALTER TABLE mail_identities + ADD COLUMN reply_to_addrs JSONB NOT NULL DEFAULT '[]', + ADD COLUMN updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(); + +CREATE TABLE mail_user_labels ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v4(), + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + name TEXT NOT NULL, + color TEXT NOT NULL DEFAULT '', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(user_id, name) +); + +CREATE INDEX idx_mail_user_labels_user ON mail_user_labels(user_id); diff --git a/project-plan/checklist-execution.md b/project-plan/checklist-execution.md index ec61910..5ee7320 100644 --- a/project-plan/checklist-execution.md +++ b/project-plan/checklist-execution.md @@ -70,19 +70,19 @@ Objectif: transformer état actuel (partiellement implémenté) vers produit fon #### Correctifs prioritaires -- [ ] Corriger logique outbox "scheduled" vs "queued" pour envoi planifié. -- [ ] Vérifier ownership sur `UpdateLabels`, `UpdateFlags`, `DeleteMessage`, `UpdateRule`, `DeleteRule`, `DeleteWebhook`. -- [ ] Corriger flux création utilisateur manquant (external_id OIDC absent -> échec sous-requêtes user_id). -- [ ] Ajouter gestion robuste erreurs SQL (`rows.Scan`, `Exec` result check, `rows.Err`). -- [ ] Corriger cohérence threading (`thread_id`, `references`, `in_reply_to`). +- [x] Corriger logique outbox "scheduled" vs "queued" pour envoi planifié. +- [x] Vérifier ownership sur `UpdateLabels`, `UpdateFlags`, `DeleteMessage`, `UpdateRule`, `DeleteRule`, `DeleteWebhook`. +- [x] Corriger flux création utilisateur manquant (external_id OIDC absent -> échec sous-requêtes user_id). +- [x] Ajouter gestion robuste erreurs SQL (`rows.Scan`, `Exec` result check, `rows.Err`). +- [x] Corriger cohérence threading (`thread_id`, `references`, `in_reply_to`). #### Implémentation manquante -- [ ] Endpoint brouillons (create/update/delete/list). -- [ ] Endpoint pièces jointes (upload/download/inline/cid mapping). -- [ ] Endpoint dossiers/labels (CRUD + mapping IMAP flags/folders). -- [ ] Endpoint recherche avancée (filtres expéditeur, date, attachment, label, account). -- [ ] Endpoint identities (alias/from/reply-to/signature par compte). +- [x] Endpoint brouillons (create/update/delete/list). +- [x] Endpoint pièces jointes (upload/download/inline/cid mapping). +- [x] Endpoint dossiers/labels (CRUD + mapping IMAP flags/folders). +- [x] Endpoint recherche avancée (filtres expéditeur, date, attachment, label, account). +- [x] Endpoint identities (alias/from/reply-to/signature par compte). #### Hardening