This commit is contained in:
R3D347HR4Y 2026-05-25 13:52:27 +02:00
parent 665201627b
commit cd0a80f5e8
65 changed files with 3785 additions and 167 deletions

View File

@ -0,0 +1,125 @@
package main
import (
"context"
"flag"
"fmt"
"log/slog"
"os"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/api/mail"
"github.com/ultisuite/ulti-backend/internal/dbmigrate"
"github.com/ultisuite/ulti-backend/internal/envexpand"
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
)
func main() {
accountID := flag.String("account", "", "mail account UUID (optional; all accounts if empty)")
dryRun := flag.Bool("dry-run", false, "scan only, do not write updates")
flag.Parse()
for _, path := range []string{".env", "../.env"} {
_ = envexpand.ApplyFile(path)
}
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
slog.Error("DATABASE_URL is required")
os.Exit(1)
}
ctx := context.Background()
if err := dbmigrate.Up(dbURL); err != nil {
slog.Error("migration failed", "error", err)
os.Exit(1)
}
pool, err := pgxpool.New(ctx, dbURL)
if err != nil {
slog.Error("db connect failed", "error", err)
os.Exit(1)
}
defer pool.Close()
if *dryRun {
scanned, changed, err := scanBodies(ctx, pool, *accountID)
if err != nil {
slog.Error("scan failed", "error", err)
os.Exit(1)
}
fmt.Printf("dry-run: scanned=%d would_update=%d\n", scanned, changed)
return
}
if *accountID != "" {
svc := mail.NewService(pool, nil, nil, nil, "")
result, err := svc.ResanitizeAccountBodiesByID(ctx, *accountID)
if err != nil {
slog.Error("resanitize failed", "account_id", *accountID, "error", err)
os.Exit(1)
}
fmt.Printf("account=%s scanned=%d updated=%d\n", *accountID, result.Scanned, result.Updated)
return
}
rows, err := pool.Query(ctx, `SELECT id FROM mail_accounts WHERE is_active = true ORDER BY created_at`)
if err != nil {
slog.Error("list accounts failed", "error", err)
os.Exit(1)
}
defer rows.Close()
svc := mail.NewService(pool, nil, nil, nil, "")
var totalScanned, totalUpdated int
for rows.Next() {
var id string
if err := rows.Scan(&id); err != nil {
slog.Error("scan account id failed", "error", err)
os.Exit(1)
}
result, err := svc.ResanitizeAccountBodiesByID(ctx, id)
if err != nil {
slog.Error("resanitize failed", "account_id", id, "error", err)
os.Exit(1)
}
fmt.Printf("account=%s scanned=%d updated=%d\n", id, result.Scanned, result.Updated)
totalScanned += result.Scanned
totalUpdated += result.Updated
}
if err := rows.Err(); err != nil {
slog.Error("list accounts failed", "error", err)
os.Exit(1)
}
fmt.Printf("done: accounts scanned_messages=%d updated=%d\n", totalScanned, totalUpdated)
}
func scanBodies(ctx context.Context, pool *pgxpool.Pool, accountID string) (scanned, changed int, err error) {
query := `
SELECT id, body_html FROM messages
WHERE body_html <> ''`
args := []any{}
if accountID != "" {
query += ` AND account_id = $1`
args = append(args, accountID)
}
rows, err := pool.Query(ctx, query, args...)
if err != nil {
return 0, 0, err
}
defer rows.Close()
for rows.Next() {
var id, body string
if err := rows.Scan(&id, &body); err != nil {
return scanned, changed, err
}
scanned++
if sanitize.SanitizeHTML(body) != body {
changed++
}
}
return scanned, changed, rows.Err()
}

View File

@ -122,7 +122,8 @@ func main() {
// Nextcloud client (nil if disabled) // Nextcloud client (nil if disabled)
var ncClient *nextcloud.Client var ncClient *nextcloud.Client
if cfg.NextcloudEnabled { if cfg.NextcloudEnabled {
ncClient = nextcloud.NewClient(cfg.NextcloudURL, cfg.NCAdminUser, cfg.NCAdminPass) ncClient = nextcloud.NewClient(cfg.NextcloudURL, cfg.NCAdminUser, cfg.NCAdminPass).
WithDAVCredentials(nextcloud.NewDAVCredentialStore(pool, credentialManager))
slog.Info("nextcloud enabled", "url", cfg.NextcloudURL) slog.Info("nextcloud enabled", "url", cfg.NextcloudURL)
} }
@ -163,12 +164,13 @@ func main() {
}, rdb) }, rdb)
// Start background workers // Start background workers
go imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{ syncWorker := imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{
Storage: attachmentStorage, Storage: attachmentStorage,
AttachBucket: cfg.MailAttachmentsBucket, AttachBucket: cfg.MailAttachmentsBucket,
Rules: rulesEngine, Rules: rulesEngine,
Hub: hub, Hub: hub,
}).Start(ctx) })
go syncWorker.Start(ctx)
sender := smtp.NewSender(pool, credentialManager, mailOAuthSvc) sender := smtp.NewSender(pool, credentialManager, mailOAuthSvc)
smtpCircuit := smtp.NewCircuitBreaker(cfg.MailSMTPCircuitFailures, cfg.MailSMTPCircuitCooldown) smtpCircuit := smtp.NewCircuitBreaker(cfg.MailSMTPCircuitFailures, cfg.MailSMTPCircuitCooldown)
@ -182,7 +184,8 @@ func main() {
).Start(ctx) ).Start(ctx)
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst) sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL) mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
mailHandler.SetAccountSync(syncWorker)
// Router // Router
r := chi.NewRouter() r := chi.NewRouter()

View File

@ -6,6 +6,7 @@ Blueprints in `blueprints/` are mounted into Authentik at `/blueprints/custom` a
|---------|------| |---------|------|
| `01-ulti-enrollment.yaml` | Inscription self-service (`ulti-enrollment`) | | `01-ulti-enrollment.yaml` | Inscription self-service (`ulti-enrollment`) |
| `02-ulti-brand.yaml` | Branding Ultimail + lien « Créer un compte » sur login | | `02-ulti-brand.yaml` | Branding Ultimail + lien « Créer un compte » sur login |
| `03-ulti-suite-groups.yaml` | Claim OIDC `groups` (RBAC contacts/calendar/drive/photos) |
| `ulti-oidc.yaml` | App OIDC Ultimail | | `ulti-oidc.yaml` | App OIDC Ultimail |
| `nextcloud-oidc.yaml` | App OIDC Nextcloud | | `nextcloud-oidc.yaml` | App OIDC Nextcloud |

View File

@ -0,0 +1,36 @@
# Ultimail — claim OIDC `groups` pour RBAC backend (contacts, calendar, drive, photos)
version: 1
metadata:
name: Ultimail suite groups
labels:
blueprints.goauthentik.io/instantiate: "true"
entries:
- model: authentik_providers_oauth2.scopemapping
id: ulti-suite-groups-mapping
identifiers:
name: ulti-suite-groups
attrs:
name: ulti-suite-groups
scope_name: profile
description: Suite RBAC groups for Ultimail API
expression: |
return {
"groups": [
"role:user",
"contacts:write",
"calendar:write",
"drive:write",
"photos:write",
],
}
- model: authentik_providers_oauth2.oauth2provider
identifiers:
name: ulti-backend-provider
attrs:
property_mappings:
- !Find [authentik_providers_oauth2.scopemapping, [scope_name, openid]]
- !Find [authentik_providers_oauth2.scopemapping, [scope_name, email]]
- !Find [authentik_providers_oauth2.scopemapping, [scope_name, profile]]
- !Find [authentik_providers_oauth2.scopemapping, [scope_name, offline_access]]
- !KeyOf ulti-suite-groups-mapping

View File

@ -92,12 +92,13 @@ services:
restart: unless-stopped restart: unless-stopped
command: server command: server
environment: environment:
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY} # Required at compose parse time — empty ${VAR} would override env_file with "".
AUTHENTIK_POSTGRESQL__HOST: ${AUTHENTIK_POSTGRESQL__HOST} AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?Set AUTHENTIK_SECRET_KEY in .env and use ./deploy/compose-up.sh}
AUTHENTIK_POSTGRESQL__USER: ${AUTHENTIK_POSTGRESQL__USER} AUTHENTIK_POSTGRESQL__HOST: ${AUTHENTIK_POSTGRESQL__HOST:-postgres}
AUTHENTIK_POSTGRESQL__PASSWORD: ${AUTHENTIK_POSTGRESQL__PASSWORD} AUTHENTIK_POSTGRESQL__USER: ${AUTHENTIK_POSTGRESQL__USER:?Set AUTHENTIK_POSTGRESQL__USER in .env}
AUTHENTIK_POSTGRESQL__NAME: ${AUTHENTIK_POSTGRESQL__NAME} AUTHENTIK_POSTGRESQL__PASSWORD: ${AUTHENTIK_POSTGRESQL__PASSWORD:?Set AUTHENTIK_POSTGRESQL__PASSWORD in .env}
AUTHENTIK_REDIS__HOST: ${AUTHENTIK_REDIS__HOST} AUTHENTIK_POSTGRESQL__NAME: ${AUTHENTIK_POSTGRESQL__NAME:-authentik}
AUTHENTIK_REDIS__HOST: ${AUTHENTIK_REDIS__HOST:-keydb}
AUTHENTIK_WEB__PATH: /auth/ AUTHENTIK_WEB__PATH: /auth/
AUTHENTIK_HOST: http://${DOMAIN:-localhost} AUTHENTIK_HOST: http://${DOMAIN:-localhost}
env_file: ../.env.resolved env_file: ../.env.resolved
@ -127,12 +128,12 @@ services:
restart: unless-stopped restart: unless-stopped
command: worker command: worker
environment: environment:
AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY} AUTHENTIK_SECRET_KEY: ${AUTHENTIK_SECRET_KEY:?Set AUTHENTIK_SECRET_KEY in .env and use ./deploy/compose-up.sh}
AUTHENTIK_POSTGRESQL__HOST: ${AUTHENTIK_POSTGRESQL__HOST} AUTHENTIK_POSTGRESQL__HOST: ${AUTHENTIK_POSTGRESQL__HOST:-postgres}
AUTHENTIK_POSTGRESQL__USER: ${AUTHENTIK_POSTGRESQL__USER} AUTHENTIK_POSTGRESQL__USER: ${AUTHENTIK_POSTGRESQL__USER:?Set AUTHENTIK_POSTGRESQL__USER in .env}
AUTHENTIK_POSTGRESQL__PASSWORD: ${AUTHENTIK_POSTGRESQL__PASSWORD} AUTHENTIK_POSTGRESQL__PASSWORD: ${AUTHENTIK_POSTGRESQL__PASSWORD:?Set AUTHENTIK_POSTGRESQL__PASSWORD in .env}
AUTHENTIK_POSTGRESQL__NAME: ${AUTHENTIK_POSTGRESQL__NAME} AUTHENTIK_POSTGRESQL__NAME: ${AUTHENTIK_POSTGRESQL__NAME:-authentik}
AUTHENTIK_REDIS__HOST: ${AUTHENTIK_REDIS__HOST} AUTHENTIK_REDIS__HOST: ${AUTHENTIK_REDIS__HOST:-keydb}
AUTHENTIK_WEB__PATH: /auth/ AUTHENTIK_WEB__PATH: /auth/
AUTHENTIK_HOST: http://${DOMAIN:-localhost} AUTHENTIK_HOST: http://${DOMAIN:-localhost}
env_file: ../.env.resolved env_file: ../.env.resolved

View File

@ -14,6 +14,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/apivalidate" "github.com/ultisuite/ulti-backend/internal/api/apivalidate"
"github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/middleware"
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
"github.com/ultisuite/ulti-backend/internal/permission" "github.com/ultisuite/ulti-backend/internal/permission"
) )
@ -48,12 +49,38 @@ func (h *Handler) Routes() chi.Router {
return r return r
} }
func (h *Handler) nextcloudUser(w http.ResponseWriter, r *http.Request, claims *auth.Claims) (string, bool) {
userID, err := h.svc.EnsureNextcloudUser(r.Context(), claims)
if err != nil {
h.logger.Error("ensure nextcloud user", "error", err, "sub", claims.Sub, "email", claims.Email)
apivalidate.WriteInternal(w, r)
return "", false
}
return userID, true
}
func (h *Handler) writeContactServiceError(w http.ResponseWriter, r *http.Request, op string, err error) {
if errors.Is(err, nextcloud.ErrPrincipalNotFound) {
apiresponse.WriteError(w, r, http.StatusNotFound, "contact_book_not_found", "contacts address book not found for user", nil)
return
}
if errors.Is(err, nextcloud.ErrDAVCredentialsMissing) {
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, "contacts_unavailable", "contacts backend credentials need refresh; retry shortly", nil)
return
}
h.logger.Error(op, "error", err)
apivalidate.WriteInternal(w, r)
}
func (h *Handler) ListAddressBooks(w http.ResponseWriter, r *http.Request) { func (h *Handler) ListAddressBooks(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
books, err := h.svc.ListAddressBooks(r.Context(), claims.Sub) ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
books, err := h.svc.ListAddressBooks(r.Context(), ncUser)
if err != nil { if err != nil {
h.logger.Error("list address books", "error", err) h.writeContactServiceError(w, r, "list address books", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"address_books": books}) apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"address_books": books})
@ -61,21 +88,24 @@ func (h *Handler) ListAddressBooks(w http.ResponseWriter, r *http.Request) {
func (h *Handler) SyncContacts(w http.ResponseWriter, r *http.Request) { func (h *Handler) SyncContacts(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
syncToken, verr := validateSyncToken(r.URL.Query().Get("sync_token")) syncToken, verr := validateSyncToken(r.URL.Query().Get("sync_token"))
if verr != nil { if verr != nil {
apivalidate.WriteValidationError(w, r, verr) apivalidate.WriteValidationError(w, r, verr)
return return
} }
result, err := h.svc.SyncContacts(r.Context(), claims.Sub, chi.URLParam(r, "bookID"), syncToken) result, err := h.svc.SyncContacts(r.Context(), ncUser, chi.URLParam(r, "bookID"), syncToken)
if err != nil { if err != nil {
if errors.Is(err, nextcloud.ErrSyncTokenInvalid) { if errors.Is(err, nextcloud.ErrSyncTokenInvalid) {
apiresponse.WriteError(w, r, http.StatusConflict, "sync_token_invalid", apiresponse.WriteError(w, r, http.StatusConflict, "sync_token_invalid",
"sync token is no longer valid; omit sync_token to perform a full resync", nil) "sync token is no longer valid; omit sync_token to perform a full resync", nil)
return return
} }
h.logger.Error("sync contacts", "error", err) h.writeContactServiceError(w, r, "sync contacts", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, result) apiresponse.WriteJSON(w, http.StatusOK, result)
@ -83,16 +113,19 @@ func (h *Handler) SyncContacts(w http.ResponseWriter, r *http.Request) {
func (h *Handler) ListContacts(w http.ResponseWriter, r *http.Request) { func (h *Handler) ListContacts(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
params, err := query.ParseListRequest(r) params, err := query.ParseListRequest(r)
if err != nil { if err != nil {
apivalidate.WriteQueryError(w, r, err) apivalidate.WriteQueryError(w, r, err)
return return
} }
result, err := h.svc.ListContacts(r.Context(), claims.Sub, chi.URLParam(r, "bookID"), params) result, err := h.svc.ListContacts(r.Context(), ncUser, chi.URLParam(r, "bookID"), params)
if err != nil { if err != nil {
h.logger.Error("list contacts", "error", err) h.writeContactServiceError(w, r, "list contacts", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, result) apiresponse.WriteJSON(w, http.StatusOK, result)
@ -100,6 +133,10 @@ func (h *Handler) ListContacts(w http.ResponseWriter, r *http.Request) {
func (h *Handler) SearchContacts(w http.ResponseWriter, r *http.Request) { func (h *Handler) SearchContacts(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
params, err := query.ParseListRequest(r) params, err := query.ParseListRequest(r)
if err != nil { if err != nil {
apivalidate.WriteQueryError(w, r, err) apivalidate.WriteQueryError(w, r, err)
@ -112,10 +149,9 @@ func (h *Handler) SearchContacts(w http.ResponseWriter, r *http.Request) {
} }
q := r.URL.Query().Get("q") q := r.URL.Query().Get("q")
result, err := h.svc.SearchContacts(r.Context(), claims.Sub, bookID, q, params) result, err := h.svc.SearchContacts(r.Context(), ncUser, bookID, q, params)
if err != nil { if err != nil {
h.logger.Error("search contacts", "error", err) h.writeContactServiceError(w, r, "search contacts", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, result) apiresponse.WriteJSON(w, http.StatusOK, result)
@ -123,6 +159,10 @@ func (h *Handler) SearchContacts(w http.ResponseWriter, r *http.Request) {
func (h *Handler) CreateContact(w http.ResponseWriter, r *http.Request) { func (h *Handler) CreateContact(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
var contact nextcloud.Contact var contact nextcloud.Contact
if err := apivalidate.DecodeJSON(w, r, maxRequestBody, &contact); err != nil { if err := apivalidate.DecodeJSON(w, r, maxRequestBody, &contact); err != nil {
@ -133,9 +173,8 @@ func (h *Handler) CreateContact(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := h.svc.CreateContact(r.Context(), claims.Sub, chi.URLParam(r, "bookID"), &contact); err != nil { if err := h.svc.CreateContact(r.Context(), ncUser, chi.URLParam(r, "bookID"), &contact); err != nil {
h.logger.Error("create contact", "error", err) h.writeContactServiceError(w, r, "create contact", err)
apivalidate.WriteInternal(w, r)
return return
} }
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
@ -143,6 +182,10 @@ func (h *Handler) CreateContact(w http.ResponseWriter, r *http.Request) {
func (h *Handler) UpdateContact(w http.ResponseWriter, r *http.Request) { func (h *Handler) UpdateContact(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
contactPath := strings.TrimSuffix(chi.URLParam(r, "*"), "/") contactPath := strings.TrimSuffix(chi.URLParam(r, "*"), "/")
if verr := validateDeletePath(contactPath); verr != nil { if verr := validateDeletePath(contactPath); verr != nil {
apivalidate.WriteValidationError(w, r, verr) apivalidate.WriteValidationError(w, r, verr)
@ -163,14 +206,13 @@ func (h *Handler) UpdateContact(w http.ResponseWriter, r *http.Request) {
return return
} }
etag, err := h.svc.UpdateContact(r.Context(), claims.Sub, contactPath, ifMatch, &contact) etag, err := h.svc.UpdateContact(r.Context(), ncUser, contactPath, ifMatch, &contact)
if err != nil { if err != nil {
if errors.Is(err, nextcloud.ErrETagMismatch) { if errors.Is(err, nextcloud.ErrETagMismatch) {
apiresponse.WriteError(w, r, http.StatusPreconditionFailed, "etag_mismatch", "etag does not match current resource version", nil) apiresponse.WriteError(w, r, http.StatusPreconditionFailed, "etag_mismatch", "etag does not match current resource version", nil)
return return
} }
h.logger.Error("update contact", "error", err) h.writeContactServiceError(w, r, "update contact", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"etag": etag}) apiresponse.WriteJSON(w, http.StatusOK, map[string]any{"etag": etag})
@ -178,6 +220,10 @@ func (h *Handler) UpdateContact(w http.ResponseWriter, r *http.Request) {
func (h *Handler) MergeDuplicateContacts(w http.ResponseWriter, r *http.Request) { func (h *Handler) MergeDuplicateContacts(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
var req MergeDuplicatesRequest var req MergeDuplicatesRequest
if r.ContentLength > 0 { if r.ContentLength > 0 {
@ -186,10 +232,9 @@ func (h *Handler) MergeDuplicateContacts(w http.ResponseWriter, r *http.Request)
} }
} }
result, err := h.svc.MergeDuplicates(r.Context(), claims.Sub, chi.URLParam(r, "bookID"), req) result, err := h.svc.MergeDuplicates(r.Context(), ncUser, chi.URLParam(r, "bookID"), req)
if err != nil { if err != nil {
h.logger.Error("merge duplicate contacts", "error", err) h.writeContactServiceError(w, r, "merge duplicate contacts", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, result) apiresponse.WriteJSON(w, http.StatusOK, result)
@ -215,6 +260,10 @@ func (h *Handler) GetInteractionsByEmail(w http.ResponseWriter, r *http.Request)
func (h *Handler) GetContactInteractions(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetContactInteractions(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
contactPath := strings.TrimSuffix(chi.URLParam(r, "*"), "/") contactPath := strings.TrimSuffix(chi.URLParam(r, "*"), "/")
if verr := validateDeletePath(contactPath); verr != nil { if verr := validateDeletePath(contactPath); verr != nil {
apivalidate.WriteValidationError(w, r, verr) apivalidate.WriteValidationError(w, r, verr)
@ -233,7 +282,7 @@ func (h *Handler) GetContactInteractions(w http.ResponseWriter, r *http.Request)
limit = val limit = val
} }
result, err := h.svc.ContactInteractionsByPath(r.Context(), claims.Sub, contactPath, limit) result, err := h.svc.ContactInteractionsByPath(r.Context(), ncUser, contactPath, limit)
if err != nil { if err != nil {
if errors.Is(err, ErrContactEmailMissing) { if errors.Is(err, ErrContactEmailMissing) {
apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{ apivalidate.WriteValidationError(w, r, apivalidate.NewValidationError(apivalidate.FieldDetail{
@ -241,8 +290,7 @@ func (h *Handler) GetContactInteractions(w http.ResponseWriter, r *http.Request)
})) }))
return return
} }
h.logger.Error("contact interactions by path", "error", err) h.writeContactServiceError(w, r, "contact interactions by path", err)
apivalidate.WriteInternal(w, r)
return return
} }
apiresponse.WriteJSON(w, http.StatusOK, result) apiresponse.WriteJSON(w, http.StatusOK, result)
@ -250,14 +298,17 @@ func (h *Handler) GetContactInteractions(w http.ResponseWriter, r *http.Request)
func (h *Handler) DeleteContact(w http.ResponseWriter, r *http.Request) { func (h *Handler) DeleteContact(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
ncUser, ok := h.nextcloudUser(w, r, claims)
if !ok {
return
}
contactPath := chi.URLParam(r, "*") contactPath := chi.URLParam(r, "*")
if verr := validateDeletePath(contactPath); verr != nil { if verr := validateDeletePath(contactPath); verr != nil {
apivalidate.WriteValidationError(w, r, verr) apivalidate.WriteValidationError(w, r, verr)
return return
} }
if err := h.svc.DeleteContact(r.Context(), claims.Sub, contactPath); err != nil { if err := h.svc.DeleteContact(r.Context(), ncUser, contactPath); err != nil {
h.logger.Error("delete contact", "error", err) h.writeContactServiceError(w, r, "delete contact", err)
apivalidate.WriteInternal(w, r)
return return
} }
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)

View File

@ -12,6 +12,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/api/paginate" "github.com/ultisuite/ulti-backend/internal/api/paginate"
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
) )
@ -24,8 +25,15 @@ func NewService(nc *nextcloud.Client, db *pgxpool.Pool) *Service {
return &Service{nc: nc, db: db} return &Service{nc: nc, db: db}
} }
func (s *Service) EnsureNextcloudUser(ctx context.Context, claims *auth.Claims) (string, error) {
if s.nc == nil {
return "", fmt.Errorf("nextcloud unavailable")
}
return s.nc.EnsurePrincipal(ctx, claims.Email, claims.Sub, claims.Name)
}
func bookPath(userID, bookID string) string { func bookPath(userID, bookID string) string {
return "/remote.php/dav/addressbooks/users/" + userID + "/" + bookID + "/" return nextcloud.AddressBookPath(userID, bookID)
} }
func (s *Service) ListAddressBooks(ctx context.Context, userID string) ([]nextcloud.AddressBook, error) { func (s *Service) ListAddressBooks(ctx context.Context, userID string) ([]nextcloud.AddressBook, error) {

View File

@ -0,0 +1,40 @@
package mail
import (
"fmt"
"strings"
"github.com/google/uuid"
)
var systemFolderSlugs = map[string]string{
"inbox": "inbox",
"sent": "sent",
"drafts": "drafts",
"trash": "trash",
"archive": "archive",
"spam": "spam",
}
// folderFilterClause builds a SQL fragment that resolves a folder query param to
// mail_folders rows. System slugs (e.g. "inbox") match folder_type; UUIDs match
// folder id; everything else matches display name case-insensitively.
func folderFilterClause(folder string, argIdx int) (clause string, arg any, ok bool) {
folder = strings.TrimSpace(folder)
if folder == "" {
return "", nil, false
}
if _, err := uuid.Parse(folder); err == nil {
return fmt.Sprintf(" AND m.folder_id = $%d", argIdx), folder, true
}
if folderType, known := systemFolderSlugs[strings.ToLower(folder)]; known {
return fmt.Sprintf(
" AND m.folder_id IN (SELECT id FROM mail_folders WHERE folder_type = $%d AND account_id = m.account_id)",
argIdx,
), folderType, true
}
return fmt.Sprintf(
" AND m.folder_id IN (SELECT id FROM mail_folders WHERE LOWER(name) = LOWER($%d) AND account_id = m.account_id)",
argIdx,
), folder, true
}

View File

@ -0,0 +1,56 @@
package mail
import (
"testing"
"github.com/google/uuid"
)
func TestFolderFilterClause(t *testing.T) {
id := uuid.NewString()
tests := []struct {
name string
folder string
wantOK bool
wantArg any
wantSQL string
}{
{name: "empty", folder: "", wantOK: false},
{name: "inbox slug", folder: "inbox", wantOK: true, wantArg: "inbox", wantSQL: "folder_type"},
{name: "Inbox slug", folder: "Inbox", wantOK: true, wantArg: "inbox", wantSQL: "folder_type"},
{name: "uuid", folder: id, wantOK: true, wantArg: id, wantSQL: "m.folder_id = $1"},
{name: "custom name", folder: "Factures", wantOK: true, wantArg: "Factures", wantSQL: "LOWER(name)"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
clause, arg, ok := folderFilterClause(tc.folder, 1)
if ok != tc.wantOK {
t.Fatalf("ok = %v, want %v", ok, tc.wantOK)
}
if !tc.wantOK {
return
}
if arg != tc.wantArg {
t.Fatalf("arg = %v, want %v", arg, tc.wantArg)
}
if !stringsContains(clause, tc.wantSQL) {
t.Fatalf("clause = %q, want fragment %q", clause, tc.wantSQL)
}
})
}
}
func stringsContains(s, sub string) bool {
return len(s) >= len(sub) && (s == sub || len(sub) == 0 || indexOf(s, sub) >= 0)
}
func indexOf(s, sub string) int {
for i := 0; i+len(sub) <= len(s); i++ {
if s[i:i+len(sub)] == sub {
return i
}
}
return -1
}

View File

@ -22,10 +22,17 @@ import (
type Handler struct { type Handler struct {
svc ServiceAPI svc ServiceAPI
mailSender MailSender
logger *slog.Logger logger *slog.Logger
sendLimiter *sendguard.RateLimiter sendLimiter *sendguard.RateLimiter
oauth *mailoauth.Service oauth *mailoauth.Service
appURL string appURL string
accountSync AccountSyncTrigger
}
// SetAccountSync wires the IMAP sync worker for on-demand account sync.
func (h *Handler) SetAccountSync(trigger AccountSyncTrigger) {
h.accountSync = trigger
} }
func NewHandlerWithService(svc ServiceAPI) *Handler { func NewHandlerWithService(svc ServiceAPI) *Handler {
@ -44,8 +51,10 @@ func NewHandler(
sendLimiter *sendguard.RateLimiter, sendLimiter *sendguard.RateLimiter,
oauthSvc *mailoauth.Service, oauthSvc *mailoauth.Service,
appURL string, appURL string,
mailSender MailSender,
) *Handler { ) *Handler {
h := NewHandlerWithService(NewService(db, audit, credentialManager, objectStorage, attachmentsBucket)) h := NewHandlerWithService(NewService(db, audit, credentialManager, objectStorage, attachmentsBucket))
h.mailSender = mailSender
h.sendLimiter = sendLimiter h.sendLimiter = sendLimiter
h.oauth = oauthSvc h.oauth = oauthSvc
h.appURL = appURL h.appURL = appURL
@ -74,6 +83,8 @@ func (h *Handler) Routes() chi.Router {
r.Get("/accounts/{accountID}", h.GetAccount) r.Get("/accounts/{accountID}", h.GetAccount)
r.Put("/accounts/{accountID}", h.UpdateAccount) r.Put("/accounts/{accountID}", h.UpdateAccount)
r.Delete("/accounts/{accountID}", h.DeleteAccount) r.Delete("/accounts/{accountID}", h.DeleteAccount)
r.Post("/accounts/{accountID}/resanitize-bodies", h.ResanitizeAccountBodies)
r.Post("/accounts/{accountID}/sync", h.SyncAccountNow)
r.Get("/accounts/{accountID}/identities", h.ListIdentities) r.Get("/accounts/{accountID}/identities", h.ListIdentities)
r.Post("/accounts/{accountID}/identities", h.CreateIdentity) r.Post("/accounts/{accountID}/identities", h.CreateIdentity)
@ -104,6 +115,7 @@ func (h *Handler) Routes() chi.Router {
r.Get("/messages/{messageID}/attachments", h.ListMessageAttachments) r.Get("/messages/{messageID}/attachments", h.ListMessageAttachments)
r.Get("/messages/{messageID}/attachments/cid-map", h.MessageAttachmentCIDMap) r.Get("/messages/{messageID}/attachments/cid-map", h.MessageAttachmentCIDMap)
r.Post("/messages/{messageID}/attachments", h.UploadMessageAttachment) r.Post("/messages/{messageID}/attachments", h.UploadMessageAttachment)
r.Post("/messages/{messageID}/list-unsubscribe-mailto", h.SendListUnsubscribeMailto)
r.Get("/messages/{messageID}", h.GetMessage) r.Get("/messages/{messageID}", h.GetMessage)
r.Put("/messages/{messageID}/labels", h.UpdateLabels) r.Put("/messages/{messageID}/labels", h.UpdateLabels)
r.Put("/messages/{messageID}/flags", h.UpdateFlags) r.Put("/messages/{messageID}/flags", h.UpdateFlags)
@ -365,6 +377,38 @@ func (h *Handler) DeleteMessage(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent) w.WriteHeader(http.StatusNoContent)
} }
func (h *Handler) SendListUnsubscribeMailto(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
messageID := chi.URLParam(r, "messageID")
if h.mailSender == nil {
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, apiresponse.CodeInternal, "mail send unavailable", nil)
return
}
target, err := h.svc.SendMailtoListUnsubscribe(r.Context(), claims.Sub, messageID, h.mailSender)
if err != nil {
switch {
case errors.Is(err, ErrNotFound):
apivalidate.WriteNotFound(w, r, "not found")
case errors.Is(err, ErrListUnsubscribeNoMailto):
apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, err.Error(), nil)
case errors.Is(err, ErrListUnsubscribeUnavailable):
apiresponse.WriteError(w, r, http.StatusConflict, apiresponse.CodeInvalidRequest, "no mailto list-unsubscribe", nil)
default:
h.logger.Error("list-unsubscribe mailto send", "message_id", messageID, "error", err)
apivalidate.WriteInternal(w, r)
}
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]any{
"sent": true,
"mailto": target.Address,
"subject": target.Subject,
})
}
func (h *Handler) GetThread(w http.ResponseWriter, r *http.Request) { func (h *Handler) GetThread(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
result, err := h.svc.GetThread(r.Context(), claims.Sub, chi.URLParam(r, "threadID")) result, err := h.svc.GetThread(r.Context(), claims.Sub, chi.URLParam(r, "threadID"))

View File

@ -0,0 +1,71 @@
package mail
import (
"context"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
"github.com/ultisuite/ulti-backend/internal/api/middleware"
)
// AccountSyncTrigger runs an immediate IMAP sync for one mail account.
type AccountSyncTrigger interface {
SyncAccountForUser(ctx context.Context, externalID, accountID string) error
}
func (h *Handler) ResanitizeAccountBodies(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
accountID := chi.URLParam(r, "accountID")
if d := validateAccountUUID(accountID); d != nil {
apivalidate.WriteNotFound(w, r, "not found")
return
}
result, err := h.svc.ResanitizeAccountBodies(r.Context(), claims.Sub, accountID)
if err != nil {
if errors.Is(err, ErrAccountNotFound) {
apivalidate.WriteNotFound(w, r, "not found")
return
}
h.logger.Error("resanitize account bodies", "account_id", accountID, "error", err)
apivalidate.WriteInternal(w, r)
return
}
apiresponse.WriteJSON(w, http.StatusOK, result)
}
func (h *Handler) SyncAccountNow(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context())
accountID := chi.URLParam(r, "accountID")
if d := validateAccountUUID(accountID); d != nil {
apivalidate.WriteNotFound(w, r, "not found")
return
}
if h.accountSync == nil {
apiresponse.WriteError(w, r, http.StatusServiceUnavailable, "sync_unavailable", "mail sync is not configured", nil)
return
}
if _, err := h.svc.GetAccount(r.Context(), claims.Sub, accountID); err != nil {
if errors.Is(err, ErrNotFound) {
apivalidate.WriteNotFound(w, r, "not found")
return
}
h.logger.Error("load account for sync", "account_id", accountID, "error", err)
apivalidate.WriteInternal(w, r)
return
}
if err := h.accountSync.SyncAccountForUser(r.Context(), claims.Sub, accountID); err != nil {
h.logger.Error("sync account", "account_id", accountID, "error", err)
apiresponse.WriteError(w, r, http.StatusBadGateway, "sync_failed", "imap sync failed", nil)
return
}
apiresponse.WriteJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}

View File

@ -2,6 +2,8 @@ package mail
import ( import (
"net/http" "net/http"
"net/url"
"strings"
"time" "time"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/api/apiresponse"
@ -12,7 +14,7 @@ import (
func (h *Handler) SearchMessages(w http.ResponseWriter, r *http.Request) { func (h *Handler) SearchMessages(w http.ResponseWriter, r *http.Request) {
claims := middleware.ClaimsFromContext(r.Context()) claims := middleware.ClaimsFromContext(r.Context())
params, err := query.ParseListRequest(r) params, err := query.ParseList(stripNonDateListRangeKeys(r.URL.Query()))
if err != nil { if err != nil {
apivalidate.WriteQueryError(w, r, err) apivalidate.WriteQueryError(w, r, err)
return return
@ -37,7 +39,7 @@ func parseMessageSearchFilter(r *http.Request) (MessageSearchFilter, *apivalidat
q := r.URL.Query() q := r.URL.Query()
filter := MessageSearchFilter{ filter := MessageSearchFilter{
Query: q.Get("q"), Query: q.Get("q"),
Sender: q.Get("from"), Sender: parseSearchSender(q),
Label: q.Get("label"), Label: q.Get("label"),
AccountID: q.Get("account_id"), AccountID: q.Get("account_id"),
} }
@ -84,3 +86,38 @@ func parseMessageSearchFilter(r *http.Request) (MessageSearchFilter, *apivalidat
} }
return filter, nil return filter, nil
} }
// stripNonDateListRangeKeys removes from/to when they are sender/recipient filters,
// not YYYY-MM-DD list date bounds (shared param names on /mail/search).
func stripNonDateListRangeKeys(values url.Values) url.Values {
out := values
clone := make(url.Values, len(values))
for k, vv := range values {
clone[k] = append([]string(nil), vv...)
}
out = clone
for _, key := range []string{"from", "to"} {
raw := strings.TrimSpace(out.Get(key))
if raw == "" {
continue
}
if _, err := query.ParseDate(raw); err != nil {
out.Del(key)
}
}
return out
}
func parseSearchSender(q url.Values) string {
if s := strings.TrimSpace(q.Get("sender")); s != "" {
return s
}
from := strings.TrimSpace(q.Get("from"))
if from == "" {
return ""
}
if _, err := query.ParseDate(from); err != nil {
return from
}
return ""
}

View File

@ -17,6 +17,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/listunsubscribe"
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/rules"
) )
@ -130,6 +131,10 @@ func (f *fakeMailService) ListMessages(_ context.Context, externalID string, _ M
}, nil }, nil
} }
func (f *fakeMailService) SendMailtoListUnsubscribe(context.Context, string, string, MailSender) (*listunsubscribe.Mailto, error) {
return nil, ErrListUnsubscribeUnavailable
}
func (f *fakeMailService) GetMessage(_ context.Context, externalID, messageID string) (map[string]any, error) { func (f *fakeMailService) GetMessage(_ context.Context, externalID, messageID string) (map[string]any, error) {
if externalID != testExternalID { if externalID != testExternalID {
return nil, ErrUserNotProvisioned return nil, ErrUserNotProvisioned
@ -300,6 +305,9 @@ func (f *fakeMailService) CredentialForConnectionTest(context.Context, string, *
return credentials.Credential{AuthType: credentials.AuthPassword, Username: "u", Password: "p"}, nil return credentials.Credential{AuthType: credentials.AuthPassword, Username: "u", Password: "p"}, nil
} }
func (f *fakeMailService) DeleteAccount(context.Context, string, string) error { return nil } func (f *fakeMailService) DeleteAccount(context.Context, string, string) error { return nil }
func (f *fakeMailService) ResanitizeAccountBodies(context.Context, string, string) (ResanitizeBodiesResult, error) {
return ResanitizeBodiesResult{}, nil
}
func (f *fakeMailService) GetThread(context.Context, string, string) (map[string]any, error) { func (f *fakeMailService) GetThread(context.Context, string, string) (map[string]any, error) {
return map[string]any{"messages": []any{}}, nil return map[string]any{"messages": []any{}}, nil
} }
@ -322,7 +330,7 @@ func (f *fakeMailService) DeleteRule(_ context.Context, externalID, ruleID strin
return nil return nil
} }
func (f *fakeMailService) SimulateRule(_ context.Context, externalID string, req *simulateRuleRequest) (rules.SimulationResult, error) { func (f *fakeMailService) SimulateRule(_ context.Context, externalID string, req *simulateRuleRequest) (any, error) {
if externalID != testExternalID { if externalID != testExternalID {
return rules.SimulationResult{}, ErrUserNotProvisioned return rules.SimulationResult{}, ErrUserNotProvisioned
} }

View File

@ -0,0 +1,102 @@
package mail
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/ultisuite/ulti-backend/internal/mail/listunsubscribe"
"github.com/ultisuite/ulti-backend/internal/mail/smtp"
)
// MailSender sends immediately without outbox persistence.
type MailSender interface {
Send(ctx context.Context, req *smtp.SendRequest) error
}
var (
ErrListUnsubscribeUnavailable = errors.New("list-unsubscribe mailto not available")
ErrListUnsubscribeNoMailto = errors.New("list-unsubscribe has no mailto target")
)
type messageAuthInfo struct {
ListUnsubscribe string `json:"list_unsubscribe"`
}
// SendMailtoListUnsubscribe sends the RFC 2369 mailto unsubscribe without outbox or sent copy.
func (s *Service) SendMailtoListUnsubscribe(
ctx context.Context,
externalID, messageID string,
sender MailSender,
) (*listunsubscribe.Mailto, error) {
if sender == nil {
return nil, errors.New("mail sender not configured")
}
var accountID string
var authRaw []byte
err := s.db.QueryRow(ctx, `
SELECT m.account_id, m.auth_info
FROM messages m
JOIN mail_accounts ma ON m.account_id = ma.id
WHERE m.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2)
`, messageID, externalID).Scan(&accountID, &authRaw)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, ErrNotFound
}
return nil, err
}
var auth messageAuthInfo
if len(authRaw) > 0 {
_ = json.Unmarshal(authRaw, &auth)
}
parsed := listunsubscribe.Parse(auth.ListUnsubscribe)
if parsed.Mailto == nil {
if parsed.HTTP != "" {
return nil, fmt.Errorf("%w: use http url", ErrListUnsubscribeNoMailto)
}
return nil, ErrListUnsubscribeUnavailable
}
fromEmail, err := s.resolveAccountFromEmail(ctx, accountID)
if err != nil {
return nil, err
}
m := parsed.Mailto
req := &smtp.SendRequest{
AccountID: accountID,
From: fromEmail,
To: []string{m.Address},
Subject: m.Subject,
BodyText: m.Body,
}
if err := sender.Send(ctx, req); err != nil {
return nil, err
}
return m, nil
}
func (s *Service) resolveAccountFromEmail(ctx context.Context, accountID string) (string, error) {
var fromEmail string
err := s.db.QueryRow(ctx, `
SELECT mi.email FROM mail_identities mi
JOIN mail_accounts ma ON mi.account_id = ma.id
WHERE ma.id = $1 AND mi.is_default = true
LIMIT 1
`, accountID).Scan(&fromEmail)
if err == nil && fromEmail != "" {
return fromEmail, nil
}
if err := s.db.QueryRow(ctx, `SELECT email FROM mail_accounts WHERE id = $1`, accountID).Scan(&fromEmail); err != nil {
return "", err
}
if fromEmail == "" {
return "", errors.New("account has no from address")
}
return fromEmail, nil
}

View File

@ -0,0 +1,76 @@
package mail
import (
"context"
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
)
const resanitizeBatchSize = 200
type ResanitizeBodiesResult struct {
Scanned int `json:"scanned"`
Updated int `json:"updated"`
}
// ResanitizeAccountBodies re-applies email HTML sanitization to stored messages.
func (s *Service) ResanitizeAccountBodies(ctx context.Context, externalID, accountID string) (ResanitizeBodiesResult, error) {
if err := s.verifyAccountOwnership(ctx, externalID, accountID); err != nil {
return ResanitizeBodiesResult{}, err
}
return s.ResanitizeAccountBodiesByID(ctx, accountID)
}
// ResanitizeAccountBodiesByID re-sanitizes messages without an ownership check (CLI/admin).
func (s *Service) ResanitizeAccountBodiesByID(ctx context.Context, accountID string) (ResanitizeBodiesResult, error) {
var result ResanitizeBodiesResult
var lastID string
for {
rows, err := s.db.Query(ctx, `
SELECT id, body_html
FROM messages
WHERE account_id = $1
AND body_html <> ''
AND ($2 = '' OR id > $2::uuid)
ORDER BY id
LIMIT $3
`, accountID, lastID, resanitizeBatchSize)
if err != nil {
return result, err
}
batchCount := 0
for rows.Next() {
var id, bodyHTML string
if err := rows.Scan(&id, &bodyHTML); err != nil {
rows.Close()
return result, err
}
batchCount++
result.Scanned++
lastID = id
sanitized := sanitize.SanitizeHTML(bodyHTML)
if sanitized == bodyHTML {
continue
}
if _, err := s.db.Exec(ctx, `
UPDATE messages SET body_html = $2, updated_at = NOW() WHERE id = $1
`, id, sanitized); err != nil {
rows.Close()
return result, err
}
result.Updated++
}
if err := rows.Err(); err != nil {
return result, err
}
rows.Close()
if batchCount < resanitizeBatchSize {
break
}
}
return result, nil
}

View File

@ -8,6 +8,7 @@ import (
"time" "time"
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/mail/imap"
) )
type MessageSearchFilter struct { type MessageSearchFilter struct {
@ -102,7 +103,7 @@ func (s *Service) SearchMessages(ctx context.Context, externalID string, filter
entry := map[string]any{ entry := map[string]any{
"id": id, "message_id": messageID, "subject": subject, "id": id, "message_id": messageID, "subject": subject,
"from": json.RawMessage(fromAddr), "to": json.RawMessage(toAddrs), "from": json.RawMessage(fromAddr), "to": json.RawMessage(toAddrs),
"date": date, "snippet": snippet, "flags": flags, "labels": labels, "date": date, "snippet": imap.RepairSnippet(snippet), "flags": flags, "labels": labels,
"has_attachments": hasAttachments, "has_attachments": hasAttachments,
} }
if threadID != nil { if threadID != nil {

View File

@ -19,6 +19,32 @@ func TestSearchMessages(t *testing.T) {
} }
} }
func TestSearchMessagesBySender(t *testing.T) {
svc := newFakeMailService()
router := newTestMailRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/search?sender=alice@example.com", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
}
func TestSearchMessagesFromEmailParam(t *testing.T) {
svc := newFakeMailService()
router := newTestMailRouter(svc)
req := httptest.NewRequest(http.MethodGet, "/search?from=alice@example.com", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d; body = %s", rec.Code, http.StatusOK, rec.Body.String())
}
}
func TestSearchMessagesRequiresFilter(t *testing.T) { func TestSearchMessagesRequiresFilter(t *testing.T) {
svc := newFakeMailService() svc := newFakeMailService()
router := newTestMailRouter(svc) router := newTestMailRouter(svc)

View File

@ -12,6 +12,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/imap"
"github.com/ultisuite/ulti-backend/internal/mail/sanitize" "github.com/ultisuite/ulti-backend/internal/mail/sanitize"
"github.com/ultisuite/ulti-backend/internal/mail/storage" "github.com/ultisuite/ulti-backend/internal/mail/storage"
"github.com/ultisuite/ulti-backend/internal/mail/threading" "github.com/ultisuite/ulti-backend/internal/mail/threading"
@ -189,9 +190,9 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me
args = append(args, filter.AccountID) args = append(args, filter.AccountID)
argIdx++ argIdx++
} }
if filter.Folder != "" { if clause, arg, ok := folderFilterClause(filter.Folder, argIdx); ok {
baseQuery += fmt.Sprintf(" AND m.folder_id = (SELECT id FROM mail_folders WHERE name = $%d AND account_id = m.account_id LIMIT 1)", argIdx) baseQuery += clause
args = append(args, filter.Folder) args = append(args, arg)
argIdx++ argIdx++
} }
@ -202,7 +203,8 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me
} }
listQuery := ` listQuery := `
SELECT m.id, m.message_id, m.thread_id, m.subject, m.from_addr, m.to_addrs, m.date, m.snippet, m.flags, m.labels, m.has_attachments SELECT m.id, m.message_id, m.thread_id, m.subject, m.from_addr, m.to_addrs, m.date, m.snippet, m.flags, m.labels, m.has_attachments,
left(m.body_text, 8192), left(m.body_html, 8192)
` + baseQuery + fmt.Sprintf(" ORDER BY m.date DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1) ` + baseQuery + fmt.Sprintf(" ORDER BY m.date DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
args = append(args, params.Limit(), params.Offset()) args = append(args, params.Limit(), params.Offset())
@ -217,16 +219,26 @@ func (s *Service) ListMessages(ctx context.Context, externalID string, filter Me
var id, messageID, subject, snippet string var id, messageID, subject, snippet string
var threadID *string var threadID *string
var fromAddr, toAddrs []byte var fromAddr, toAddrs []byte
var bodyTextSample, bodyHTMLSample string
var date any var date any
var flags, labels []string var flags, labels []string
var hasAttachments bool var hasAttachments bool
if err := rows.Scan(&id, &messageID, &threadID, &subject, &fromAddr, &toAddrs, &date, &snippet, &flags, &labels, &hasAttachments); err != nil { if err := rows.Scan(&id, &messageID, &threadID, &subject, &fromAddr, &toAddrs, &date, &snippet, &flags, &labels, &hasAttachments, &bodyTextSample, &bodyHTMLSample); err != nil {
return MessagesList{}, err return MessagesList{}, err
} }
bodyTextSample, bodyHTMLSample = imap.RepairStoredBodies(bodyTextSample, bodyHTMLSample)
preview := imap.RepairSnippet(imap.SnippetFromBodies(bodyTextSample, bodyHTMLSample, 200))
if preview == "" {
preview = imap.RepairSnippet(snippet)
}
entry := map[string]any{ entry := map[string]any{
"id": id, "message_id": messageID, "subject": subject, "from": json.RawMessage(fromAddr), "id": id, "message_id": messageID,
"to": json.RawMessage(toAddrs), "date": date, "snippet": snippet, "subject": imap.RepairSubject(subject, bodyTextSample, bodyHTMLSample, nil),
"flags": flags, "labels": labels, "has_attachments": hasAttachments, "from": json.RawMessage(fromAddr),
"to": json.RawMessage(toAddrs),
"date": date,
"snippet": preview,
"flags": flags, "labels": labels, "has_attachments": hasAttachments,
} }
if threadID != nil { if threadID != nil {
entry["thread_id"] = *threadID entry["thread_id"] = *threadID
@ -255,6 +267,8 @@ func (s *Service) GetMessage(ctx context.Context, externalID, messageID string)
From []byte From []byte
To []byte To []byte
Cc []byte Cc []byte
ReplyTo []byte
AuthInfo []byte
Date any Date any
Text string Text string
HTML string HTML string
@ -263,13 +277,13 @@ func (s *Service) GetMessage(ctx context.Context, externalID, messageID string)
} }
err := s.db.QueryRow(ctx, ` err := s.db.QueryRow(ctx, `
SELECT m.id, m.message_id, m.thread_id, m.in_reply_to, m.references_header, SELECT m.id, m.message_id, m.thread_id, m.in_reply_to, m.references_header,
m.subject, m.from_addr, m.to_addrs, m.cc_addrs, m.date, m.subject, m.from_addr, m.to_addrs, m.cc_addrs, m.reply_to, m.auth_info, m.date,
m.body_text, m.body_html, m.flags, m.labels m.body_text, m.body_html, m.flags, m.labels
FROM messages m JOIN mail_accounts ma ON m.account_id = ma.id FROM messages m JOIN mail_accounts ma ON m.account_id = ma.id
WHERE m.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) WHERE m.id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2)
`, messageID, externalID).Scan( `, messageID, externalID).Scan(
&msg.ID, &msg.MessageID, &msg.ThreadID, &msg.InReplyTo, &msg.References, &msg.ID, &msg.MessageID, &msg.ThreadID, &msg.InReplyTo, &msg.References,
&msg.Subject, &msg.From, &msg.To, &msg.Cc, &msg.Date, &msg.Subject, &msg.From, &msg.To, &msg.Cc, &msg.ReplyTo, &msg.AuthInfo, &msg.Date,
&msg.Text, &msg.HTML, &msg.Flags, &msg.Labels, &msg.Text, &msg.HTML, &msg.Flags, &msg.Labels,
) )
if err != nil { if err != nil {
@ -278,10 +292,20 @@ func (s *Service) GetMessage(ctx context.Context, externalID, messageID string)
} }
return nil, err return nil, err
} }
bodyText, bodyHTML := imap.RepairStoredBodies(msg.Text, msg.HTML)
subject := imap.RepairSubject(msg.Subject, bodyText, bodyHTML, nil)
repairedSnippet := imap.RepairSnippet(imap.SnippetFromBodies(bodyText, bodyHTML, 200))
if bodyText != msg.Text || bodyHTML != msg.HTML || subject != msg.Subject {
_, _ = s.db.Exec(ctx, `
UPDATE messages SET body_text = $1, body_html = $2, snippet = $3, subject = $4, updated_at = NOW()
WHERE id = $5
`, bodyText, bodyHTML, repairedSnippet, subject, msg.ID)
}
out := map[string]any{ out := map[string]any{
"id": msg.ID, "message_id": msg.MessageID, "subject": msg.Subject, "id": msg.ID, "message_id": msg.MessageID, "subject": subject,
"from": json.RawMessage(msg.From), "to": json.RawMessage(msg.To), "cc": json.RawMessage(msg.Cc), "from": json.RawMessage(msg.From), "to": json.RawMessage(msg.To), "cc": json.RawMessage(msg.Cc),
"date": msg.Date, "body_text": msg.Text, "body_html": sanitize.SanitizeHTML(msg.HTML), "reply_to": json.RawMessage(msg.ReplyTo), "auth_info": json.RawMessage(msg.AuthInfo),
"date": msg.Date, "body_text": bodyText, "body_html": sanitize.SanitizeHTML(bodyHTML),
"flags": msg.Flags, "labels": msg.Labels, "flags": msg.Flags, "labels": msg.Labels,
"in_reply_to": msg.InReplyTo, "references": msg.References, "in_reply_to": msg.InReplyTo, "references": msg.References,
} }
@ -351,7 +375,7 @@ func (s *Service) DeleteMessage(ctx context.Context, externalID, messageID strin
func (s *Service) GetThread(ctx context.Context, externalID, threadID string) (map[string]any, error) { func (s *Service) GetThread(ctx context.Context, externalID, threadID string) (map[string]any, error) {
rows, err := s.db.Query(ctx, ` rows, err := s.db.Query(ctx, `
SELECT m.id, m.subject, m.from_addr, m.date, m.snippet, m.flags SELECT m.id, m.subject, m.from_addr, m.to_addrs, m.cc_addrs, m.date, m.snippet, m.flags, m.labels
FROM messages m JOIN mail_accounts ma ON m.account_id = ma.id FROM messages m JOIN mail_accounts ma ON m.account_id = ma.id
WHERE m.thread_id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2) WHERE m.thread_id = $1 AND ma.user_id = (SELECT id FROM users WHERE external_id = $2)
ORDER BY m.date ASC ORDER BY m.date ASC
@ -364,15 +388,16 @@ func (s *Service) GetThread(ctx context.Context, externalID, threadID string) (m
messages := make([]map[string]any, 0) messages := make([]map[string]any, 0)
for rows.Next() { for rows.Next() {
var id, subject, snippet string var id, subject, snippet string
var from []byte var from, toAddrs, ccAddrs []byte
var date any var date any
var flags []string var flags, labels []string
if err := rows.Scan(&id, &subject, &from, &date, &snippet, &flags); err != nil { if err := rows.Scan(&id, &subject, &from, &toAddrs, &ccAddrs, &date, &snippet, &flags, &labels); err != nil {
return nil, err return nil, err
} }
messages = append(messages, map[string]any{ messages = append(messages, map[string]any{
"id": id, "subject": subject, "from": json.RawMessage(from), "id": id, "subject": subject, "from": json.RawMessage(from),
"date": date, "snippet": snippet, "flags": flags, "to": json.RawMessage(toAddrs), "cc": json.RawMessage(ccAddrs),
"date": date, "snippet": snippet, "flags": flags, "labels": labels,
}) })
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -487,7 +512,7 @@ func (s *Service) ListRules(ctx context.Context, externalID string, params query
} }
rows, err := s.db.Query(ctx, ` rows, err := s.db.Query(ctx, `
SELECT id, name, priority, conditions, actions, is_active, match_count SELECT id, name, priority, conditions, actions, is_active, match_count, rule_kind, workflow
FROM mail_rules WHERE user_id = (SELECT id FROM users WHERE external_id = $1) FROM mail_rules WHERE user_id = (SELECT id FROM users WHERE external_id = $1)
ORDER BY priority ASC ORDER BY priority ASC
LIMIT $2 OFFSET $3 LIMIT $2 OFFSET $3
@ -502,15 +527,18 @@ func (s *Service) ListRules(ctx context.Context, externalID string, params query
var id, name string var id, name string
var priority int var priority int
var conditions, actions []byte var conditions, actions []byte
var workflow []byte
var ruleKind string
var isActive bool var isActive bool
var matchCount int64 var matchCount int64
if err := rows.Scan(&id, &name, &priority, &conditions, &actions, &isActive, &matchCount); err != nil { if err := rows.Scan(&id, &name, &priority, &conditions, &actions, &isActive, &matchCount, &ruleKind, &workflow); err != nil {
return RulesList{}, err return RulesList{}, err
} }
rules = append(rules, map[string]any{ rules = append(rules, map[string]any{
"id": id, "name": name, "priority": priority, "id": id, "name": name, "priority": priority,
"conditions": json.RawMessage(conditions), "actions": json.RawMessage(actions), "conditions": json.RawMessage(conditions), "actions": json.RawMessage(actions),
"is_active": isActive, "match_count": matchCount, "is_active": isActive, "match_count": matchCount,
"rule_kind": ruleKind, "workflow": json.RawMessage(workflow),
}) })
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
@ -525,7 +553,18 @@ func (s *Service) ListRules(ctx context.Context, externalID string, params query
func (s *Service) CreateRule(ctx context.Context, userID string, req *createRuleRequest) (string, error) { func (s *Service) CreateRule(ctx context.Context, userID string, req *createRuleRequest) (string, error) {
condJSON, _ := json.Marshal(req.Conditions) condJSON, _ := json.Marshal(req.Conditions)
if req.Conditions == nil {
condJSON = []byte("[]")
}
actJSON, _ := json.Marshal(req.Actions) actJSON, _ := json.Marshal(req.Actions)
if req.Actions == nil {
actJSON = []byte("[]")
}
wfJSON, _ := json.Marshal(req.Workflow)
ruleKind := req.RuleKind
if ruleKind == "" {
ruleKind = "rule"
}
if req.AccountID != "" { if req.AccountID != "" {
var exists bool var exists bool
@ -542,10 +581,10 @@ func (s *Service) CreateRule(ctx context.Context, userID string, req *createRule
var id string var id string
err := s.db.QueryRow(ctx, ` err := s.db.QueryRow(ctx, `
INSERT INTO mail_rules (user_id, account_id, name, priority, conditions, actions) INSERT INTO mail_rules (user_id, account_id, name, priority, conditions, actions, rule_kind, workflow)
VALUES ($1, $2, $3, $4, $5, $6) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
RETURNING id RETURNING id
`, userID, nilIfEmpty(req.AccountID), req.Name, req.Priority, condJSON, actJSON).Scan(&id) `, userID, nilIfEmpty(req.AccountID), req.Name, req.Priority, condJSON, actJSON, ruleKind, wfJSON).Scan(&id)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -554,12 +593,23 @@ func (s *Service) CreateRule(ctx context.Context, userID string, req *createRule
func (s *Service) UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error { func (s *Service) UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error {
condJSON, _ := json.Marshal(req.Conditions) condJSON, _ := json.Marshal(req.Conditions)
if req.Conditions == nil {
condJSON = []byte("[]")
}
actJSON, _ := json.Marshal(req.Actions) actJSON, _ := json.Marshal(req.Actions)
if req.Actions == nil {
actJSON = []byte("[]")
}
wfJSON, _ := json.Marshal(req.Workflow)
ruleKind := req.RuleKind
if ruleKind == "" {
ruleKind = "rule"
}
result, err := s.db.Exec(ctx, ` result, err := s.db.Exec(ctx, `
UPDATE mail_rules SET name=$1, priority=$2, is_active=$3, conditions=$4, actions=$5, updated_at=NOW() UPDATE mail_rules SET name=$1, priority=$2, is_active=$3, conditions=$4, actions=$5, rule_kind=$6, workflow=$7, updated_at=NOW()
WHERE id=$6 AND user_id=(SELECT id FROM users WHERE external_id=$7) WHERE id=$8 AND user_id=(SELECT id FROM users WHERE external_id=$9)
`, req.Name, req.Priority, req.IsActive, condJSON, actJSON, ruleID, externalID) `, req.Name, req.Priority, req.IsActive, condJSON, actJSON, ruleKind, wfJSON, ruleID, externalID)
if err != nil { if err != nil {
return err return err
} }

View File

@ -7,7 +7,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/listunsubscribe"
) )
// ServiceAPI is the mail handler service boundary. *Service implements it in production. // ServiceAPI is the mail handler service boundary. *Service implements it in production.
@ -26,8 +26,10 @@ type ServiceAPI interface {
UpdateAccount(ctx context.Context, externalID, accountID string, req *updateAccountRequest) error UpdateAccount(ctx context.Context, externalID, accountID string, req *updateAccountRequest) error
CredentialForConnectionTest(ctx context.Context, externalID string, req *testAccountRequest) (credentials.Credential, error) CredentialForConnectionTest(ctx context.Context, externalID string, req *testAccountRequest) (credentials.Credential, error)
DeleteAccount(ctx context.Context, externalID, accountID string) error DeleteAccount(ctx context.Context, externalID, accountID string) error
ResanitizeAccountBodies(ctx context.Context, externalID, accountID string) (ResanitizeBodiesResult, error)
ListMessages(ctx context.Context, externalID string, filter MessageListFilter, params query.ListParams) (MessagesList, error) ListMessages(ctx context.Context, externalID string, filter MessageListFilter, params query.ListParams) (MessagesList, error)
GetMessage(ctx context.Context, externalID, messageID string) (map[string]any, error) GetMessage(ctx context.Context, externalID, messageID string) (map[string]any, error)
SendMailtoListUnsubscribe(ctx context.Context, externalID, messageID string, sender MailSender) (*listunsubscribe.Mailto, error)
UpdateLabels(ctx context.Context, externalID, messageID string, labels []string) error UpdateLabels(ctx context.Context, externalID, messageID string, labels []string) error
UpdateFlags(ctx context.Context, externalID, messageID string, flags []string) error UpdateFlags(ctx context.Context, externalID, messageID string, flags []string) error
DeleteMessage(ctx context.Context, externalID, messageID string) error DeleteMessage(ctx context.Context, externalID, messageID string) error
@ -45,7 +47,7 @@ type ServiceAPI interface {
CreateRule(ctx context.Context, userID string, req *createRuleRequest) (string, error) CreateRule(ctx context.Context, userID string, req *createRuleRequest) (string, error)
UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error UpdateRule(ctx context.Context, externalID, ruleID string, req *updateRuleRequest) error
DeleteRule(ctx context.Context, externalID, ruleID string) error DeleteRule(ctx context.Context, externalID, ruleID string) error
SimulateRule(ctx context.Context, externalID string, req *simulateRuleRequest) (rules.SimulationResult, error) SimulateRule(ctx context.Context, externalID string, req *simulateRuleRequest) (any, error)
ListWebhooks(ctx context.Context, externalID string, params query.ListParams) (WebhooksList, error) ListWebhooks(ctx context.Context, externalID string, params query.ListParams) (WebhooksList, error)
CreateWebhook(ctx context.Context, externalID string, req *createWebhookRequest, method string, maxRetries int) (string, error) CreateWebhook(ctx context.Context, externalID string, req *createWebhookRequest, method string, maxRetries int) (string, error)
UpdateWebhook(ctx context.Context, externalID, webhookID string, req *updateWebhookRequest, method string, maxRetries int) error UpdateWebhook(ctx context.Context, externalID, webhookID string, req *updateWebhookRequest, method string, maxRetries int) error

View File

@ -10,12 +10,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/rules"
) )
func (s *Service) SimulateRule(ctx context.Context, externalID string, req *simulateRuleRequest) (rules.SimulationResult, error) { func (s *Service) SimulateRule(ctx context.Context, externalID string, req *simulateRuleRequest) (any, error) {
conditions, actions, err := s.resolveSimulateRule(ctx, externalID, req)
if err != nil {
return rules.SimulationResult{}, err
}
msg := &rules.Message{ msg := &rules.Message{
ID: "simulation", ID: "simulation",
From: req.Message.From, From: req.Message.From,
@ -23,38 +18,74 @@ func (s *Service) SimulateRule(ctx context.Context, externalID string, req *simu
Subject: req.Message.Subject, Subject: req.Message.Subject,
BodyText: req.Message.BodyText, BodyText: req.Message.BodyText,
HasAttachments: req.Message.HasAttachments, HasAttachments: req.Message.HasAttachments,
Labels: req.Message.Labels,
} }
engine := rules.NewEngine(s.db) engine := rules.NewEngine(s.db)
wf, conditions, actions, err := s.resolveSimulateRulePayload(ctx, externalID, req)
if err != nil {
return nil, err
}
if wf != nil && len(wf.Nodes) > 0 {
var userID string
_ = s.db.QueryRow(ctx, `SELECT id FROM users WHERE external_id = $1`, externalID).Scan(&userID)
return engine.SimulateWorkflow(ctx, userID, wf, msg, &rules.EventContext{Type: rules.TriggerMessageReceived}), nil
}
return engine.SimulateRule(ctx, conditions, actions, msg), nil return engine.SimulateRule(ctx, conditions, actions, msg), nil
} }
func (s *Service) resolveSimulateRule(ctx context.Context, externalID string, req *simulateRuleRequest) ([]rules.Condition, []rules.Action, error) { func (s *Service) resolveSimulateRulePayload(ctx context.Context, externalID string, req *simulateRuleRequest) (*rules.Workflow, []rules.Condition, []rules.Action, error) {
if req.RuleID != "" { if req.RuleID != "" {
var condJSON, actJSON []byte var condJSON, actJSON, wfJSON []byte
err := s.db.QueryRow(ctx, ` err := s.db.QueryRow(ctx, `
SELECT conditions, actions SELECT conditions, actions, workflow
FROM mail_rules FROM mail_rules
WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2) WHERE id = $1 AND user_id = (SELECT id FROM users WHERE external_id = $2)
`, req.RuleID, externalID).Scan(&condJSON, &actJSON) `, req.RuleID, externalID).Scan(&condJSON, &actJSON, &wfJSON)
if err != nil { if err != nil {
if errors.Is(err, pgx.ErrNoRows) { if errors.Is(err, pgx.ErrNoRows) {
return nil, nil, ErrNotFound return nil, nil, nil, ErrNotFound
} }
return nil, nil, err return nil, nil, nil, err
}
wf, err := rules.ParseWorkflow(wfJSON)
if err != nil {
return nil, nil, nil, err
}
if wf != nil && len(wf.Nodes) > 0 {
return wf, nil, nil, nil
}
conditions, actions, err := unmarshalRuleConditionsActions(condJSON, actJSON)
return nil, conditions, actions, err
}
if req.Rule.Workflow != nil {
wfJSON, err := json.Marshal(req.Rule.Workflow)
if err != nil {
return nil, nil, nil, err
}
wf, err := rules.ParseWorkflow(wfJSON)
if err != nil {
return nil, nil, nil, err
}
if wf != nil && len(wf.Nodes) > 0 {
return wf, nil, nil, nil
} }
return unmarshalRuleConditionsActions(condJSON, actJSON)
} }
condJSON, err := json.Marshal(req.Rule.Conditions) condJSON, err := json.Marshal(req.Rule.Conditions)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
actJSON, err := json.Marshal(req.Rule.Actions) actJSON, err := json.Marshal(req.Rule.Actions)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, nil, err
} }
return unmarshalRuleConditionsActions(condJSON, actJSON) conditions, actions, err := unmarshalRuleConditionsActions(condJSON, actJSON)
return nil, conditions, actions, err
} }
func unmarshalRuleConditionsActions(condJSON, actJSON []byte) ([]rules.Condition, []rules.Action, error) { func unmarshalRuleConditionsActions(condJSON, actJSON []byte) ([]rules.Condition, []rules.Action, error) {

View File

@ -486,8 +486,10 @@ type createRuleRequest struct {
Name string `json:"name"` Name string `json:"name"`
AccountID string `json:"account_id"` AccountID string `json:"account_id"`
Priority int `json:"priority"` Priority int `json:"priority"`
RuleKind string `json:"rule_kind"`
Conditions any `json:"conditions"` Conditions any `json:"conditions"`
Actions any `json:"actions"` Actions any `json:"actions"`
Workflow any `json:"workflow"`
} }
func validateCreateRule(req *createRuleRequest) *apivalidate.ValidationError { func validateCreateRule(req *createRuleRequest) *apivalidate.ValidationError {
@ -497,11 +499,17 @@ func validateCreateRule(req *createRuleRequest) *apivalidate.ValidationError {
} else if len(req.Name) > maxRuleName { } else if len(req.Name) > maxRuleName {
details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"}) details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"})
} }
if req.Conditions == nil { hasWorkflow := req.Workflow != nil
details = append(details, apivalidate.FieldDetail{Field: "conditions", Message: "required"}) if !hasWorkflow {
if req.Conditions == nil {
details = append(details, apivalidate.FieldDetail{Field: "conditions", Message: "required"})
}
if req.Actions == nil {
details = append(details, apivalidate.FieldDetail{Field: "actions", Message: "required"})
}
} }
if req.Actions == nil { if req.RuleKind != "" && req.RuleKind != "rule" && req.RuleKind != "function" {
details = append(details, apivalidate.FieldDetail{Field: "actions", Message: "required"}) details = append(details, apivalidate.FieldDetail{Field: "rule_kind", Message: "invalid"})
} }
if len(details) == 0 { if len(details) == 0 {
return nil return nil
@ -513,8 +521,10 @@ type updateRuleRequest struct {
Name string `json:"name"` Name string `json:"name"`
Priority int `json:"priority"` Priority int `json:"priority"`
IsActive bool `json:"is_active"` IsActive bool `json:"is_active"`
RuleKind string `json:"rule_kind"`
Conditions any `json:"conditions"` Conditions any `json:"conditions"`
Actions any `json:"actions"` Actions any `json:"actions"`
Workflow any `json:"workflow"`
} }
type simulateRuleSampleMessage struct { type simulateRuleSampleMessage struct {
@ -523,11 +533,13 @@ type simulateRuleSampleMessage struct {
Subject string `json:"subject"` Subject string `json:"subject"`
BodyText string `json:"body_text"` BodyText string `json:"body_text"`
HasAttachments bool `json:"has_attachments"` HasAttachments bool `json:"has_attachments"`
Labels []string `json:"labels,omitempty"`
} }
type simulateRuleInlineRule struct { type simulateRuleInlineRule struct {
Conditions any `json:"conditions"` Conditions any `json:"conditions"`
Actions any `json:"actions"` Actions any `json:"actions"`
Workflow any `json:"workflow"`
} }
type simulateRuleRequest struct { type simulateRuleRequest struct {
@ -550,10 +562,10 @@ func validateSimulateRule(req *simulateRuleRequest) *apivalidate.ValidationError
details = append(details, apivalidate.FieldDetail{Field: "rule_id", Message: "rule_id or rule required"}) details = append(details, apivalidate.FieldDetail{Field: "rule_id", Message: "rule_id or rule required"})
} }
if hasInlineRule { if hasInlineRule {
if req.Rule.Conditions == nil { if req.Rule.Conditions == nil && req.Rule.Workflow == nil {
details = append(details, apivalidate.FieldDetail{Field: "rule.conditions", Message: "required"}) details = append(details, apivalidate.FieldDetail{Field: "rule.conditions", Message: "required"})
} }
if req.Rule.Actions == nil { if req.Rule.Actions == nil && req.Rule.Workflow == nil {
details = append(details, apivalidate.FieldDetail{Field: "rule.actions", Message: "required"}) details = append(details, apivalidate.FieldDetail{Field: "rule.actions", Message: "required"})
} }
} }
@ -570,12 +582,15 @@ func validateUpdateRule(req *updateRuleRequest) *apivalidate.ValidationError {
} else if len(req.Name) > maxRuleName { } else if len(req.Name) > maxRuleName {
details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"}) details = append(details, apivalidate.FieldDetail{Field: "name", Message: "too long"})
} }
if req.Conditions == nil { if req.Conditions == nil && req.Workflow == nil {
details = append(details, apivalidate.FieldDetail{Field: "conditions", Message: "required"}) details = append(details, apivalidate.FieldDetail{Field: "conditions", Message: "required"})
} }
if req.Actions == nil { if req.Actions == nil && req.Workflow == nil {
details = append(details, apivalidate.FieldDetail{Field: "actions", Message: "required"}) details = append(details, apivalidate.FieldDetail{Field: "actions", Message: "required"})
} }
if req.RuleKind != "" && req.RuleKind != "rule" && req.RuleKind != "function" {
details = append(details, apivalidate.FieldDetail{Field: "rule_kind", Message: "invalid"})
}
if len(details) == 0 { if len(details) == 0 {
return nil return nil
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/api/apiresponse"
"github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/permission"
"github.com/ultisuite/ulti-backend/internal/securityaudit" "github.com/ultisuite/ulti-backend/internal/securityaudit"
"github.com/ultisuite/ulti-backend/internal/users" "github.com/ultisuite/ulti-backend/internal/users"
) )
@ -71,6 +72,7 @@ func Auth(verifier *auth.Verifier, db *pgxpool.Pool, audit *securityaudit.Logger
} }
return return
} }
claims.Groups = permission.WithSuiteDefaults(claims.Groups)
if db != nil { if db != nil {
if _, err := users.EnsureUser(r.Context(), db, claims); err != nil { if _, err := users.EnsureUser(r.Context(), db, claims); err != nil {

View File

@ -12,7 +12,7 @@ import (
const ( const (
DefaultPage = 1 DefaultPage = 1
DefaultPageSize = 50 DefaultPageSize = 50
MaxPageSize = 200 MaxPageSize = 500
dateLayout = "2006-01-02" dateLayout = "2006-01-02"
) )

View File

@ -92,7 +92,7 @@ func TestParseList_invalidPageSize(t *testing.T) {
}{ }{
{"zero", "0"}, {"zero", "0"},
{"negative", "-5"}, {"negative", "-5"},
{"too_large", "201"}, {"too_large", "501"},
{"non_numeric", "large"}, {"non_numeric", "large"},
} }
@ -138,7 +138,7 @@ func TestParseList_invalidDates(t *testing.T) {
func TestParseList_multipleErrors(t *testing.T) { func TestParseList_multipleErrors(t *testing.T) {
_, err := ParseList(url.Values{ _, err := ParseList(url.Values{
"page": {"0"}, "page": {"0"},
"page_size": {"500"}, "page_size": {"501"},
"from": {"bad-date"}, "from": {"bad-date"},
}) })
var verr *ValidationError var verr *ValidationError

View File

@ -0,0 +1,187 @@
package imap
import (
"io"
"mime/quotedprintable"
"strings"
"unicode"
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
)
const minBareBase64Len = 24
// RepairStoredBodies fixes bodies stored as raw MIME, quoted-printable, or base64.
func RepairStoredBodies(text, html string) (string, string) {
text, html = repairRawMIME(text, html)
text = decodeBareQuotedPrintableIfNeeded(text)
html = decodeBareQuotedPrintableIfNeeded(html)
text = decodeBareBase64IfNeeded(text)
html = decodeBareBase64IfNeeded(html)
text = stripPlainTextPreheaderPadding(text)
return text, html
}
func repairRawMIME(text, html string) (string, string) {
if !looksLikeRawMIME(text) && !looksLikeRawMIME(html) {
return text, html
}
raw := text
if raw == "" {
raw = html
}
t, h := parseBody([]byte(raw))
if t == "" && h == "" {
return text, html
}
if looksLikeRawMIME(t) || looksLikeRawMIME(h) {
return text, html
}
return t, h
}
// RepairSnippet fixes list/search previews stored as undecoded base64 or raw MIME.
func RepairSnippet(snippet string) string {
if snippet == "" {
return snippet
}
if decoded := decodeBareQuotedPrintableIfNeeded(snippet); decoded != snippet {
snippet = decoded
}
if decoded := decodeBareBase64IfNeeded(snippet); decoded != snippet {
snippet = decoded
}
snippet = stripPlainTextPreheaderPadding(snippet)
if looksLikeRawMIME(snippet) {
t, h, ok := parseEmbeddedMIME([]byte(snippet))
if ok {
return SnippetFromBodies(t, h, 200)
}
}
return snippet
}
// SnippetFromBodies builds a short preview from repaired plain/html bodies.
func SnippetFromBodies(text, html string, maxLen int) string {
text = strings.TrimSpace(text)
if text != "" {
return truncate(text, maxLen)
}
html = strings.TrimSpace(stripHTMLForSnippet(html))
if html != "" {
return truncate(html, maxLen)
}
return ""
}
func stripPlainTextPreheaderPadding(text string) string {
return sanitize.StripInvisibleTextRuns(text)
}
func stripHTMLForSnippet(html string) string {
if html == "" {
return ""
}
html = sanitize.StripHiddenEmailHTML(html)
var b strings.Builder
inTag := false
for _, r := range html {
switch {
case r == '<':
inTag = true
case r == '>':
inTag = false
case !inTag && r != '\r':
if r == '\n' {
if b.Len() > 0 && b.String()[b.Len()-1] != ' ' {
b.WriteRune(' ')
}
} else if !unicode.IsControl(r) {
b.WriteRune(r)
}
}
}
return sanitize.StripInvisibleTextRuns(strings.Join(strings.Fields(b.String()), " "))
}
func decodeBareQuotedPrintableIfNeeded(s string) string {
if s == "" || !looksLikeQuotedPrintable(s) {
return s
}
decoded, err := io.ReadAll(quotedprintable.NewReader(strings.NewReader(s)))
if err != nil || len(decoded) == 0 || !isMostlyReadableText(decoded) {
return s
}
return string(decoded)
}
func looksLikeQuotedPrintable(s string) bool {
if strings.Contains(s, "=\r\n") || strings.Contains(s, "=\n") {
return true
}
if strings.Contains(s, "=3D") || strings.Contains(s, "=C3=") || strings.Contains(s, "=E2=") {
return true
}
return len(qpHexSeqRE.FindAllString(s, -1)) >= 3
}
func decodeBareBase64IfNeeded(s string) string {
if s == "" {
return s
}
trimmed := strings.TrimSpace(s)
if len(trimmed) < minBareBase64Len {
return s
}
clean := stripBase64Whitespace(trimmed)
if !isLikelyBase64(clean) {
return s
}
decoded, err := decodeBase64Body([]byte(clean))
if err != nil || len(decoded) == 0 || !isMostlyReadableText(decoded) {
return s
}
return string(decoded)
}
func stripBase64Whitespace(s string) string {
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
switch r {
case '\r', '\n', ' ', '\t':
continue
default:
b.WriteRune(r)
}
}
return b.String()
}
func isLikelyBase64(s string) bool {
if len(s) < minBareBase64Len || len(s)%4 != 0 {
return false
}
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z', r >= 'a' && r <= 'z', r >= '0' && r <= '9', r == '+', r == '/', r == '=':
continue
default:
return false
}
}
return strings.Contains(s, "=") || len(s) >= 32
}
func isMostlyReadableText(b []byte) bool {
if len(b) == 0 {
return false
}
printable := 0
for _, c := range b {
if c == '\n' || c == '\r' || c == '\t' || (c >= 32 && c < 127) || c >= 0xc0 {
printable++
}
}
return float64(printable)/float64(len(b)) >= 0.85
}

View File

@ -0,0 +1,100 @@
package imap
import (
"strings"
"testing"
)
func TestDecodeBareBase64IfNeeded_samsungMessage(t *testing.T) {
const encoded = "U0FNU1VORwpSw6lzZXJ2w6kgYXV4IHByb2Zlc3Npb25uZWxzCgrigIoKVm9zIMOpcXVpcGVzIG9u\r\n" +
"dCBiZXNvaW4KZGUgc29sdXRpb25zIG1vYmlsZXMKZXQgcm9idXN0ZXMuCgpTYW1zdW5nIFBybyBy\r\n" +
"w6lwb25kIGF1eCBtw6l0aWVycyBkZSBsYSBjb25zdHJ1"
decoded := decodeBareBase64IfNeeded(encoded)
if decoded == encoded {
t.Fatal("expected base64 decode")
}
if !strings.HasPrefix(decoded, "SAMSUNG") {
t.Fatalf("decoded = %q", decoded)
}
if !strings.Contains(decoded, "professionnels") {
t.Fatalf("decoded = %q, want utf-8 text", decoded)
}
}
func TestDecodeBareQuotedPrintableIfNeeded_frenchMarketing(t *testing.T) {
const qp = "Hello = Eliott,\n\nNous pouvons faire appara=C3=AEtre votre marque en premi=C3=A8re =\n" +
" position dans les Google Suggests"
decoded := decodeBareQuotedPrintableIfNeeded(qp)
if decoded == qp {
t.Fatal("expected quoted-printable decode")
}
if !strings.Contains(decoded, "apparaître") {
t.Fatalf("decoded = %q, want apparaître", decoded)
}
if strings.Contains(decoded, "=C3=") {
t.Fatalf("still contains qp escapes: %q", decoded)
}
}
func TestRepairStoredBodies_quotedPrintableHTML(t *testing.T) {
const qpHTML = `<html><body><div style=3D"color: rgb(65, 65, 65);">appara=C3=AEtre</div></body></html>`
_, html := RepairStoredBodies("", qpHTML)
if !strings.Contains(html, "apparaître") {
t.Fatalf("html = %q", html)
}
if strings.Contains(html, "=3D") {
t.Fatal("html still quoted-printable encoded")
}
}
func TestRepairSnippet_truncatedBase64Preview(t *testing.T) {
snippet := truncate(
"U0FNU1VORwpSw6lzZXJ2w6kgYXV4IHByb2Zlc3Npb25uZWxzCgrigIoKVm9zIMOpcXVpcGVzIG9u"+
"dCBiZXNvaW4KZGUgc29sdXRpb25zIG1vYmlsZXMKZXQgcm9idXN0ZXMuCgpTYW1zdW5nIFBybyBy"+
"w6lwb25kIGF1eCBtw6l0aWVycyBkZSBsYSBjb25zdHJ1Y3Rpb24uCgrigIoKPiBEw6ljb3V2cmly",
200,
)
repaired := RepairSnippet(snippet)
if !strings.HasPrefix(repaired, "SAMSUNG") {
t.Fatalf("snippet = %q, want decoded preview", repaired)
}
}
func TestRepairStoredBodies_base64HTML(t *testing.T) {
const encodedHTML = "PCFET0NUWVBFIGh0bWwgUFVCTElDICItLy9XM0MvL0RURCBYSFRNTCAxLjAgVHJhbnNpdGlvbmFs"
_, html := RepairStoredBodies("", encodedHTML)
if !strings.HasPrefix(html, "<!DOCTYPE") && !strings.HasPrefix(html, "<!doctype") {
t.Fatalf("html = %q, want decoded doctype", html)
}
}
func TestRepairStoredBodies_rawMIMEInDB(t *testing.T) {
raw := string(buildMultipartMessage(t, "alternative", []mimePart{
{
contentType: "text/plain; charset=utf-8",
body: []byte("Stored fix"),
transferEnc: "base64",
},
}))
repairedText, repairedHTML := RepairStoredBodies(raw, "")
if repairedText != "Stored fix" {
t.Fatalf("repaired text = %q, want Stored fix", repairedText)
}
if repairedHTML != "" {
t.Fatalf("repaired html = %q, want empty", repairedHTML)
}
}
func TestRepairSubject_brokenSymbolUsesBodyFallback(t *testing.T) {
qpHTML := `<html><body><p>Hello Eliott, Nous pouvons faire apparaître votre marque.</p></body></html>`
subject := RepairSubject("▱", "", qpHTML, nil)
if subjectLooksBroken(subject) {
t.Fatalf("subject = %q, want readable fallback", subject)
}
if !strings.Contains(subject, "Hello") {
t.Fatalf("subject = %q", subject)
}
}

View File

@ -108,3 +108,11 @@ func mailboxLeaf(mailbox string) string {
} }
return leaf return leaf
} }
// FolderDerivedLabels returns Ultimail labels inferred from IMAP mailbox path/name.
func FolderDerivedLabels(mailbox string) []string {
if strings.ToLower(mailboxLeaf(mailbox)) == "important" {
return []string{"important"}
}
return nil
}

View File

@ -0,0 +1,150 @@
package imap
import (
"bytes"
"encoding/json"
"net/mail"
"regexp"
"strings"
"github.com/emersion/go-imap/v2"
)
// MessageAuthInfo is persisted in messages.auth_info (JSON).
type MessageAuthInfo struct {
MailedBy string `json:"mailed_by,omitempty"`
SignedBy string `json:"signed_by,omitempty"`
DKIMPass *bool `json:"dkim_pass,omitempty"`
TLS bool `json:"tls,omitempty"`
ListUnsubscribe string `json:"list_unsubscribe,omitempty"`
}
var (
dkimDomainRe = regexp.MustCompile(`(?i)header\.d=([^\s;]+)`)
dkimSigDRe = regexp.MustCompile(`(?i)\bd=([^;\s]+)`)
returnPathRe = regexp.MustCompile(`(?i)<([^>]+)>`)
receivedFromRe = regexp.MustCompile(`(?i)from\s+([^\s;(\[]+)`)
)
func parseMessageMeta(raw []byte, envelope *imap.Envelope) (replyToJSON, authJSON []byte) {
auth := MessageAuthInfo{}
replyTo := replyAddresses(envelope, raw)
if len(raw) > 0 {
msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err == nil {
if len(replyTo) == 0 {
replyTo = parseAddressListHeader(msg.Header.Get("Reply-To"))
}
auth.ListUnsubscribe = strings.TrimSpace(msg.Header.Get("List-Unsubscribe"))
mergeAuthFromHeaders(&auth, msg)
}
}
if auth.MailedBy == "" && len(envelope.From) > 0 {
auth.MailedBy = domainFromAddr(envelope.From[0].Addr())
}
if auth.SignedBy == "" && auth.MailedBy != "" {
auth.SignedBy = auth.MailedBy
}
authJSON, _ = json.Marshal(auth)
replyToJSON, _ = json.Marshal(replyTo)
return replyToJSON, authJSON
}
func replyAddresses(envelope *imap.Envelope, raw []byte) []EmailAddress {
if len(envelope.ReplyTo) > 0 {
return imapAddressesToEmail(envelope.ReplyTo)
}
if len(raw) == 0 {
return nil
}
msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return nil
}
return parseAddressListHeader(msg.Header.Get("Reply-To"))
}
func imapAddressesToEmail(addrs []imap.Address) []EmailAddress {
out := make([]EmailAddress, 0, len(addrs))
for _, a := range addrs {
out = append(out, EmailAddress{Name: a.Name, Address: a.Addr()})
}
return out
}
func parseAddressListHeader(header string) []EmailAddress {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
parsed, err := mail.ParseAddressList(header)
if err != nil {
return nil
}
out := make([]EmailAddress, 0, len(parsed))
for _, a := range parsed {
out = append(out, EmailAddress{Name: a.Name, Address: a.Address})
}
return out
}
func headerValues(h mail.Header, key string) []string {
return h[key]
}
func mergeAuthFromHeaders(auth *MessageAuthInfo, msg *mail.Message) {
for _, line := range headerValues(msg.Header, "Authentication-Results") {
lower := strings.ToLower(line)
if strings.Contains(lower, "dkim=pass") {
pass := true
auth.DKIMPass = &pass
if m := dkimDomainRe.FindStringSubmatch(line); len(m) > 1 && auth.SignedBy == "" {
auth.SignedBy = strings.Trim(m[1], `"'`)
}
} else if strings.Contains(lower, "dkim=fail") {
fail := false
auth.DKIMPass = &fail
}
if strings.Contains(lower, "tls=1") || strings.Contains(lower, "version=tls") {
auth.TLS = true
}
}
if auth.SignedBy == "" {
if sig := msg.Header.Get("DKIM-Signature"); sig != "" {
if m := dkimSigDRe.FindStringSubmatch(sig); len(m) > 1 {
auth.SignedBy = strings.Trim(m[1], `"'`)
}
}
}
if rp := msg.Header.Get("Return-Path"); rp != "" {
if m := returnPathRe.FindStringSubmatch(rp); len(m) > 1 {
auth.MailedBy = domainFromAddr(m[1])
}
}
for _, recv := range headerValues(msg.Header, "Received") {
lower := strings.ToLower(recv)
if strings.Contains(lower, "esmtps") || strings.Contains(lower, "tls") {
auth.TLS = true
}
if auth.MailedBy == "" {
if m := receivedFromRe.FindStringSubmatch(recv); len(m) > 1 {
auth.MailedBy = domainFromAddr(m[1])
}
}
}
}
func domainFromAddr(addr string) string {
addr = strings.Trim(addr, "<>")
if i := strings.LastIndex(addr, "@"); i >= 0 && i < len(addr)-1 {
return strings.ToLower(addr[i+1:])
}
host := strings.TrimSpace(addr)
if strings.Contains(host, ".") {
return strings.ToLower(host)
}
return ""
}

View File

@ -0,0 +1,33 @@
package imap
import (
"bytes"
"net/mail"
"strings"
"testing"
)
func Test_mergeAuthFromHeaders_dkimAndTLS(t *testing.T) {
raw := strings.Join([]string{
"From: Sender <sender@example.com>",
"Authentication-Results: mx.example.com; dkim=pass header.d=mail.example.com",
"Received: from mail.example.com (mail.example.com [1.2.3.4]) by mx with ESMTPS",
"",
"Body",
}, "\r\n")
msg, err := mail.ReadMessage(bytes.NewReader([]byte(raw)))
if err != nil {
t.Fatal(err)
}
var auth MessageAuthInfo
mergeAuthFromHeaders(&auth, msg)
if auth.DKIMPass == nil || !*auth.DKIMPass {
t.Fatalf("dkim_pass = %v, want true", auth.DKIMPass)
}
if auth.SignedBy != "mail.example.com" {
t.Fatalf("signed_by = %q", auth.SignedBy)
}
if !auth.TLS {
t.Fatal("expected tls true")
}
}

View File

@ -7,6 +7,7 @@ import (
"mime" "mime"
"mime/multipart" "mime/multipart"
"net/mail" "net/mail"
"regexp"
"strings" "strings"
imapTypes "github.com/emersion/go-imap/v2" imapTypes "github.com/emersion/go-imap/v2"
@ -19,6 +20,8 @@ type EmailAddress struct {
Address string `json:"address"` Address string `json:"address"`
} }
var mimeBoundaryParamRE = regexp.MustCompile(`(?i)boundary\s*=\s*"?([^";\s]+)"?`)
func addressesToJSON(addrs []imapTypes.Address) []byte { func addressesToJSON(addrs []imapTypes.Address) []byte {
result := make([]EmailAddress, 0, len(addrs)) result := make([]EmailAddress, 0, len(addrs))
for _, a := range addrs { for _, a := range addrs {
@ -36,9 +39,33 @@ func parseBody(raw []byte) (text string, html string) {
return "", "" return "", ""
} }
text, html = parseBodyFromRFC822(raw)
if text != "" || html != "" {
if !looksLikeRawMIME(text) && !looksLikeRawMIME(html) {
return finalizeDecodedBody(text), finalizeDecodedBody(html)
}
}
if t, h, ok := parseEmbeddedMIME(raw); ok {
return finalizeDecodedBody(t), finalizeDecodedBody(h)
}
if text != "" || html != "" {
return finalizeDecodedBody(text), finalizeDecodedBody(html)
}
fallback := string(raw)
return finalizeDecodedBody(fallback), ""
}
func finalizeDecodedBody(s string) string {
s = decodeBareQuotedPrintableIfNeeded(s)
return decodeBareBase64IfNeeded(s)
}
func parseBodyFromRFC822(raw []byte) (text string, html string) {
msg, err := mail.ReadMessage(bytes.NewReader(raw)) msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil { if err != nil {
return string(raw), "" return "", ""
} }
contentType := msg.Header.Get("Content-Type") contentType := msg.Header.Get("Content-Type")
@ -48,7 +75,7 @@ func parseBody(raw []byte) (text string, html string) {
mediaType, params, err := mime.ParseMediaType(contentType) mediaType, params, err := mime.ParseMediaType(contentType)
if err != nil { if err != nil {
body, _ := io.ReadAll(msg.Body) body, _ := readDecodedBody(msg.Body, msg.Header.Get("Content-Transfer-Encoding"))
return string(body), "" return string(body), ""
} }
@ -56,14 +83,23 @@ func parseBody(raw []byte) (text string, html string) {
return parseMultipart(msg.Body, params["boundary"]) return parseMultipart(msg.Body, params["boundary"])
} }
body, _ := io.ReadAll(msg.Body) body, _ := readDecodedBody(msg.Body, msg.Header.Get("Content-Transfer-Encoding"))
if mediaType == "text/html" { if mediaType == "text/html" {
return "", string(body) return "", string(body)
} }
return string(body), "" outText := string(body)
if looksLikeEmbeddedMIME(raw) {
if t, h, ok := parseEmbeddedMIME(raw); ok {
return t, h
}
}
return outText, ""
} }
func parseMultipart(r io.Reader, boundary string) (text string, html string) { func parseMultipart(r io.Reader, boundary string) (text string, html string) {
if boundary == "" {
return "", ""
}
mr := multipart.NewReader(r, boundary) mr := multipart.NewReader(r, boundary)
for { for {
part, err := mr.NextPart() part, err := mr.NextPart()
@ -76,17 +112,21 @@ func parseMultipart(r io.Reader, boundary string) (text string, html string) {
switch { switch {
case mediaType == "text/plain": case mediaType == "text/plain":
body, _ := io.ReadAll(part) body, _ := readDecodedBody(part, part.Header.Get("Content-Transfer-Encoding"))
text = string(body) if text == "" {
text = string(body)
}
case mediaType == "text/html": case mediaType == "text/html":
body, _ := io.ReadAll(part) body, _ := readDecodedBody(part, part.Header.Get("Content-Transfer-Encoding"))
html = string(body) if len(body) > 0 {
html = string(body)
}
case strings.HasPrefix(mediaType, "multipart/"): case strings.HasPrefix(mediaType, "multipart/"):
t, h := parseMultipart(part, params["boundary"]) t, h := parseMultipart(part, params["boundary"])
if text == "" { if text == "" {
text = t text = t
} }
if html == "" { if html == "" && h != "" {
html = h html = h
} }
} }
@ -94,6 +134,105 @@ func parseMultipart(r io.Reader, boundary string) (text string, html string) {
return text, html return text, html
} }
func readDecodedBody(r io.Reader, transferEncoding string) ([]byte, error) {
data, err := io.ReadAll(r)
if err != nil {
return nil, err
}
return decodePartBody(transferEncoding, data)
}
func parseEmbeddedMIME(raw []byte) (text string, html string, ok bool) {
if !looksLikeEmbeddedMIME(raw) {
return "", "", false
}
boundary := boundaryFromMIMEBytes(raw)
if boundary == "" {
return "", "", false
}
text, html = parseMultipart(bytes.NewReader(raw), boundary)
if text == "" && html == "" {
return "", "", false
}
if looksLikeRawMIME(text) || looksLikeRawMIME(html) {
return "", "", false
}
return text, html, true
}
func looksLikeEmbeddedMIME(raw []byte) bool {
s := string(raw)
if !strings.Contains(s, "Content-Type:") {
return false
}
return strings.Contains(s, "Content-Transfer-Encoding:") ||
strings.Contains(strings.ToLower(s), "multipart/") ||
strings.Contains(s, "This is a multi-part message in MIME format")
}
func looksLikeRawMIME(s string) bool {
if s == "" {
return false
}
if !strings.Contains(s, "Content-Type:") {
return false
}
return strings.Contains(s, "Content-Transfer-Encoding:") ||
strings.Contains(s, "--") && strings.Contains(strings.ToLower(s), "multipart")
}
func boundaryFromMIMEBytes(raw []byte) string {
if m := mimeBoundaryParamRE.FindSubmatch(raw); len(m) >= 3 {
return strings.Trim(string(m[2]), `"`)
}
return detectBoundaryDelimiter(raw)
}
func detectBoundaryDelimiter(raw []byte) string {
for _, line := range bytes.Split(raw, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) < 4 || line[0] != '-' || line[1] != '-' {
continue
}
if line[len(line)-1] == '-' && line[len(line)-2] == '-' {
continue
}
b := strings.TrimPrefix(string(line), "--")
b = strings.TrimSpace(b)
if b != "" && !strings.Contains(b, " ") {
return b
}
}
return ""
}
func parseFromHeader(raw []byte) []EmailAddress {
if len(raw) == 0 {
return nil
}
msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return nil
}
fromHdr := strings.TrimSpace(msg.Header.Get("From"))
if fromHdr == "" {
return nil
}
parsed, err := mail.ParseAddressList(fromHdr)
if err != nil || len(parsed) == 0 {
if id := threading.NormalizeMessageID(fromHdr); strings.Contains(fromHdr, "@") {
addr := strings.Trim(id, "<>")
return []EmailAddress{{Address: addr}}
}
return nil
}
out := make([]EmailAddress, 0, len(parsed))
for _, a := range parsed {
out = append(out, EmailAddress{Name: a.Name, Address: a.Address})
}
return out
}
func parseThreadHeaders(raw []byte) (references []string, inReplyTo string) { func parseThreadHeaders(raw []byte) (references []string, inReplyTo string) {
if len(raw) == 0 { if len(raw) == 0 {
return nil, "" return nil, ""
@ -106,3 +245,7 @@ func parseThreadHeaders(raw []byte) (references []string, inReplyTo string) {
irt := strings.TrimSpace(msg.Header.Get("In-Reply-To")) irt := strings.TrimSpace(msg.Header.Get("In-Reply-To"))
return threading.ParseMessageIDs(refs), threading.NormalizeMessageID(irt) return threading.ParseMessageIDs(refs), threading.NormalizeMessageID(irt)
} }
func toValidUTF8(s string) string {
return strings.ToValidUTF8(s, "")
}

View File

@ -0,0 +1,75 @@
package imap
import (
"strings"
"testing"
)
func TestParseBody_multipartAlternativeBase64(t *testing.T) {
raw := buildMultipartMessage(t, "alternative", []mimePart{
{
contentType: "text/plain; charset=utf-8",
body: []byte("SAMSUNG\nRéserver aux professionnels"),
transferEnc: "base64",
},
{
contentType: "text/html; charset=utf-8",
body: []byte("<p>SAMSUNG</p><p>Réserver aux professionnels</p>"),
transferEnc: "base64",
},
})
text, html := parseBody(raw)
if !strings.Contains(text, "SAMSUNG") {
t.Fatalf("text = %q, want decoded plain text", text)
}
if !strings.Contains(text, "professionnels") {
t.Fatalf("text = %q, want utf-8 decoded content", text)
}
if !strings.Contains(html, "<p>SAMSUNG</p>") {
t.Fatalf("html = %q, want decoded html", html)
}
if looksLikeRawMIME(text) || looksLikeRawMIME(html) {
t.Fatal("parseBody returned raw MIME")
}
}
func TestParseBody_headerlessMultipartBase64(t *testing.T) {
withHeaders := buildMultipartMessage(t, "alternative", []mimePart{
{
contentType: "text/plain; charset=utf-8",
body: []byte("Hello MIME"),
transferEnc: "base64",
},
})
// Drop RFC822 headers — simulates IMAP body fetch without outer Content-Type.
idx := strings.Index(string(withHeaders), "\r\n\r\n")
if idx < 0 {
t.Fatal("missing header/body separator")
}
raw := withHeaders[idx+4:]
text, _ := parseBody(raw)
if text != "Hello MIME" {
t.Fatalf("text = %q, want Hello MIME", text)
}
}
func TestParseBody_singlePartBase64(t *testing.T) {
var b strings.Builder
b.WriteString("From: a@b.com\r\n")
b.WriteString("To: c@d.com\r\n")
b.WriteString("Subject: test\r\n")
b.WriteString("Content-Type: text/plain; charset=utf-8\r\n")
b.WriteString("Content-Transfer-Encoding: base64\r\n")
b.WriteString("\r\n")
b.WriteString("SGVsbG8gYmFzZTY0") // "Hello base64"
text, html := parseBody([]byte(b.String()))
if text != "Hello base64" {
t.Fatalf("text = %q, want Hello base64", text)
}
if html != "" {
t.Fatalf("html = %q, want empty", html)
}
}

View File

@ -70,11 +70,14 @@ func (p *syncPipeline) loadRuleMessage(ctx context.Context, messageID string) (*
subject string subject string
bodyText string bodyText string
hasAtt bool hasAtt bool
accountID string
folderID *string
labels []string
) )
err := p.db.QueryRow(ctx, ` err := p.db.QueryRow(ctx, `
SELECT from_addr, to_addrs, subject, body_text, has_attachments SELECT from_addr, to_addrs, subject, body_text, has_attachments, account_id, folder_id, labels
FROM messages WHERE id = $1 FROM messages WHERE id = $1
`, messageID).Scan(&fromJSON, &toJSON, &subject, &bodyText, &hasAtt) `, messageID).Scan(&fromJSON, &toJSON, &subject, &bodyText, &hasAtt, &accountID, &folderID, &labels)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -82,14 +85,20 @@ func (p *syncPipeline) loadRuleMessage(ctx context.Context, messageID string) (*
from := firstAddressString(fromJSON) from := firstAddressString(fromJSON)
to := addressListStrings(toJSON) to := addressListStrings(toJSON)
return &rules.Message{ msg := &rules.Message{
ID: messageID, ID: messageID,
From: from, From: from,
To: to, To: to,
Subject: subject, Subject: subject,
BodyText: bodyText, BodyText: bodyText,
HasAttachments: hasAtt, HasAttachments: hasAtt,
}, nil AccountID: accountID,
Labels: labels,
}
if folderID != nil {
msg.FolderID = *folderID
}
return msg, nil
} }
func firstAddressString(fromJSON []byte) string { func firstAddressString(fromJSON []byte) string {

View File

@ -0,0 +1,100 @@
package imap
import (
"bytes"
"mime"
"net/mail"
"regexp"
"strings"
"unicode"
)
var (
htmlTitleRE = regexp.MustCompile(`(?is)<title[^>]*>\s*([^<]+?)\s*</title>`)
qpHexSeqRE = regexp.MustCompile(`=[0-9A-Fa-f]{2}`)
)
var mimeWordDecoder mime.WordDecoder
// RepairSubject decodes RFC 2047 / broken envelope subjects using headers or body fallbacks.
func RepairSubject(subject string, bodyText, bodyHTML string, raw []byte) string {
if s := decodeMIMEHeaderValue(subject); !subjectLooksBroken(s) {
return s
}
if len(raw) > 0 {
if hdr := subjectFromRawMessage(raw); hdr != "" && !subjectLooksBroken(hdr) {
return hdr
}
}
decodedHTML := decodeBareQuotedPrintableIfNeeded(bodyHTML)
decodedText := decodeBareQuotedPrintableIfNeeded(bodyText)
if t := extractSubjectFromHTML(decodedHTML); t != "" {
return t
}
if fallback := subjectFromBodyFallback(decodedText, decodedHTML); fallback != "" {
return fallback
}
return decodeMIMEHeaderValue(subject)
}
func decodeMIMEHeaderValue(s string) string {
s = strings.TrimSpace(s)
if s == "" {
return s
}
dec, err := mimeWordDecoder.DecodeHeader(s)
if err != nil {
return toValidUTF8(s)
}
return toValidUTF8(dec)
}
func subjectFromRawMessage(raw []byte) string {
msg, err := mail.ReadMessage(bytes.NewReader(raw))
if err != nil {
return ""
}
return decodeMIMEHeaderValue(msg.Header.Get("Subject"))
}
func extractSubjectFromHTML(html string) string {
html = strings.TrimSpace(html)
if html == "" {
return ""
}
if m := htmlTitleRE.FindStringSubmatch(html); len(m) > 1 {
t := strings.TrimSpace(m[1])
if t != "" && !subjectLooksBroken(t) {
return t
}
}
return ""
}
func subjectFromBodyFallback(text, html string) string {
plain := strings.TrimSpace(text)
if plain == "" {
plain = strings.TrimSpace(stripHTMLForSnippet(html))
}
if plain == "" || subjectLooksBroken(plain) {
return ""
}
if idx := strings.IndexAny(plain, ".\n\r"); idx >= 15 && idx <= 100 {
return truncate(strings.TrimSpace(plain[:idx]), 120)
}
return truncate(plain, 120)
}
func subjectLooksBroken(s string) bool {
s = strings.TrimSpace(s)
if s == "" {
return true
}
letters := 0
for _, r := range s {
if unicode.IsLetter(r) || unicode.IsNumber(r) {
letters++
}
}
return letters < 2
}

View File

@ -3,6 +3,7 @@ package imap
import ( import (
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log/slog" "log/slog"
@ -123,6 +124,29 @@ func (w *SyncWorker) syncAllAccounts(ctx context.Context) error {
return nil return nil
} }
// SyncAccountForUser triggers an immediate IMAP sync for a single owned account.
func (w *SyncWorker) SyncAccountForUser(ctx context.Context, externalID, accountID string) error {
var (
host string
port int
useTLS bool
creds []byte
)
err := w.db.QueryRow(ctx, `
SELECT ma.imap_host, ma.imap_port, ma.imap_tls, ma.credentials
FROM mail_accounts ma
JOIN users u ON ma.user_id = u.id
WHERE ma.id = $1 AND u.external_id = $2 AND ma.is_active = true
`, accountID, externalID).Scan(&host, &port, &useTLS, &creds)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return fmt.Errorf("account not found")
}
return err
}
return w.syncAccount(ctx, accountID, host, port, useTLS, creds)
}
func (w *SyncWorker) syncAccount(ctx context.Context, accountID, host string, port int, useTLS bool, creds []byte) error { func (w *SyncWorker) syncAccount(ctx context.Context, accountID, host string, port int, useTLS bool, creds []byte) error {
userID, err := w.accountUserID(ctx, accountID) userID, err := w.accountUserID(ctx, accountID)
if err != nil { if err != nil {
@ -166,6 +190,8 @@ func (w *SyncWorker) syncAccount(ctx context.Context, accountID, host string, po
} }
} }
w.tagImportantFolderMessages(ctx, accountID)
_, err = w.db.Exec(ctx, `UPDATE mail_accounts SET last_sync_at = NOW() WHERE id = $1`, accountID) _, err = w.db.Exec(ctx, `UPDATE mail_accounts SET last_sync_at = NOW() WHERE id = $1`, accountID)
return err return err
} }
@ -222,18 +248,19 @@ func (w *SyncWorker) syncFolder(ctx context.Context, client *imapclient.Client,
} }
lastUID := prevState.LastUID lastUID := prevState.LastUID
derivedLabels := FolderDerivedLabels(folderName)
if lastUID > 0 { if lastUID > 0 {
if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, lastUID+1, 0, false); err != nil { if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, lastUID+1, 0, false, derivedLabels); err != nil {
return err return err
} }
} else { } else {
if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, 1, 0, false); err != nil { if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, 1, 0, false, derivedLabels); err != nil {
return err return err
} }
} }
if selectData.HighestModSeq > 0 && prevState.HighestModSeq > 0 && selectData.HighestModSeq > prevState.HighestModSeq { if selectData.HighestModSeq > 0 && prevState.HighestModSeq > 0 && selectData.HighestModSeq > prevState.HighestModSeq {
if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, 1, prevState.HighestModSeq, true); err != nil { if err := w.fetchAndProcess(ctx, client, accountID, userID, folderID, 1, prevState.HighestModSeq, true, derivedLabels); err != nil {
w.logger.Warn("condstore incremental fetch failed", "folder", folderName, "error", err) w.logger.Warn("condstore incremental fetch failed", "folder", folderName, "error", err)
} }
} }
@ -248,7 +275,7 @@ func (w *SyncWorker) syncFolder(ctx context.Context, client *imapclient.Client,
return saveFolderSyncState(ctx, w.db, folderID, selectData.UIDValidity, selectData.HighestModSeq, maxUID, int(selectData.NumMessages)) return saveFolderSyncState(ctx, w.db, folderID, selectData.UIDValidity, selectData.HighestModSeq, maxUID, int(selectData.NumMessages))
} }
func (w *SyncWorker) fetchAndProcess(ctx context.Context, client *imapclient.Client, accountID, userID, folderID string, fromUID uint32, changedSince uint64, updatesOnly bool) error { func (w *SyncWorker) fetchAndProcess(ctx context.Context, client *imapclient.Client, accountID, userID, folderID string, fromUID uint32, changedSince uint64, updatesOnly bool, derivedLabels []string) error {
seqSet := imap.UIDSet{} seqSet := imap.UIDSet{}
seqSet.AddRange(imap.UID(fromUID), imap.UID(0)) seqSet.AddRange(imap.UID(fromUID), imap.UID(0))
@ -269,7 +296,7 @@ func (w *SyncWorker) fetchAndProcess(ctx context.Context, client *imapclient.Cli
if msg == nil { if msg == nil {
break break
} }
kind, messageID, err := w.processMessage(ctx, msg, accountID, userID, folderID, updatesOnly) kind, messageID, err := w.processMessage(ctx, msg, accountID, userID, folderID, updatesOnly, derivedLabels)
if err != nil { if err != nil {
w.logger.Error("process message failed", "folder_id", folderID, "error", err) w.logger.Error("process message failed", "folder_id", folderID, "error", err)
continue continue
@ -338,7 +365,7 @@ func uidSetToMap(set imap.NumSet) map[uint32]bool {
return out return out
} }
func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMessageData, accountID, userID, folderID string, updatesOnly bool) (kind, messageID string, err error) { func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMessageData, accountID, userID, folderID string, updatesOnly bool, derivedLabels []string) (kind, messageID string, err error) {
var envelope *imap.Envelope var envelope *imap.Envelope
var uid imap.UID var uid imap.UID
var flags []imap.Flag var flags []imap.Flag
@ -378,11 +405,33 @@ func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMe
} }
flagStrs := flagsToStrings(flags) flagStrs := flagsToStrings(flags)
fromAddr := addressesToJSON(envelope.From) fromList := envelope.From
if len(fromList) == 0 {
fromList = envelope.Sender
}
fromAddr := addressesToJSON(fromList)
if len(fromList) == 0 {
if hdrFrom := parseFromHeader(bodyContent); len(hdrFrom) > 0 {
b, _ := json.Marshal(hdrFrom)
fromAddr = b
}
}
if isEmptyFromJSON(fromAddr) {
var folderType string
_ = w.db.QueryRow(ctx, `SELECT folder_type FROM mail_folders WHERE id = $1`, folderID).Scan(&folderType)
if folderType == "sent" {
if acctFrom, err := w.accountFromJSON(ctx, accountID); err == nil && len(acctFrom) > 0 {
fromAddr = acctFrom
}
}
}
toAddrs := addressesToJSON(envelope.To) toAddrs := addressesToJSON(envelope.To)
ccAddrs := addressesToJSON(envelope.Cc) ccAddrs := addressesToJSON(envelope.Cc)
bodyText, bodyHTML := parseBody(bodyContent) bodyText, bodyHTML := parseBody(bodyContent)
snippet := truncate(bodyText, 200) snippet := truncate(bodyText, 200)
if snippet == "" && bodyHTML != "" {
snippet = SnippetFromBodies(bodyText, bodyHTML, 200)
}
headerRefs, headerInReplyTo := parseThreadHeaders(bodyContent) headerRefs, headerInReplyTo := parseThreadHeaders(bodyContent)
inReplyTo := headerInReplyTo inReplyTo := headerInReplyTo
@ -390,24 +439,34 @@ func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMe
inReplyTo = threading.NormalizeMessageID(envelope.InReplyTo[0]) inReplyTo = threading.NormalizeMessageID(envelope.InReplyTo[0])
} }
references := headerRefs references := headerRefs
if len(references) == 0 { if references == nil {
references = threading.ParseMessageIDs(strings.Join(envelope.InReplyTo, " ")) references = []string{}
} }
rfcMessageID := threading.NormalizeMessageID(envelope.MessageID)
replyToJSON, authJSON := parseMessageMeta(bodyContent, envelope)
subject := RepairSubject(envelope.Subject, bodyText, bodyHTML, bodyContent)
snippet = toValidUTF8(snippet)
bodyText = toValidUTF8(bodyText)
bodyHTML = toValidUTF8(bodyHTML)
var existed bool var existed bool
_ = w.db.QueryRow(ctx, ` _ = w.db.QueryRow(ctx, `
SELECT EXISTS(SELECT 1 FROM messages WHERE folder_id = $1 AND uid = $2) SELECT EXISTS(SELECT 1 FROM messages WHERE folder_id = $1 AND uid = $2)
`, folderID, uid).Scan(&existed) `, folderID, uid).Scan(&existed)
err = w.db.QueryRow(ctx, ` err = w.db.QueryRow(ctx, `
INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, cc_addrs, date, snippet, body_text, body_html, flags, in_reply_to, references_header) INSERT INTO messages (account_id, folder_id, uid, message_id, subject, from_addr, to_addrs, cc_addrs, reply_to, auth_info, date, snippet, body_text, body_html, flags, in_reply_to, references_header)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17)
ON CONFLICT (folder_id, uid) DO UPDATE SET ON CONFLICT (folder_id, uid) DO UPDATE SET
message_id = EXCLUDED.message_id, message_id = EXCLUDED.message_id,
subject = EXCLUDED.subject, subject = EXCLUDED.subject,
from_addr = EXCLUDED.from_addr, from_addr = EXCLUDED.from_addr,
to_addrs = EXCLUDED.to_addrs, to_addrs = EXCLUDED.to_addrs,
cc_addrs = EXCLUDED.cc_addrs, cc_addrs = EXCLUDED.cc_addrs,
reply_to = EXCLUDED.reply_to,
auth_info = EXCLUDED.auth_info,
date = EXCLUDED.date, date = EXCLUDED.date,
snippet = EXCLUDED.snippet, snippet = EXCLUDED.snippet,
body_text = EXCLUDED.body_text, body_text = EXCLUDED.body_text,
@ -417,20 +476,29 @@ func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMe
references_header = EXCLUDED.references_header, references_header = EXCLUDED.references_header,
updated_at = NOW() updated_at = NOW()
RETURNING id RETURNING id
`, accountID, folderID, uid, envelope.MessageID, envelope.Subject, `, accountID, folderID, uid, rfcMessageID, subject,
fromAddr, toAddrs, ccAddrs, envelope.Date, snippet, bodyText, bodyHTML, flagStrs, inReplyTo, references, fromAddr, toAddrs, ccAddrs, replyToJSON, authJSON, envelope.Date, snippet, bodyText, bodyHTML, flagStrs, inReplyTo, references,
).Scan(&messageID) ).Scan(&messageID)
if err != nil { if err != nil {
return "", "", err return "", "", err
} }
threadID, err := threading.AssignThreadID(ctx, w.db, accountID, inReplyTo, references) if err := threading.ApplyMessageThread(ctx, w.db, accountID, messageID, rfcMessageID, inReplyTo, references); err != nil {
if err != nil {
return "", "", err return "", "", err
} }
_, err = w.db.Exec(ctx, `UPDATE messages SET thread_id = $1, updated_at = NOW() WHERE id = $2`, threadID, messageID)
if err != nil { if len(derivedLabels) > 0 {
return "", "", err if _, err := w.db.Exec(ctx, `
UPDATE messages
SET labels = (
SELECT COALESCE(array_agg(DISTINCT elem), '{}')
FROM unnest(COALESCE(labels, '{}') || $1::text[]) AS elem
),
updated_at = NOW()
WHERE id = $2
`, derivedLabels, messageID); err != nil {
return "", "", err
}
} }
if err := w.storeAttachments(ctx, userID, messageID, bodyContent, existed); err != nil { if err := w.storeAttachments(ctx, userID, messageID, bodyContent, existed); err != nil {
@ -446,6 +514,25 @@ func (w *SyncWorker) processMessage(ctx context.Context, msg *imapclient.FetchMe
return "created", messageID, nil return "created", messageID, nil
} }
func (w *SyncWorker) tagImportantFolderMessages(ctx context.Context, accountID string) {
_, err := w.db.Exec(ctx, `
UPDATE messages m
SET labels = (
SELECT COALESCE(array_agg(DISTINCT elem), '{}')
FROM unnest(COALESCE(m.labels, '{}') || ARRAY['important']) AS elem
),
updated_at = NOW()
FROM mail_folders mf
WHERE m.folder_id = mf.id
AND m.account_id = $1
AND LOWER(mf.name) = 'important'
AND NOT (COALESCE(m.labels, '{}') @> ARRAY['important'])
`, accountID)
if err != nil {
w.logger.Warn("tag important folder messages failed", "account_id", accountID, "error", err)
}
}
func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID string, raw []byte, messageExisted bool) error { func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID string, raw []byte, messageExisted bool) error {
if w.storage == nil || len(raw) == 0 { if w.storage == nil || len(raw) == 0 {
return nil return nil
@ -490,6 +577,34 @@ func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID str
return err return err
} }
func isEmptyFromJSON(fromAddr []byte) bool {
if len(fromAddr) == 0 || string(fromAddr) == "[]" || string(fromAddr) == "null" {
return true
}
var addrs []EmailAddress
if err := json.Unmarshal(fromAddr, &addrs); err != nil {
return true
}
for _, a := range addrs {
if strings.TrimSpace(a.Address) != "" || strings.TrimSpace(a.Name) != "" {
return false
}
}
return true
}
func (w *SyncWorker) accountFromJSON(ctx context.Context, accountID string) ([]byte, error) {
var email, name string
err := w.db.QueryRow(ctx, `SELECT email, name FROM mail_accounts WHERE id = $1`, accountID).Scan(&email, &name)
if err != nil {
return nil, err
}
if strings.TrimSpace(email) == "" {
return nil, nil
}
return json.Marshal([]EmailAddress{{Name: name, Address: email}})
}
func flagsToStrings(flags []imap.Flag) []string { func flagsToStrings(flags []imap.Flag) []string {
out := make([]string, len(flags)) out := make([]string, len(flags))
for i, f := range flags { for i, f := range flags {

View File

@ -0,0 +1,113 @@
package listunsubscribe
import (
"net/mail"
"net/url"
"strings"
)
// Mailto holds a one-click mailto unsubscribe target.
type Mailto struct {
Address string
Subject string
Body string
}
// Parsed from List-Unsubscribe (RFC 2369).
type Parsed struct {
Mailto *Mailto
HTTP string
}
func splitHeaderParts(raw string) []string {
var parts []string
var cur strings.Builder
depth := 0
for _, r := range raw {
switch r {
case '<':
depth++
cur.WriteRune(r)
case '>':
if depth > 0 {
depth--
}
cur.WriteRune(r)
case ',':
if depth == 0 {
if p := strings.TrimSpace(cur.String()); p != "" {
parts = append(parts, p)
}
cur.Reset()
continue
}
cur.WriteRune(r)
default:
cur.WriteRune(r)
}
}
if p := strings.TrimSpace(cur.String()); p != "" {
parts = append(parts, p)
}
return parts
}
func unwrapAngle(s string) string {
s = strings.TrimSpace(s)
if strings.HasPrefix(s, "<") && strings.HasSuffix(s, ">") {
return strings.TrimSpace(s[1 : len(s)-1])
}
return s
}
// ParseMailtoURL parses mailto:user@host?subject=...&body=...
func ParseMailtoURL(raw string) (*Mailto, bool) {
raw = unwrapAngle(strings.TrimSpace(raw))
if raw == "" {
return nil, false
}
if !strings.HasPrefix(strings.ToLower(raw), "mailto:") {
return nil, false
}
u, err := url.Parse(raw)
if err != nil {
return nil, false
}
addr := strings.TrimSpace(u.Opaque)
if addr == "" {
addr = strings.TrimSpace(strings.TrimPrefix(u.Path, "/"))
}
if addr == "" {
return nil, false
}
if _, err := mail.ParseAddress(addr); err != nil {
// bare addr@host
if !strings.Contains(addr, "@") {
return nil, false
}
}
return &Mailto{
Address: addr,
Subject: u.Query().Get("subject"),
Body: u.Query().Get("body"),
}, true
}
// Parse extracts mailto and https targets from a List-Unsubscribe header value.
func Parse(listUnsubscribe string) Parsed {
out := Parsed{}
for _, part := range splitHeaderParts(listUnsubscribe) {
part = unwrapAngle(part)
if m, ok := ParseMailtoURL(part); ok && out.Mailto == nil {
out.Mailto = m
continue
}
lower := strings.ToLower(part)
if strings.HasPrefix(lower, "http://") || strings.HasPrefix(lower, "https://") {
if out.HTTP == "" {
out.HTTP = part
}
}
}
return out
}

View File

@ -0,0 +1,23 @@
package listunsubscribe
import "testing"
func TestParse_mailtoHeader(t *testing.T) {
got := Parse("<mailto:opposition@vertical-mail.com>")
if got.Mailto == nil || got.Mailto.Address != "opposition@vertical-mail.com" {
t.Fatalf("Parse() mailto = %+v", got.Mailto)
}
}
func TestParse_mailtoWithHttp(t *testing.T) {
got := Parse("<mailto:a@b.com?subject=unsub>, <https://example.com/unsub>")
if got.Mailto == nil || got.Mailto.Address != "a@b.com" {
t.Fatalf("mailto = %+v", got.Mailto)
}
if got.Mailto.Subject != "unsub" {
t.Fatalf("subject = %q", got.Mailto.Subject)
}
if got.HTTP != "https://example.com/unsub" {
t.Fatalf("http = %q", got.HTTP)
}
}

View File

@ -5,6 +5,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"log/slog" "log/slog"
"regexp"
"strings" "strings"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
@ -48,7 +49,7 @@ type Rule struct {
type Condition struct { type Condition struct {
Field string `json:"field"` // from, to, subject, body, has_attachment Field string `json:"field"` // from, to, subject, body, has_attachment
Operator string `json:"operator"` // contains, equals, starts_with, ends_with, matches Operator string `json:"operator"` // contains, equals, starts_with, ends_with, regex, has, not_has, ...
Value string `json:"value"` Value string `json:"value"`
} }
@ -66,11 +67,14 @@ type ActionResult struct {
type Message struct { type Message struct {
ID string `json:"id"` ID string `json:"id"`
AccountID string `json:"account_id,omitempty"`
FolderID string `json:"folder_id,omitempty"`
From string `json:"from"` From string `json:"from"`
To []string `json:"to"` To []string `json:"to"`
Subject string `json:"subject"` Subject string `json:"subject"`
BodyText string `json:"body_text"` BodyText string `json:"body_text"`
HasAttachments bool `json:"has_attachments"` HasAttachments bool `json:"has_attachments"`
Labels []string `json:"labels,omitempty"`
} }
func (e *Engine) Evaluate(ctx context.Context, userID string, msg *Message) error { func (e *Engine) Evaluate(ctx context.Context, userID string, msg *Message) error {
@ -78,10 +82,14 @@ func (e *Engine) Evaluate(ctx context.Context, userID string, msg *Message) erro
} }
func (e *Engine) EvaluateMessage(ctx context.Context, userID string, msg *Message) error { func (e *Engine) EvaluateMessage(ctx context.Context, userID string, msg *Message) error {
return e.EvaluateMessageEvent(ctx, userID, msg, &EventContext{Type: TriggerMessageReceived})
}
func (e *Engine) EvaluateMessageEvent(ctx context.Context, userID string, msg *Message, evt *EventContext) error {
rows, err := e.db.Query(ctx, ` rows, err := e.db.Query(ctx, `
SELECT id, name, conditions, actions SELECT id, name, conditions, actions, workflow
FROM mail_rules FROM mail_rules
WHERE user_id = $1 AND is_active = true WHERE user_id = $1 AND is_active = true AND rule_kind = 'rule'
ORDER BY priority ASC ORDER BY priority ASC
`, userID) `, userID)
if err != nil { if err != nil {
@ -95,25 +103,51 @@ func (e *Engine) EvaluateMessage(ctx context.Context, userID string, msg *Messag
name string name string
condJSON []byte condJSON []byte
actJSON []byte actJSON []byte
wfJSON []byte
) )
if err := rows.Scan(&ruleID, &name, &condJSON, &actJSON); err != nil { if err := rows.Scan(&ruleID, &name, &condJSON, &actJSON, &wfJSON); err != nil {
e.logger.Error("scan rule", "error", err) e.logger.Error("scan rule", "error", err)
continue continue
} }
var conditions []Condition wf, err := ParseWorkflow(wfJSON)
var actions []Action if err != nil {
json.Unmarshal(condJSON, &conditions) e.logger.Error("parse workflow", "rule_id", ruleID, "error", err)
json.Unmarshal(actJSON, &actions) continue
if matchesAll(conditions, msg) {
e.logger.Info("rule matched", "rule_id", ruleID, "rule_name", name, "message_id", msg.ID)
results := e.executeRuleActions(ctx, ruleID, actions, msg)
if err := e.recordRuleExecution(ctx, ruleID, msg.ID, results); err != nil {
e.logger.Error("record rule execution", "rule_id", ruleID, "message_id", msg.ID, "error", err)
}
e.db.Exec(ctx, `UPDATE mail_rules SET match_count = match_count + 1 WHERE id = $1`, ruleID)
} }
var results []ActionResult
if wf != nil && len(wf.Nodes) > 0 {
if !matchesTriggers(wf.Triggers, msg, evt) {
continue
}
startID := wf.findStartNode()
if startID == "" {
e.logger.Error("workflow missing start", "rule_id", ruleID)
continue
}
execCtx := newExecContext(msg, userID, wf.Variables)
if err := e.walkWorkflow(ctx, userID, msg, wf, startID, execCtx, 0); err != nil {
e.logger.Error("execute workflow", "rule_id", ruleID, "error", err)
}
results = execCtx.Results
} else {
var conditions []Condition
var actions []Action
json.Unmarshal(condJSON, &conditions)
json.Unmarshal(actJSON, &actions)
if !matchesAll(conditions, msg) {
continue
}
results = e.executeRuleActions(ctx, ruleID, actions, msg)
}
e.logger.Info("rule matched", "rule_id", ruleID, "rule_name", name, "message_id", msg.ID)
if err := e.recordRuleExecution(ctx, ruleID, msg.ID, results); err != nil {
e.logger.Error("record rule execution", "rule_id", ruleID, "message_id", msg.ID, "error", err)
}
e.db.Exec(ctx, `UPDATE mail_rules SET match_count = match_count + 1 WHERE id = $1`, ruleID)
} }
return nil return nil
@ -181,6 +215,18 @@ func matchesAll(conditions []Condition, msg *Message) bool {
} }
func matchCondition(cond Condition, msg *Message) bool { func matchCondition(cond Condition, msg *Message) bool {
if cond.Field == "label" {
has := messageHasLabel(msg, cond.Value)
switch cond.Operator {
case "has":
return has
case "not_has":
return !has
default:
return false
}
}
var fieldValue string var fieldValue string
switch cond.Field { switch cond.Field {
case "from": case "from":
@ -201,6 +247,21 @@ func matchCondition(cond Condition, msg *Message) bool {
return false return false
} }
switch cond.Operator {
case "regex":
re, err := regexp.Compile(cond.Value)
if err != nil {
return false
}
return re.MatchString(fieldValue)
case "not_regex":
re, err := regexp.Compile(cond.Value)
if err != nil {
return false
}
return !re.MatchString(fieldValue)
}
fieldLower := strings.ToLower(fieldValue) fieldLower := strings.ToLower(fieldValue)
valueLower := strings.ToLower(cond.Value) valueLower := strings.ToLower(cond.Value)
@ -220,6 +281,19 @@ func matchCondition(cond Condition, msg *Message) bool {
} }
} }
func messageHasLabel(msg *Message, label string) bool {
labelLower := strings.ToLower(strings.TrimSpace(label))
if labelLower == "" {
return false
}
for _, l := range msg.Labels {
if strings.ToLower(l) == labelLower {
return true
}
}
return false
}
func messageToWebhookContext(msg *Message) *webhooks.MessageContext { func messageToWebhookContext(msg *Message) *webhooks.MessageContext {
senderName, senderEmail := parseFromAddress(msg.From) senderName, senderEmail := parseFromAddress(msg.From)
return &webhooks.MessageContext{ return &webhooks.MessageContext{
@ -286,6 +360,38 @@ func (e *Engine) executeAction(ctx context.Context, action Action, msg *Message)
WHERE id = $1 WHERE id = $1
`, msg.ID) `, msg.ID)
return err return err
case "remove_label":
_, err := e.db.Exec(ctx, `
UPDATE messages SET labels = array_remove(labels, $1), updated_at = NOW()
WHERE id = $2
`, action.Value, msg.ID)
return err
case "mark_important":
_, err := e.db.Exec(ctx, `
UPDATE messages SET flags = array_append(flags, '\Flagged'), updated_at = NOW()
WHERE id = $1 AND NOT ('\Flagged' = ANY(flags))
`, msg.ID)
return err
case "mark_spam":
_, err := e.db.Exec(ctx, `
UPDATE messages SET labels = (
SELECT array_agg(DISTINCT x) FROM unnest(array_append(labels, 'SPAM')) AS x
), updated_at = NOW()
WHERE id = $1
`, msg.ID)
return err
case "star":
_, err := e.db.Exec(ctx, `
UPDATE messages SET flags = array_append(flags, '\Flagged'), updated_at = NOW()
WHERE id = $1 AND NOT ('\Flagged' = ANY(flags))
`, msg.ID)
return err
case "notify":
e.logger.Info("notification action", "message_id", msg.ID, "body", action.Value)
return nil
case "reply", "send_mail", "forward":
e.logger.Info("deferred mail action", "type", action.Type, "message_id", msg.ID, "value", action.Value)
return nil
case "webhook": case "webhook":
if e.webhookExec == nil { if e.webhookExec == nil {
return fmt.Errorf("webhook executor not configured") return fmt.Errorf("webhook executor not configured")

View File

@ -17,6 +17,7 @@ func testMessage() *Message {
Subject: "Invoice Q1", Subject: "Invoice Q1",
BodyText: "Please review the attached invoice.", BodyText: "Please review the attached invoice.",
HasAttachments: true, HasAttachments: true,
Labels: []string{"work", "finance"},
} }
} }
@ -37,7 +38,12 @@ func TestMatchCondition_fieldsAndOperators(t *testing.T) {
{"has_attachment false", Condition{Field: "has_attachment", Operator: "equals", Value: "false"}, false}, {"has_attachment false", Condition{Field: "has_attachment", Operator: "equals", Value: "false"}, false},
{"not_contains", Condition{Field: "subject", Operator: "not_contains", Value: "spam"}, true}, {"not_contains", Condition{Field: "subject", Operator: "not_contains", Value: "spam"}, true},
{"unknown field", Condition{Field: "unknown", Operator: "contains", Value: "x"}, false}, {"unknown field", Condition{Field: "unknown", Operator: "contains", Value: "x"}, false},
{"unknown operator", Condition{Field: "subject", Operator: "matches", Value: "Invoice"}, false}, {"unknown operator", Condition{Field: "subject", Operator: "unknown_op", Value: "Invoice"}, false},
{"regex match", Condition{Field: "subject", Operator: "regex", Value: `(?i)invoice`}, true},
{"regex no match", Condition{Field: "subject", Operator: "regex", Value: `^Spam`}, false},
{"not_regex", Condition{Field: "subject", Operator: "not_regex", Value: `^Spam`}, true},
{"label has", Condition{Field: "label", Operator: "has", Value: "work"}, true},
{"label not_has", Condition{Field: "label", Operator: "not_has", Value: "spam"}, true},
} }
for _, tt := range tests { for _, tt := range tests {
@ -81,11 +87,11 @@ func TestMatchesAll(t *testing.T) {
func TestExecuteAction_unknownType(t *testing.T) { func TestExecuteAction_unknownType(t *testing.T) {
e := &Engine{} e := &Engine{}
err := e.executeAction(context.Background(), Action{Type: "forward", Value: "x@example.com"}, &Message{ID: "msg-1"}) err := e.executeAction(context.Background(), Action{Type: "unknown_action", Value: "x@example.com"}, &Message{ID: "msg-1"})
if err == nil { if err == nil {
t.Fatal("executeAction() error = nil, want unknown action type error") t.Fatal("executeAction() error = nil, want unknown action type error")
} }
if !strings.Contains(err.Error(), "unknown action type: forward") { if !strings.Contains(err.Error(), "unknown action type: unknown_action") {
t.Fatalf("executeAction() error = %v, want unknown action type", err) t.Fatalf("executeAction() error = %v, want unknown action type", err)
} }
} }

View File

@ -37,7 +37,7 @@ func (e *Engine) simulateActions(ctx context.Context, actions []Action, msg *Mes
func (e *Engine) simulateAction(ctx context.Context, action Action, msg *Message) SimulatedActionResult { func (e *Engine) simulateAction(ctx context.Context, action Action, msg *Message) SimulatedActionResult {
switch action.Type { switch action.Type {
case "label", "move", "archive", "delete", "mark_read": case "label", "move", "archive", "delete", "mark_read", "remove_label", "mark_important", "mark_spam", "star", "notify", "reply", "send_mail", "forward":
return SimulatedActionResult{ return SimulatedActionResult{
ActionResult: ActionResult{Type: action.Type, Value: action.Value, OK: true}, ActionResult: ActionResult{Type: action.Type, Value: action.Value, OK: true},
} }

View File

@ -0,0 +1,240 @@
package rules
import (
"encoding/json"
"fmt"
)
const WorkflowVersion = 1
type RuleKind string
const (
RuleKindRule RuleKind = "rule"
RuleKindFunction RuleKind = "function"
)
type TriggerType string
const (
TriggerMessageReceived TriggerType = "message_received"
TriggerLabelAdded TriggerType = "label_added"
TriggerLabelRemoved TriggerType = "label_removed"
)
type Trigger struct {
Type TriggerType `json:"type"`
FolderID string `json:"folder_id,omitempty"`
Label string `json:"label,omitempty"`
AccountID string `json:"account_id,omitempty"`
}
type TriggerGroup struct {
Operator string `json:"operator"` // "or"
Groups []TriggerAnd `json:"groups"`
}
type TriggerAnd struct {
Operator string `json:"operator"` // "and"
Items []Trigger `json:"items"`
}
type ExecVariable struct {
Name string `json:"name"`
Type string `json:"type"` // string, number, boolean
Default string `json:"default,omitempty"`
}
type WorkflowNode struct {
ID string `json:"id"`
Type string `json:"type"`
Position json.RawMessage `json:"position,omitempty"`
Data json.RawMessage `json:"data"`
}
type WorkflowEdge struct {
ID string `json:"id"`
Source string `json:"source"`
Target string `json:"target"`
SourceHandle string `json:"sourceHandle,omitempty"`
}
type Workflow struct {
Version int `json:"version"`
Kind RuleKind `json:"kind"`
Triggers TriggerGroup `json:"triggers"`
Variables []ExecVariable `json:"variables,omitempty"`
Nodes []WorkflowNode `json:"nodes"`
Edges []WorkflowEdge `json:"edges"`
}
type ConditionNodeData struct {
Field string `json:"field"`
Operator string `json:"operator"`
Value string `json:"value"`
}
type LabelCheckNodeData struct {
Label string `json:"label"`
Operator string `json:"operator"` // has, not_has
}
type SwitchCase struct {
Value string `json:"value"`
Label string `json:"label,omitempty"`
}
type SwitchNodeData struct {
Field string `json:"field"`
Cases []SwitchCase `json:"cases"`
}
type LLMCheckNodeData struct {
Prompt string `json:"prompt"`
Provider string `json:"provider,omitempty"`
Model string `json:"model,omitempty"`
}
type ActionItem struct {
Type string `json:"type"`
Value string `json:"value"`
}
type ActionsNodeData struct {
Actions []ActionItem `json:"actions"`
}
type SetVarNodeData struct {
Name string `json:"name"`
Value string `json:"value"`
}
type CallRuleNodeData struct {
RuleID string `json:"rule_id"`
}
type EventContext struct {
Type TriggerType
FolderID string
Label string
}
func ParseWorkflow(raw []byte) (*Workflow, error) {
if len(raw) == 0 || string(raw) == "null" {
return nil, nil
}
var wf Workflow
if err := json.Unmarshal(raw, &wf); err != nil {
return nil, fmt.Errorf("parse workflow: %w", err)
}
if wf.Version == 0 {
wf.Version = WorkflowVersion
}
if wf.Kind == "" {
wf.Kind = RuleKindRule
}
return &wf, nil
}
func (wf *Workflow) nodeMap() map[string]WorkflowNode {
m := make(map[string]WorkflowNode, len(wf.Nodes))
for _, n := range wf.Nodes {
m[n.ID] = n
}
return m
}
func (wf *Workflow) outgoingEdges(nodeID string) []WorkflowEdge {
var out []WorkflowEdge
for _, e := range wf.Edges {
if e.Source == nodeID {
out = append(out, e)
}
}
return out
}
func (wf *Workflow) nextNode(nodeID, handle string) string {
for _, e := range wf.Edges {
if e.Source == nodeID && e.SourceHandle == handle {
return e.Target
}
}
return ""
}
func (wf *Workflow) nextDefault(nodeID string) string {
for _, e := range wf.Edges {
if e.Source == nodeID && e.SourceHandle == "" {
return e.Target
}
}
return ""
}
func (wf *Workflow) findStartNode() string {
for _, n := range wf.Nodes {
if n.Type == "start" {
return n.ID
}
}
return ""
}
func matchesTriggers(triggers TriggerGroup, msg *Message, evt *EventContext) bool {
if len(triggers.Groups) == 0 {
return true
}
for _, group := range triggers.Groups {
if matchesTriggerAnd(group, msg, evt) {
return true
}
}
return false
}
func matchesTriggerAnd(group TriggerAnd, msg *Message, evt *EventContext) bool {
if len(group.Items) == 0 {
return true
}
for _, t := range group.Items {
if !matchTrigger(t, msg, evt) {
return false
}
}
return true
}
func matchTrigger(t Trigger, msg *Message, evt *EventContext) bool {
switch t.Type {
case TriggerMessageReceived:
if evt != nil && evt.Type != TriggerMessageReceived && evt.Type != "" {
return false
}
if t.AccountID != "" && msg.AccountID != "" && t.AccountID != msg.AccountID {
return false
}
if t.FolderID != "" && msg.FolderID != "" && t.FolderID != msg.FolderID {
return false
}
return true
case TriggerLabelAdded:
if evt == nil || evt.Type != TriggerLabelAdded {
return false
}
if t.Label != "" && t.Label != evt.Label {
return false
}
return true
case TriggerLabelRemoved:
if evt == nil || evt.Type != TriggerLabelRemoved {
return false
}
if t.Label != "" && t.Label != evt.Label {
return false
}
return true
default:
return false
}
}

View File

@ -0,0 +1,291 @@
package rules
import (
"context"
"encoding/json"
"fmt"
"strings"
)
type ExecContext struct {
Variables map[string]string
Message *Message
UserID string
Results []ActionResult
}
func newExecContext(msg *Message, userID string, vars []ExecVariable) *ExecContext {
m := make(map[string]string, len(vars))
for _, v := range vars {
m[v.Name] = v.Default
}
return &ExecContext{
Variables: m,
Message: msg,
UserID: userID,
Results: make([]ActionResult, 0),
}
}
func (e *Engine) ExecuteWorkflow(ctx context.Context, userID string, msg *Message, wf *Workflow, evt *EventContext) ([]ActionResult, error) {
if wf == nil {
return nil, nil
}
if wf.Kind == RuleKindFunction {
return e.runWorkflowGraph(ctx, userID, msg, wf, newExecContext(msg, userID, wf.Variables))
}
if !matchesTriggers(wf.Triggers, msg, evt) {
return nil, nil
}
startID := wf.findStartNode()
if startID == "" {
return nil, fmt.Errorf("workflow missing start node")
}
execCtx := newExecContext(msg, userID, wf.Variables)
if err := e.walkWorkflow(ctx, userID, msg, wf, startID, execCtx, 0); err != nil {
return execCtx.Results, err
}
return execCtx.Results, nil
}
const maxWorkflowDepth = 32
func (e *Engine) walkWorkflow(ctx context.Context, userID string, msg *Message, wf *Workflow, nodeID string, execCtx *ExecContext, depth int) error {
if depth > maxWorkflowDepth {
return fmt.Errorf("workflow depth exceeded")
}
if nodeID == "" {
return nil
}
nodes := wf.nodeMap()
node, ok := nodes[nodeID]
if !ok {
return fmt.Errorf("unknown node: %s", nodeID)
}
switch node.Type {
case "start":
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, depth+1)
case "label_check":
var data LabelCheckNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("label_check node data: %w", err)
}
cond := Condition{Field: "label", Operator: "has", Value: data.Label}
if data.Operator == "not_has" {
cond.Operator = "not_has"
}
handle := "false"
if matchCondition(cond, msg) {
handle = "true"
}
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, depth+1)
case "condition":
var data ConditionNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("condition node data: %w", err)
}
cond := Condition{Field: data.Field, Operator: data.Operator, Value: interpolateValue(data.Value, execCtx)}
handle := "false"
if matchCondition(cond, msg) {
handle = "true"
}
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, depth+1)
case "switch":
var data SwitchNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("switch node data: %w", err)
}
fieldVal := workflowFieldValue(data.Field, msg, execCtx)
handle := "default"
for i, c := range data.Cases {
if strings.EqualFold(fieldVal, c.Value) {
handle = fmt.Sprintf("case-%d", i)
break
}
}
next := wf.nextNode(nodeID, handle)
if next == "" {
next = wf.nextNode(nodeID, "default")
}
return e.walkWorkflow(ctx, userID, msg, wf, next, execCtx, depth+1)
case "llm_check":
var data LLMCheckNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("llm_check node data: %w", err)
}
handle := "false"
if e.evaluateLLMCheck(ctx, data, msg, execCtx) {
handle = "true"
}
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, depth+1)
case "actions":
var data ActionsNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("actions node data: %w", err)
}
for _, item := range data.Actions {
action := Action{Type: item.Type, Value: interpolateValue(item.Value, execCtx)}
err := e.executeAction(ctx, action, msg)
result := actionResultFrom(action, err)
execCtx.Results = append(execCtx.Results, result)
if err != nil {
e.logger.Error("workflow action failed", "action", action.Type, "error", err)
}
}
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, depth+1)
case "set_var":
var data SetVarNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("set_var node data: %w", err)
}
execCtx.Variables[data.Name] = interpolateValue(data.Value, execCtx)
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, depth+1)
case "call_function", "call_rule":
var data CallRuleNodeData
if err := json.Unmarshal(node.Data, &data); err != nil {
return fmt.Errorf("call_rule node data: %w", err)
}
if err := e.invokeSubWorkflow(ctx, userID, msg, data.RuleID, execCtx, depth+1); err != nil {
return err
}
return e.walkWorkflow(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, depth+1)
case "end":
return nil
default:
return fmt.Errorf("unknown node type: %s", node.Type)
}
}
func (e *Engine) invokeSubWorkflow(ctx context.Context, userID string, msg *Message, ruleID string, parent *ExecContext, depth int) error {
if depth > maxWorkflowDepth {
return fmt.Errorf("workflow call depth exceeded")
}
var (
wfJSON []byte
ruleKind string
isActive bool
)
err := e.db.QueryRow(ctx, `
SELECT workflow, rule_kind, is_active
FROM mail_rules
WHERE id = $1 AND user_id = $2
`, ruleID, userID).Scan(&wfJSON, &ruleKind, &isActive)
if err != nil {
return fmt.Errorf("load sub-rule %s: %w", ruleID, err)
}
if !isActive {
return nil
}
wf, err := ParseWorkflow(wfJSON)
if err != nil {
return err
}
if wf == nil {
return fmt.Errorf("sub-rule %s has no workflow", ruleID)
}
childCtx := &ExecContext{
Variables: copyVars(parent.Variables),
Message: msg,
UserID: userID,
Results: parent.Results,
}
startID := wf.findStartNode()
if startID == "" {
return fmt.Errorf("sub-rule %s missing start node", ruleID)
}
return e.walkWorkflow(ctx, userID, msg, wf, startID, childCtx, depth)
}
func copyVars(src map[string]string) map[string]string {
dst := make(map[string]string, len(src))
for k, v := range src {
dst[k] = v
}
return dst
}
func workflowFieldValue(field string, msg *Message, execCtx *ExecContext) string {
if strings.HasPrefix(field, "$") {
name := strings.TrimPrefix(field, "$")
if v, ok := execCtx.Variables[name]; ok {
return v
}
return ""
}
switch field {
case "from":
return msg.From
case "to":
return strings.Join(msg.To, ", ")
case "subject":
return msg.Subject
case "body":
return msg.BodyText
case "has_attachment":
if msg.HasAttachments {
return "true"
}
return "false"
case "label":
return strings.Join(msg.Labels, ", ")
default:
return ""
}
}
func interpolateValue(template string, execCtx *ExecContext) string {
if !strings.Contains(template, "{{") {
return template
}
out := template
for name, val := range execCtx.Variables {
out = strings.ReplaceAll(out, "{{"+name+"}}", val)
}
if strings.Contains(out, "{{") && execCtx.Message != nil {
out = strings.ReplaceAll(out, "{{subject}}", execCtx.Message.Subject)
out = strings.ReplaceAll(out, "{{from}}", execCtx.Message.From)
}
return out
}
func (e *Engine) evaluateLLMCheck(ctx context.Context, data LLMCheckNodeData, msg *Message, execCtx *ExecContext) bool {
_ = ctx
prompt := interpolateValue(data.Prompt, execCtx)
promptLower := strings.ToLower(prompt)
if strings.Contains(promptLower, "spam") {
subjectLower := strings.ToLower(msg.Subject)
bodyLower := strings.ToLower(msg.BodyText)
return strings.Contains(subjectLower, "spam") || strings.Contains(bodyLower, "spam") ||
strings.Contains(subjectLower, "viagra") || strings.Contains(bodyLower, "lottery")
}
if strings.Contains(promptLower, "important") || strings.Contains(promptLower, "urgent") {
subjectLower := strings.ToLower(msg.Subject)
return strings.Contains(subjectLower, "urgent") || strings.Contains(subjectLower, "important") ||
strings.Contains(subjectLower, "asap")
}
return false
}
func (e *Engine) runWorkflowGraph(ctx context.Context, userID string, msg *Message, wf *Workflow, execCtx *ExecContext) ([]ActionResult, error) {
startID := wf.findStartNode()
if startID == "" {
return nil, fmt.Errorf("function workflow missing start node")
}
if err := e.walkWorkflow(ctx, userID, msg, wf, startID, execCtx, 0); err != nil {
return execCtx.Results, err
}
return execCtx.Results, nil
}

View File

@ -0,0 +1,139 @@
package rules
import (
"context"
"encoding/json"
"fmt"
)
type WorkflowSimulationStep struct {
NodeID string `json:"node_id"`
NodeType string `json:"node_type"`
Handle string `json:"handle,omitempty"`
}
type WorkflowSimulationResult struct {
Matched bool `json:"matched"`
Steps []WorkflowSimulationStep `json:"steps,omitempty"`
Actions []SimulatedActionResult `json:"actions,omitempty"`
}
func (e *Engine) SimulateWorkflow(ctx context.Context, userID string, wf *Workflow, msg *Message, evt *EventContext) WorkflowSimulationResult {
if wf == nil || len(wf.Nodes) == 0 {
return WorkflowSimulationResult{Matched: false}
}
if wf.Kind != RuleKindFunction && !matchesTriggers(wf.Triggers, msg, evt) {
return WorkflowSimulationResult{Matched: false}
}
startID := wf.findStartNode()
if startID == "" {
return WorkflowSimulationResult{Matched: false}
}
execCtx := newExecContext(msg, userID, wf.Variables)
steps := make([]WorkflowSimulationStep, 0)
e.simulateWalk(ctx, userID, msg, wf, startID, execCtx, &steps, 0)
simActions := make([]SimulatedActionResult, 0, len(execCtx.Results))
for _, r := range execCtx.Results {
simActions = append(simActions, SimulatedActionResult{ActionResult: r})
}
return WorkflowSimulationResult{
Matched: true,
Steps: steps,
Actions: simActions,
}
}
func (e *Engine) simulateWalk(ctx context.Context, userID string, msg *Message, wf *Workflow, nodeID string, execCtx *ExecContext, steps *[]WorkflowSimulationStep, depth int) {
if depth > maxWorkflowDepth || nodeID == "" {
return
}
nodes := wf.nodeMap()
node, ok := nodes[nodeID]
if !ok {
return
}
switch node.Type {
case "start":
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type})
e.simulateWalk(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, steps, depth+1)
case "condition":
var data ConditionNodeData
json.Unmarshal(node.Data, &data)
handle := "false"
if matchCondition(Condition{Field: data.Field, Operator: data.Operator, Value: interpolateValue(data.Value, execCtx)}, msg) {
handle = "true"
}
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type, Handle: handle})
e.simulateWalk(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, steps, depth+1)
case "label_check":
var data LabelCheckNodeData
json.Unmarshal(node.Data, &data)
op := "has"
if data.Operator == "not_has" {
op = "not_has"
}
handle := "false"
if matchCondition(Condition{Field: "label", Operator: op, Value: data.Label}, msg) {
handle = "true"
}
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type, Handle: handle})
e.simulateWalk(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, steps, depth+1)
case "switch":
var data SwitchNodeData
json.Unmarshal(node.Data, &data)
fieldVal := workflowFieldValue(data.Field, msg, execCtx)
handle := "default"
for i, c := range data.Cases {
if fieldVal == c.Value {
handle = fmt.Sprintf("case-%d", i)
break
}
}
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type, Handle: handle})
next := wf.nextNode(nodeID, handle)
if next == "" {
next = wf.nextNode(nodeID, "default")
}
e.simulateWalk(ctx, userID, msg, wf, next, execCtx, steps, depth+1)
case "llm_check":
var data LLMCheckNodeData
json.Unmarshal(node.Data, &data)
handle := "false"
if e.evaluateLLMCheck(ctx, data, msg, execCtx) {
handle = "true"
}
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type, Handle: handle})
e.simulateWalk(ctx, userID, msg, wf, wf.nextNode(nodeID, handle), execCtx, steps, depth+1)
case "actions":
var data ActionsNodeData
json.Unmarshal(node.Data, &data)
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type})
for _, item := range data.Actions {
action := Action{Type: item.Type, Value: interpolateValue(item.Value, execCtx)}
execCtx.Results = append(execCtx.Results, e.simulateAction(ctx, action, msg).ActionResult)
}
e.simulateWalk(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, steps, depth+1)
case "set_var":
var data SetVarNodeData
json.Unmarshal(node.Data, &data)
execCtx.Variables[data.Name] = interpolateValue(data.Value, execCtx)
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type})
e.simulateWalk(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, steps, depth+1)
case "call_function", "call_rule":
var data CallRuleNodeData
json.Unmarshal(node.Data, &data)
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type})
e.simulateWalk(ctx, userID, msg, wf, wf.nextDefault(nodeID), execCtx, steps, depth+1)
case "end":
*steps = append(*steps, WorkflowSimulationStep{NodeID: nodeID, NodeType: node.Type})
}
}

View File

@ -0,0 +1,59 @@
package sanitize
import (
"regexp"
"github.com/microcosm-cc/bluemonday"
)
var (
styleType = regexp.MustCompile(`(?i)^text\/css$`)
cssJSURL = regexp.MustCompile(`(?i)url\s*\(\s*['"]?javascript:[^)]*\)`)
)
// emailPolicy preserves HTML email layout (inline styles, <style>, tables) while
// stripping scripts and event handlers. Display happens in a sandboxed iframe.
func emailPolicy() *bluemonday.Policy {
p := bluemonday.UGCPolicy()
// Full documents and email structure
p.AllowElements("html", "head", "body", "title")
p.AllowElements("font", "main", "nav", "header", "footer")
// Inline styles + <style> blocks (requires AllowUnsafe for tag content)
p.AllowAttrs("type").Matching(styleType).OnElements("style")
p.AllowAttrs("style").Globally()
p.AllowStyling()
p.AllowUnsafe(true)
p.AllowElementsContent("style")
// Legacy table / font attributes common in newsletters
p.AllowAttrs("bgcolor", "color").OnElements("basefont", "font", "hr", "td", "table", "tr", "th")
p.AllowAttrs("border").Matching(bluemonday.Integer).OnElements("img", "table")
p.AllowAttrs("cellpadding", "cellspacing").Matching(bluemonday.Integer).OnElements("table")
p.AllowAttrs("width", "height", "align", "valign", "background", "colspan", "rowspan").
OnElements("table", "tbody", "tr", "td", "th", "thead", "tfoot", "colgroup", "col", "div", "p", "img")
// External CSS (resolved by the client when remote content is allowed)
p.AllowAttrs("rel", "href").OnElements("link")
p.AllowAttrs("type").Matching(styleType).OnElements("link")
p.AllowRelativeURLs(true)
p.AllowDataURIImages()
p.AllowURLSchemes("cid")
// Lazy-load / responsive images common in newsletters
p.AllowAttrs("srcset").OnElements("img", "source")
p.AllowAttrs("loading", "decoding", "sizes").OnElements("img", "source")
p.AllowAttrs(
"data-src", "data-original", "data-lazy-src", "data-srcset",
"data-href", "data-url", "data-image", "data-bg",
).OnElements("img", "source")
return p
}
var policy = emailPolicy()
func stripUnsafeCSSURLs(html string) string {
return cssJSURL.ReplaceAllString(html, "url(about:blank)")
}

View File

@ -0,0 +1,40 @@
package sanitize
import (
"strings"
"testing"
)
func TestSanitizeHTML_preservesImgLazyAndSrcset(t *testing.T) {
in := `<img data-src="https://cdn.example.com/hero.png" alt="hero">` +
`<img srcset="https://cdn.example.com/a.png 1x" src="https://cdn.example.com/f.png">`
got := SanitizeHTML(in)
if !strings.Contains(got, `data-src="https://cdn.example.com/hero.png"`) {
t.Fatalf("expected data-src preserved, got %q", got)
}
if !strings.Contains(got, `src="https://cdn.example.com/hero.png"`) {
t.Fatalf("expected lazy src promoted to src, got %q", got)
}
if !strings.Contains(got, `srcset="https://cdn.example.com/a.png 1x"`) {
t.Fatalf("expected srcset preserved, got %q", got)
}
if !strings.Contains(got, `src="https://cdn.example.com/f.png"`) {
t.Fatalf("expected src preserved, got %q", got)
}
}
func TestSanitizeHTML_preservesCidImageSrc(t *testing.T) {
in := `<img src="cid:logo@mail" alt="logo">`
got := SanitizeHTML(in)
if !strings.Contains(got, `src="cid:logo@mail"`) {
t.Fatalf("expected cid src preserved, got %q", got)
}
}
func TestSanitizeHTML_preservesRelativeImgSrc(t *testing.T) {
in := `<img src="/campaign/logo.png" alt="logo">`
got := SanitizeHTML(in)
if !strings.Contains(got, `src="/campaign/logo.png"`) {
t.Fatalf("expected relative src preserved, got %q", got)
}
}

View File

@ -0,0 +1,119 @@
package sanitize
import (
"bytes"
"regexp"
"strings"
"golang.org/x/net/html"
)
// StripHiddenEmailHTML removes invisible preheader / preview blocks common in marketing mail.
// Must run before bluemonday, which strips display:none styles and would expose padding text.
func StripHiddenEmailHTML(raw string) string {
if raw == "" {
return raw
}
doc, err := html.Parse(strings.NewReader(raw))
if err != nil {
return stripHiddenEmailHTMLRegex(raw)
}
var remove []*html.Node
var walk func(*html.Node)
walk = func(n *html.Node) {
if n.Type == html.ElementNode && shouldStripHiddenElement(n) {
remove = append(remove, n)
return
}
for c := n.FirstChild; c != nil; c = c.NextSibling {
walk(c)
}
}
walk(doc)
for _, n := range remove {
if n.Parent != nil {
n.Parent.RemoveChild(n)
}
}
var buf bytes.Buffer
if err := html.Render(&buf, doc); err != nil {
return stripHiddenEmailHTMLRegex(raw)
}
return buf.String()
}
func shouldStripHiddenElement(n *html.Node) bool {
if n.Type != html.ElementNode {
return false
}
if attrVal(n, "hidden") != "" {
return true
}
if strings.EqualFold(attrVal(n, "aria-hidden"), "true") {
return true
}
class := strings.ToLower(attrVal(n, "class"))
if strings.Contains(class, "mcnpreviewtext") ||
strings.Contains(class, "preheader") ||
strings.Contains(class, "preview-text") {
return true
}
style := strings.ToLower(attrVal(n, "style"))
if style == "" {
return false
}
styleCompact := strings.ReplaceAll(style, " ", "")
return strings.Contains(styleCompact, "display:none") ||
strings.Contains(styleCompact, "mso-hide:all") ||
strings.Contains(styleCompact, "max-height:0") ||
strings.Contains(styleCompact, "opacity:0") ||
strings.Contains(styleCompact, "font-size:0") ||
strings.Contains(styleCompact, "visibility:hidden") ||
strings.Contains(styleCompact, "overflow:hidden") && strings.Contains(styleCompact, "max-height:0")
}
func attrVal(n *html.Node, key string) string {
for _, a := range n.Attr {
if strings.EqualFold(a.Key, key) {
return a.Val
}
}
return ""
}
func stripHiddenEmailHTMLRegex(raw string) string {
patterns := []*regexp.Regexp{
regexp.MustCompile(`(?is)<span[^>]*class="[^"]*mcnPreviewText[^"]*"[^>]*>.*?</span>`),
regexp.MustCompile(`(?is)<div[^>]*style="[^"]*display\s*:\s*none[^"]*"[^>]*>.*?</div>`),
}
out := raw
for _, re := range patterns {
out = re.ReplaceAllString(out, "")
}
return out
}
func isInvisiblePaddingRune(r rune) bool {
switch r {
case '\u034f', '\u200b', '\u200c', '\u200d', '\u200e', '\u200f', '\ufeff', '\u00a0', '\u2007':
return true
default:
return false
}
}
// StripInvisibleTextRuns removes repeated invisible Unicode padding from plain text previews.
func StripInvisibleTextRuns(s string) string {
if s == "" {
return s
}
var b strings.Builder
b.Grow(len(s))
for _, r := range s {
if isInvisiblePaddingRune(r) {
continue
}
b.WriteRune(r)
}
return strings.Join(strings.Fields(b.String()), " ")
}

View File

@ -0,0 +1,41 @@
package sanitize
import (
"strings"
"testing"
)
func TestStripHiddenEmailHTML_removesMailchimpPreheaderPadding(t *testing.T) {
const padding = "͏ \u200c \u00a0 \u2007 "
in := `<div style="display: none; max-height: 0px; overflow: hidden;">` + strings.Repeat(padding, 20) + `</div>` +
`<p>All motor files are now structured by model.</p>`
out := StripHiddenEmailHTML(in)
if strings.Contains(out, "\u034f") || strings.Contains(out, "\u200c") {
t.Fatalf("invisible padding remains: %q", out)
}
if !strings.Contains(out, "All motor files") {
t.Fatalf("visible content removed: %q", out)
}
}
func TestSanitizeHTML_stripsHiddenPreheader(t *testing.T) {
in := `<div style="display: none;">hidden junk</div><p>Hello world</p>`
got := SanitizeHTML(in)
if strings.Contains(got, "hidden junk") {
t.Fatalf("hidden preheader leaked after sanitize: %q", got)
}
if !strings.Contains(got, "Hello world") {
t.Fatalf("visible content missing: %q", got)
}
}
func TestStripInvisibleTextRuns(t *testing.T) {
in := "Hello " + strings.Repeat("\u034f\u200c\u00a0\u2007 ", 30) + "world"
got := StripInvisibleTextRuns(in)
if strings.Contains(got, "\u034f") {
t.Fatalf("padding remains: %q", got)
}
if got != "Hello world" {
t.Fatalf("got %q", got)
}
}

View File

@ -1,12 +1,45 @@
package sanitize package sanitize
import "github.com/microcosm-cc/bluemonday" import "regexp"
var policy = bluemonday.UGCPolicy() var (
imgTagRe = regexp.MustCompile(`(?i)<img\b([^>]*)>`)
imgSrcAttrRe = regexp.MustCompile(`(?i)\bsrc\s*=`)
imgLazySrcRe = regexp.MustCompile(
`(?i)\bdata-(?:src|original|lazy-src|url|image|href)\s*=\s*("([^"]*)"|'([^']*)')`,
)
)
// promoteEmailImageSources copies lazy-load URLs into src when newsletters omit src.
func promoteEmailImageSources(html string) string {
return imgTagRe.ReplaceAllStringFunc(html, func(tag string) string {
attrs := imgTagRe.FindStringSubmatch(tag)
if len(attrs) < 2 || imgSrcAttrRe.MatchString(attrs[1]) {
return tag
}
lazy := imgLazySrcRe.FindStringSubmatch(attrs[1])
if len(lazy) < 4 {
return tag
}
var quote, url string
if lazy[2] != "" {
quote, url = `"`, lazy[2]
} else {
quote, url = `'`, lazy[3]
}
if url == "" {
return tag
}
return "<img" + attrs[1] + ` src=` + quote + url + quote + ">"
})
}
func SanitizeHTML(html string) string { func SanitizeHTML(html string) string {
if html == "" { if html == "" {
return "" return ""
} }
return policy.Sanitize(html) html = StripHiddenEmailHTML(html)
html = policy.Sanitize(html)
html = promoteEmailImageSources(html)
return stripUnsafeCSSURLs(html)
} }

View File

@ -35,6 +35,40 @@ func TestSanitizeHTML_preservesSafeContent(t *testing.T) {
} }
} }
func TestSanitizeHTML_preservesEmailStyles(t *testing.T) {
in := `<style type="text/css">.title{font-family:Arial,sans-serif;color:#c00;}</style>` +
`<table width="600"><tr><td class="title" style="font-size:16px;">Promo</td></tr></table>`
got := SanitizeHTML(in)
if !strings.Contains(got, "font-family:Arial") {
t.Fatalf("expected style block preserved, got %q", got)
}
if !strings.Contains(got, `class="title"`) {
t.Fatalf("expected class preserved, got %q", got)
}
if !strings.Contains(got, `style="font-size:16px`) {
t.Fatalf("expected inline style preserved, got %q", got)
}
}
func TestSanitizeHTML_stripsJavascriptInCSS(t *testing.T) {
in := `<style>.x{background:url(javascript:alert(1))}</style><p class="x">Y</p>`
got := SanitizeHTML(in)
if strings.Contains(strings.ToLower(got), "javascript:") {
t.Fatalf("expected javascript css url stripped, got %q", got)
}
if !strings.Contains(got, `<p class="x">Y</p>`) {
t.Fatalf("expected content preserved, got %q", got)
}
}
func TestSanitizeHTML_preservesStylesheetLink(t *testing.T) {
in := `<link rel="stylesheet" href="https://cdn.example.com/campaign.css"><p>Hi</p>`
got := SanitizeHTML(in)
if !strings.Contains(got, `href="https://cdn.example.com/campaign.css"`) {
t.Fatalf("expected stylesheet link preserved, got %q", got)
}
}
func TestSanitizeHTML_empty(t *testing.T) { func TestSanitizeHTML_empty(t *testing.T) {
if got := SanitizeHTML(""); got != "" { if got := SanitizeHTML(""); got != "" {
t.Fatalf("expected empty string, got %q", got) t.Fatalf("expected empty string, got %q", got)

View File

@ -0,0 +1,44 @@
package threading
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
)
// ApplyMessageThread assigns thread_id for one message and propagates to direct replies.
func ApplyMessageThread(
ctx context.Context,
db *pgxpool.Pool,
accountID, rowID, rfcMessageID, inReplyTo string,
references []string,
) error {
threadID, err := AssignThreadID(ctx, db, accountID, inReplyTo, references)
if err != nil {
return err
}
if _, err := db.Exec(ctx, `
UPDATE messages SET thread_id = $1::uuid, updated_at = NOW() WHERE id = $2
`, threadID, rowID); err != nil {
return err
}
return propagateThreadToReplies(ctx, db, accountID, threadID, rfcMessageID)
}
func propagateThreadToReplies(ctx context.Context, db *pgxpool.Pool, accountID, threadID, rfcMessageID string) error {
rfcMessageID = NormalizeMessageID(rfcMessageID)
if rfcMessageID == "" {
return nil
}
_, err := db.Exec(ctx, `
UPDATE messages
SET thread_id = $1::uuid, updated_at = NOW()
WHERE account_id = $2
AND thread_id IS DISTINCT FROM $1::uuid
AND (
in_reply_to = $3
OR $3 = ANY(references_header)
)
`, threadID, accountID, rfcMessageID)
return err
}

View File

@ -36,6 +36,13 @@ func TestBuildReferences(t *testing.T) {
} }
} }
func TestNormalizeMessageID_imapEnvelopeWithoutBrackets(t *testing.T) {
// go-imap Envelope.MessageID is documented without angle brackets.
if got := NormalizeMessageID("abc@host.test"); got != "<abc@host.test>" {
t.Fatalf("NormalizeMessageID() = %q, want %q", got, "<abc@host.test>")
}
}
func TestBuildReferences_dedupesParent(t *testing.T) { func TestBuildReferences_dedupesParent(t *testing.T) {
got := BuildReferences("<b@y>", []string{"<a@x>", "<b@y>"}) got := BuildReferences("<b@y>", []string{"<a@x>", "<b@y>"})
want := []string{"<a@x>", "<b@y>"} want := []string{"<a@x>", "<b@y>"}

View File

@ -13,6 +13,7 @@ type Client struct {
httpClient *http.Client httpClient *http.Client
adminUser string adminUser string
adminPass string adminPass string
credStore *DAVCredentialStore
} }
func NewClient(baseURL, adminUser, adminPass string) *Client { func NewClient(baseURL, adminUser, adminPass string) *Client {
@ -26,6 +27,14 @@ func NewClient(baseURL, adminUser, adminPass string) *Client {
} }
} }
func (c *Client) WithDAVCredentials(store *DAVCredentialStore) *Client {
if c == nil {
return nil
}
c.credStore = store
return c
}
func (c *Client) doRequest(ctx context.Context, method, path string, body io.Reader, headers map[string]string) (*http.Response, error) { func (c *Client) doRequest(ctx context.Context, method, path string, body io.Reader, headers map[string]string) (*http.Response, error) {
url := c.baseURL + path url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, body) req, err := http.NewRequestWithContext(ctx, method, url, body)
@ -43,20 +52,43 @@ func (c *Client) doRequest(ctx context.Context, method, path string, body io.Rea
} }
func (c *Client) DoAsUser(ctx context.Context, method, path string, body io.Reader, userID string, headers map[string]string) (*http.Response, error) { func (c *Client) DoAsUser(ctx context.Context, method, path string, body io.Reader, userID string, headers map[string]string) (*http.Response, error) {
token, err := c.userDAVToken(ctx, userID)
if err != nil {
return nil, err
}
url := c.baseURL + path url := c.baseURL + path
req, err := http.NewRequestWithContext(ctx, method, url, body) req, err := http.NewRequestWithContext(ctx, method, url, body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
req.SetBasicAuth(c.adminUser, c.adminPass) req.SetBasicAuth(userID, token)
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("X-NC-User", userID)
for k, v := range headers { for k, v := range headers {
req.Header.Set(k, v) req.Header.Set(k, v)
} }
return c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusUnauthorized && c.credStore != nil {
_ = c.credStore.DeleteToken(ctx, userID)
resp.Body.Close()
return nil, ErrDAVCredentialsMissing
}
return resp, nil
}
func (c *Client) userDAVToken(ctx context.Context, userID string) (string, error) {
if c.credStore == nil {
return "", fmt.Errorf("nextcloud dav credentials store not configured")
}
token, err := c.credStore.GetToken(ctx, userID)
if err != nil {
return "", err
}
return token, nil
} }
func (c *Client) WebDAVPath(userID, path string) string { func (c *Client) WebDAVPath(userID, path string) string {

View File

@ -38,7 +38,7 @@ type ContactSyncResult struct {
} }
func (c *Client) ListAddressBooks(ctx context.Context, userID string) ([]AddressBook, error) { func (c *Client) ListAddressBooks(ctx context.Context, userID string) ([]AddressBook, error) {
path := fmt.Sprintf("/remote.php/dav/addressbooks/users/%s/", userID) path := addressBookHomePath(userID)
body := `<?xml version="1.0" encoding="UTF-8"?> body := `<?xml version="1.0" encoding="UTF-8"?>
<d:propfind xmlns:d="DAV:"> <d:propfind xmlns:d="DAV:">
<d:prop> <d:prop>
@ -56,7 +56,14 @@ func (c *Client) ListAddressBooks(ctx context.Context, userID string) ([]Address
} }
defer resp.Body.Close() defer resp.Body.Close()
return parseAddressBookList(resp.Body, path) raw, err := readResponseBody(resp)
if err != nil {
return nil, err
}
if err := davResponseError(raw, resp.StatusCode); err != nil {
return nil, err
}
return parseAddressBookList(strings.NewReader(string(raw)), path)
} }
func (c *Client) ListContacts(ctx context.Context, userID, bookPath string) ([]Contact, error) { func (c *Client) ListContacts(ctx context.Context, userID, bookPath string) ([]Contact, error) {
@ -77,7 +84,14 @@ func (c *Client) ListContacts(ctx context.Context, userID, bookPath string) ([]C
} }
defer resp.Body.Close() defer resp.Body.Close()
return parseContactList(resp.Body) raw, err := readResponseBody(resp)
if err != nil {
return nil, err
}
if err := davResponseError(raw, resp.StatusCode); err != nil {
return nil, err
}
return parseContactList(strings.NewReader(string(raw)))
} }
func (c *Client) SyncContacts(ctx context.Context, userID, bookPath, syncToken string) (ContactSyncResult, error) { func (c *Client) SyncContacts(ctx context.Context, userID, bookPath, syncToken string) (ContactSyncResult, error) {
@ -201,7 +215,14 @@ func (c *Client) SearchContacts(ctx context.Context, userID, bookPath, query str
} }
defer resp.Body.Close() defer resp.Body.Close()
return parseContactList(resp.Body) raw, err := readResponseBody(resp)
if err != nil {
return nil, err
}
if err := davResponseError(raw, resp.StatusCode); err != nil {
return nil, err
}
return parseContactList(strings.NewReader(string(raw)))
} }
func buildVCard(contact *Contact) string { func buildVCard(contact *Contact) string {
@ -231,9 +252,11 @@ func parseAddressBookList(body io.Reader, basePath string) ([]AddressBook, error
return nil, err return nil, err
} }
basePath = normalizeDAVHref(basePath)
books := make([]AddressBook, 0) books := make([]AddressBook, 0)
for _, r := range ms.Responses { for _, r := range ms.Responses {
if r.Href == basePath { href := normalizeDAVHref(r.Href)
if href == basePath {
continue continue
} }
name := r.Propstat.Prop.DisplayName name := r.Propstat.Prop.DisplayName
@ -241,14 +264,22 @@ func parseAddressBookList(body io.Reader, basePath string) ([]AddressBook, error
continue continue
} }
books = append(books, AddressBook{ books = append(books, AddressBook{
ID: strings.TrimSuffix(strings.TrimPrefix(r.Href, basePath), "/"), ID: strings.TrimSuffix(strings.TrimPrefix(href, basePath), "/"),
DisplayName: name, DisplayName: name,
Path: r.Href, Path: href,
}) })
} }
return books, nil return books, nil
} }
func normalizeDAVHref(href string) string {
href = strings.TrimSpace(href)
if strings.HasPrefix(href, "/cloud/") {
return strings.TrimPrefix(href, "/cloud")
}
return href
}
func buildSyncCollectionRequest(syncToken string) string { func buildSyncCollectionRequest(syncToken string) string {
var b strings.Builder var b strings.Builder
b.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`) b.WriteString(`<?xml version="1.0" encoding="UTF-8"?>`)

View File

@ -0,0 +1,69 @@
package nextcloud
import (
"context"
"errors"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
)
var ErrDAVCredentialsMissing = errors.New("nextcloud dav credentials missing")
type DAVCredentialStore struct {
db *pgxpool.Pool
enc *credentials.Manager
}
func NewDAVCredentialStore(db *pgxpool.Pool, enc *credentials.Manager) *DAVCredentialStore {
if db == nil || enc == nil {
return nil
}
return &DAVCredentialStore{db: db, enc: enc}
}
func (s *DAVCredentialStore) GetToken(ctx context.Context, ncUserID string) (string, error) {
if s == nil {
return "", ErrDAVCredentialsMissing
}
var blob []byte
err := s.db.QueryRow(ctx, `
SELECT dav_token FROM nextcloud_dav_credentials WHERE nc_user_id = $1
`, ncUserID).Scan(&blob)
if errors.Is(err, pgx.ErrNoRows) {
return "", ErrDAVCredentialsMissing
}
if err != nil {
return "", err
}
_, token, err := s.enc.Decrypt(blob)
return token, err
}
func (s *DAVCredentialStore) SaveToken(ctx context.Context, ncUserID, token string) error {
if s == nil {
return fmt.Errorf("nextcloud dav credential store unavailable")
}
blob, err := s.enc.Encrypt(ncUserID, token)
if err != nil {
return err
}
_, err = s.db.Exec(ctx, `
INSERT INTO nextcloud_dav_credentials (nc_user_id, dav_token, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (nc_user_id) DO UPDATE
SET dav_token = EXCLUDED.dav_token, updated_at = NOW()
`, ncUserID, blob)
return err
}
func (s *DAVCredentialStore) DeleteToken(ctx context.Context, ncUserID string) error {
if s == nil {
return nil
}
_, err := s.db.Exec(ctx, `DELETE FROM nextcloud_dav_credentials WHERE nc_user_id = $1`, ncUserID)
return err
}

291
internal/nextcloud/users.go Normal file
View File

@ -0,0 +1,291 @@
package nextcloud
import (
"context"
"crypto/rand"
"encoding/json"
"errors"
"fmt"
"io"
"math/big"
"net/http"
"net/url"
"strings"
)
var ErrPrincipalNotFound = errors.New("nextcloud principal not found")
// UserIDFromClaims returns the Nextcloud account id aligned with user_oidc mapping-uid
// (preferred_username / enrollment email), not the opaque OIDC sub.
func UserIDFromClaims(email, sub string) string {
email = strings.TrimSpace(strings.ToLower(email))
if email != "" {
return email
}
return strings.TrimSpace(sub)
}
// EnsurePrincipal provisions a Nextcloud user and CardDAV app credentials.
func (c *Client) EnsurePrincipal(ctx context.Context, email, sub, displayName string) (string, error) {
if c.credStore == nil {
return "", fmt.Errorf("nextcloud dav credentials store not configured")
}
userID := UserIDFromClaims(email, sub)
if userID == "" {
return "", fmt.Errorf("nextcloud user id is empty")
}
token, err := c.credStore.GetToken(ctx, userID)
if err == nil && token != "" {
return userID, nil
}
exists, err := c.userExists(ctx, userID)
if err != nil {
return "", err
}
provisionEmail := strings.TrimSpace(email)
if provisionEmail == "" {
provisionEmail = userID
}
name := strings.TrimSpace(displayName)
if name == "" {
name = provisionEmail
}
loginPassword, err := generateNextcloudPassword()
if err != nil {
return "", err
}
if !exists {
if err := c.createUser(ctx, userID, provisionEmail, name, loginPassword); err != nil {
return "", err
}
} else if err := c.setUserPassword(ctx, userID, loginPassword); err != nil {
return "", err
}
appPassword, err := c.createAppPassword(ctx, userID, loginPassword)
if err != nil {
return "", err
}
if err := c.credStore.SaveToken(ctx, userID, appPassword); err != nil {
return "", err
}
return userID, nil
}
func (c *Client) userExists(ctx context.Context, userID string) (bool, error) {
path := fmt.Sprintf("/ocs/v1.php/cloud/users/%s?format=json", url.PathEscape(userID))
resp, err := c.doRequest(ctx, http.MethodGet, path, nil, map[string]string{
"OCS-APIRequest": "true",
"Accept": "application/json",
})
if err != nil {
return false, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
return false, nil
}
if resp.StatusCode != http.StatusOK {
return false, &HTTPStatusError{Operation: "get user", StatusCode: resp.StatusCode}
}
var payload struct {
OCS struct {
Meta struct {
StatusCode int `json:"statuscode"`
} `json:"meta"`
} `json:"ocs"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return false, err
}
return payload.OCS.Meta.StatusCode == 100, nil
}
func (c *Client) setUserPassword(ctx context.Context, userID, password string) error {
form := url.Values{}
form.Set("key", "password")
form.Set("value", password)
path := fmt.Sprintf("/ocs/v1.php/cloud/users/%s", url.PathEscape(userID))
resp, err := c.doRequest(ctx, http.MethodPut, path, strings.NewReader(form.Encode()), map[string]string{
"OCS-APIRequest": "true",
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
})
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return &HTTPStatusError{Operation: "set user password", StatusCode: resp.StatusCode}
}
var payload struct {
OCS struct {
Meta struct {
Status string `json:"status"`
StatusCode int `json:"statuscode"`
} `json:"meta"`
} `json:"ocs"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return err
}
if strings.EqualFold(payload.OCS.Meta.Status, "ok") || payload.OCS.Meta.StatusCode == 100 {
return nil
}
return fmt.Errorf("set nextcloud user password failed with status %d", payload.OCS.Meta.StatusCode)
}
func (c *Client) createAppPassword(ctx context.Context, userID, loginPassword string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.baseURL+"/ocs/v2.php/core/getapppassword?format=json", nil)
if err != nil {
return "", err
}
req.SetBasicAuth(userID, loginPassword)
req.Header.Set("OCS-APIRequest", "true")
req.Header.Set("Accept", "application/json")
resp, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", &HTTPStatusError{Operation: "create app password", StatusCode: resp.StatusCode}
}
var payload struct {
OCS struct {
Data struct {
AppPassword string `json:"apppassword"`
} `json:"data"`
} `json:"ocs"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return "", err
}
token := strings.TrimSpace(payload.OCS.Data.AppPassword)
if token == "" {
return "", fmt.Errorf("nextcloud app password response empty")
}
return token, nil
}
func (c *Client) createUser(ctx context.Context, userID, email, displayName, password string) error {
form := url.Values{}
form.Set("userid", userID)
form.Set("password", password)
form.Set("email", email)
form.Set("displayName", displayName)
resp, err := c.doRequest(ctx, http.MethodPost, "/ocs/v1.php/cloud/users?format=json", strings.NewReader(form.Encode()), map[string]string{
"OCS-APIRequest": "true",
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
})
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return &HTTPStatusError{Operation: "create user", StatusCode: resp.StatusCode}
}
var payload struct {
OCS struct {
Meta struct {
Status string `json:"status"`
StatusCode int `json:"statuscode"`
Message string `json:"message"`
} `json:"meta"`
} `json:"ocs"`
}
if err := json.NewDecoder(resp.Body).Decode(&payload); err != nil {
return err
}
if strings.EqualFold(payload.OCS.Meta.Status, "ok") || payload.OCS.Meta.StatusCode == 100 {
return nil
}
if payload.OCS.Meta.Message != "" {
return fmt.Errorf("create nextcloud user: %s", payload.OCS.Meta.Message)
}
return fmt.Errorf("create nextcloud user failed with status %d", payload.OCS.Meta.StatusCode)
}
func generateNextcloudPassword() (string, error) {
const (
lower = "abcdefghijklmnopqrstuvwxyz"
upper = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
digits = "0123456789"
symbols = "!@#$%^&*()-_=+"
all = lower + upper + digits + symbols
)
pick := func(chars string) (byte, error) {
n, err := rand.Int(rand.Reader, big.NewInt(int64(len(chars))))
if err != nil {
return 0, err
}
return chars[n.Int64()], nil
}
out := make([]byte, 32)
required := []string{lower, upper, digits, symbols}
for i, chars := range required {
b, err := pick(chars)
if err != nil {
return "", err
}
out[i] = b
}
for i := len(required); i < len(out); i++ {
b, err := pick(all)
if err != nil {
return "", err
}
out[i] = b
}
for i := len(out) - 1; i > 0; i-- {
j, err := rand.Int(rand.Reader, big.NewInt(int64(i+1)))
if err != nil {
return "", err
}
out[i], out[j.Int64()] = out[j.Int64()], out[i]
}
return string(out), nil
}
func readResponseBody(resp *http.Response) ([]byte, error) {
raw, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
return raw, nil
}
func davResponseError(raw []byte, statusCode int) error {
if statusCode == http.StatusNotFound {
return ErrPrincipalNotFound
}
if statusCode != http.StatusMultiStatus && statusCode != http.StatusOK {
if strings.Contains(string(raw), "<d:error") {
return &HTTPStatusError{Operation: "carddav", StatusCode: statusCode}
}
return &HTTPStatusError{Operation: "carddav", StatusCode: statusCode}
}
return nil
}
func addressBookHomePath(userID string) string {
return fmt.Sprintf("/remote.php/dav/addressbooks/users/%s/", url.PathEscape(userID))
}
// AddressBookPath returns the CardDAV collection path for a user's address book.
func AddressBookPath(userID, bookID string) string {
return addressBookHomePath(userID) + url.PathEscape(bookID) + "/"
}

View File

@ -0,0 +1,42 @@
package nextcloud
import "testing"
func TestUserIDFromClaimsPrefersEmail(t *testing.T) {
got := UserIDFromClaims("User@Example.com", "opaque-sub")
if got != "user@example.com" {
t.Fatalf("UserIDFromClaims() = %q", got)
}
}
func TestUserIDFromClaimsFallbackSub(t *testing.T) {
got := UserIDFromClaims("", "opaque-sub")
if got != "opaque-sub" {
t.Fatalf("UserIDFromClaims() = %q", got)
}
}
func TestNormalizeDAVHref(t *testing.T) {
got := normalizeDAVHref("/cloud/remote.php/dav/addressbooks/users/alice/contacts/")
want := "/remote.php/dav/addressbooks/users/alice/contacts/"
if got != want {
t.Fatalf("normalizeDAVHref() = %q, want %q", got, want)
}
}
func TestDavResponseErrorNotFound(t *testing.T) {
raw := []byte(`<?xml version="1.0"?><d:error xmlns:d="DAV:"><d:message>missing</d:message></d:error>`)
if err := davResponseError(raw, 404); err != ErrPrincipalNotFound {
t.Fatalf("davResponseError() = %v", err)
}
}
func TestGenerateNextcloudPassword(t *testing.T) {
pw, err := generateNextcloudPassword()
if err != nil {
t.Fatal(err)
}
if len(pw) < 24 {
t.Fatalf("password too short: %d", len(pw))
}
}

View File

@ -60,6 +60,66 @@ func levelRank(l Level) int {
return int(l) return int(l)
} }
var suiteResources = []Resource{
ResourceContacts,
ResourceCalendar,
ResourceDrive,
ResourcePhotos,
}
func hasAnyResourcePermission(groups []string) bool {
for _, g := range groups {
g = strings.ToLower(strings.TrimSpace(g))
for _, resource := range suiteResources {
if strings.HasPrefix(g, string(resource)+":") {
return true
}
}
}
return false
}
// WithSuiteDefaults grants standard suite read/write when the token carries no
// resource-scoped groups. Mail endpoints stay open; CardDAV/CalDAV modules rely
// on this until Authentik emits explicit RBAC groups on every account.
func WithSuiteDefaults(groups []string) []string {
if hasAnyResourcePermission(groups) {
return groups
}
defaults := []string{
string(RoleUser),
string(ResourceContacts) + ":write",
string(ResourceCalendar) + ":write",
string(ResourceDrive) + ":write",
string(ResourcePhotos) + ":write",
}
seen := make(map[string]struct{}, len(groups)+len(defaults))
out := make([]string, 0, len(groups)+len(defaults))
for _, g := range groups {
g = strings.TrimSpace(g)
if g == "" {
continue
}
key := strings.ToLower(g)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, g)
}
for _, g := range defaults {
key := strings.ToLower(g)
if _, ok := seen[key]; ok {
continue
}
seen[key] = struct{}{}
out = append(out, g)
}
return out
}
// AdminScope is a fine-grained admin API permission with read < write ordering. // AdminScope is a fine-grained admin API permission with read < write ordering.
type AdminScope int type AdminScope int

View File

@ -57,6 +57,42 @@ func TestHasPermissionResourceAdmin(t *testing.T) {
} }
} }
func TestWithSuiteDefaultsEmptyGroups(t *testing.T) {
groups := WithSuiteDefaults(nil)
if !HasRole(groups, RoleUser) {
t.Fatal("expected role:user")
}
if !HasPermission(groups, ResourceContacts, LevelWrite) {
t.Fatal("expected contacts write")
}
if !HasPermission(groups, ResourceCalendar, LevelWrite) {
t.Fatal("expected calendar write")
}
}
func TestWithSuiteDefaultsPreservesExplicitResource(t *testing.T) {
groups := WithSuiteDefaults([]string{"contacts:read"})
if !HasPermission(groups, ResourceContacts, LevelRead) {
t.Fatal("expected contacts read")
}
if HasPermission(groups, ResourceContacts, LevelWrite) {
t.Fatal("contacts:read must not be upgraded to write")
}
if HasPermission(groups, ResourceDrive, LevelRead) {
t.Fatal("must not grant drive when contacts-only group is present")
}
}
func TestWithSuiteDefaultsUserRoleOnly(t *testing.T) {
groups := WithSuiteDefaults([]string{"role:user"})
if !HasPermission(groups, ResourceContacts, LevelWrite) {
t.Fatal("role:user without resource groups should get suite defaults")
}
}
func TestHasPermissionIsolation(t *testing.T) { func TestHasPermissionIsolation(t *testing.T) {
groups := []string{"contacts:read"} groups := []string{"contacts:read"}

View File

@ -0,0 +1,4 @@
DROP INDEX IF EXISTS idx_mail_rules_kind;
ALTER TABLE mail_rules DROP COLUMN IF EXISTS workflow;
ALTER TABLE mail_rules DROP COLUMN IF EXISTS rule_kind;

View File

@ -0,0 +1,11 @@
ALTER TABLE mail_rules
ADD COLUMN IF NOT EXISTS rule_kind TEXT NOT NULL DEFAULT 'rule';
ALTER TABLE mail_rules
ADD COLUMN IF NOT EXISTS workflow JSONB;
CREATE INDEX IF NOT EXISTS idx_mail_rules_kind
ON mail_rules(user_id, rule_kind);
COMMENT ON COLUMN mail_rules.rule_kind IS 'rule = triggered automation, function = reusable subroutine';
COMMENT ON COLUMN mail_rules.workflow IS 'Graphical workflow definition (triggers, nodes, edges, variables)';

View File

@ -0,0 +1 @@
-- Irreversible: Message-ID normalization cannot be safely reversed.

View File

@ -0,0 +1,28 @@
-- Canonicalize RFC Message-IDs so threading lookups (angle-bracket form) match stored values.
UPDATE messages
SET message_id = '<' || trim(both '<>' from trim(message_id)) || '>',
updated_at = NOW()
WHERE message_id <> ''
AND message_id NOT LIKE '<%>';
-- Re-link split threads from in_reply_to / references (repeat for nested replies).
DO $$
DECLARE
i INT;
n BIGINT;
BEGIN
FOR i IN 1..8 LOOP
UPDATE messages child
SET thread_id = parent.thread_id, updated_at = NOW()
FROM messages parent
WHERE child.account_id = parent.account_id
AND parent.thread_id IS NOT NULL
AND child.thread_id IS DISTINCT FROM parent.thread_id
AND (
(child.in_reply_to <> '' AND child.in_reply_to = parent.message_id)
OR parent.message_id = ANY(child.references_header)
);
GET DIAGNOSTICS n = ROW_COUNT;
EXIT WHEN n = 0;
END LOOP;
END $$;

View File

@ -0,0 +1 @@
ALTER TABLE messages DROP COLUMN IF EXISTS auth_info;

View File

@ -0,0 +1,2 @@
ALTER TABLE messages
ADD COLUMN IF NOT EXISTS auth_info JSONB NOT NULL DEFAULT '{}';

View File

@ -0,0 +1 @@
DROP TABLE IF EXISTS nextcloud_dav_credentials;

View File

@ -0,0 +1,8 @@
CREATE TABLE nextcloud_dav_credentials (
nc_user_id TEXT PRIMARY KEY,
dav_token BYTEA NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX idx_nextcloud_dav_credentials_updated_at ON nextcloud_dav_credentials (updated_at);