diff --git a/.env.example b/.env.example index 61f3a29..04a9c3a 100644 --- a/.env.example +++ b/.env.example @@ -255,3 +255,8 @@ SEARCH_ENGINE=postgres # TYPESENSE_URL=http://typesense:8108 # TYPESENSE_API_KEY={{TYPESENSE_API_KEY}} # TYPESENSE_COLLECTION=ulti + +# ----------------------------------------------------------------------------- +# VirusTotal (optional env fallback; prefer admin Settings > File policies) +# ----------------------------------------------------------------------------- +# VIRUSTOTAL_API_KEY= diff --git a/internal/api/admin/org_settings.go b/internal/api/admin/org_settings.go index e7a11ae..c6e8570 100644 --- a/internal/api/admin/org_settings.go +++ b/internal/api/admin/org_settings.go @@ -51,6 +51,7 @@ func defaultOrgPolicy() map[string]any { "external_sharing": "authenticated", "default_link_expiry_days": 30, "virus_scan_enabled": false, + "virustotal_api_key": "", "retention_trash_days": 30, }, "llm": map[string]any{ @@ -174,6 +175,7 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any { {"onlyoffice", "jwt_secret"}, {"search", "meilisearch_api_key"}, {"search", "typesense_api_key"}, + {"file_policies", "virustotal_api_key"}, } for _, p := range secretPaths { patchSection, _ := patch[p.section].(map[string]any) @@ -296,6 +298,7 @@ func maskOrgPolicy(policy map[string]any) map[string]any { maskStringField(cloned, "onlyoffice", "jwt_secret") maskStringField(cloned, "search", "meilisearch_api_key") maskStringField(cloned, "search", "typesense_api_key") + maskStringField(cloned, "file_policies", "virustotal_api_key") if llm, ok := cloned["llm"].(map[string]any); ok { if providers, ok := llm["providers"].([]any); ok { for i, p := range providers { @@ -374,6 +377,9 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string "typesense_api_key": map[string]any{ "configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "", }, + "virustotal_api_key": map[string]any{ + "configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "", + }, } } diff --git a/internal/api/drive/automation_hooks.go b/internal/api/drive/automation_hooks.go index 55bae0f..b9d967c 100644 --- a/internal/api/drive/automation_hooks.go +++ b/internal/api/drive/automation_hooks.go @@ -5,6 +5,7 @@ import ( "path" "github.com/ultisuite/ulti-backend/internal/automation" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/nextcloud" ) @@ -13,6 +14,10 @@ func (s *Service) SetAutomation(d driveAutomation) { s.automation = d } +func (s *Service) SetFileScanner(scanner *filescan.Scanner) { + s.scanner = scanner +} + func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) { normalized := nextcloud.NormalizeClientPath(filePath) s.notifyFileChanged(externalUserID, normalized) diff --git a/internal/api/drive/handlers.go b/internal/api/drive/handlers.go index 99d9713..73559b8 100644 --- a/internal/api/drive/handlers.go +++ b/internal/api/drive/handlers.go @@ -733,6 +733,8 @@ func writeDriveError(w http.ResponseWriter, r *http.Request, err error) { apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil) case errors.Is(err, ErrQuotaExceeded): apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil) + case errors.Is(err, ErrMalware): + apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "drive.malware_detected", "malware detected in file", nil) case errors.Is(err, ErrInvalid): apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil) default: diff --git a/internal/api/drive/public_service.go b/internal/api/drive/public_service.go index 3952bd4..52b9baf 100644 --- a/internal/api/drive/public_service.go +++ b/internal/api/drive/public_service.go @@ -1,11 +1,14 @@ package drive import ( + "bytes" "context" + "errors" "io" "path" "strings" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/nextcloud" ) @@ -17,7 +20,18 @@ func (s *Service) UploadPublicShare(ctx context.Context, token, filePath, passwo if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) { return ErrForbidden } - if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, body, contentType)); err != nil { + reader := body + if s.scanner != nil { + data, _, err := s.scanner.ScanReader(ctx, filePath, body, -1) + if err != nil { + if errors.Is(err, filescan.ErrMalicious) { + return ErrMalware + } + return err + } + reader = bytes.NewReader(data) + } + if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, reader, contentType)); err != nil { return err } s.recordPublicShareAccess(ctx, token) diff --git a/internal/api/drive/service.go b/internal/api/drive/service.go index 0b05f69..2631cac 100644 --- a/internal/api/drive/service.go +++ b/internal/api/drive/service.go @@ -18,6 +18,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/automation" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/publicshare" @@ -30,6 +31,7 @@ var ( ErrForbidden = errors.New("forbidden") ErrQuotaExceeded = errors.New("quota exceeded") ErrInvalid = errors.New("invalid request") + ErrMalware = errors.New("malware detected") ) type Service struct { @@ -37,6 +39,7 @@ type Service struct { hub *realtime.Hub db *pgxpool.Pool automation driveAutomation + scanner *filescan.Scanner maxUploadBytes int64 quotaReserveByte int64 } @@ -184,7 +187,18 @@ func (s *Service) Upload(ctx context.Context, userID, path string, body io.Reade if err := s.ensureQuota(ctx, userID, contentLength); err != nil { return err } - return mapDriveError(s.nc.Upload(ctx, userID, path, body, contentType)) + reader := body + if s.scanner != nil { + data, _, err := s.scanner.ScanReader(ctx, path, body, contentLength) + if err != nil { + if errors.Is(err, filescan.ErrMalicious) { + return ErrMalware + } + return err + } + reader = bytes.NewReader(data) + } + return mapDriveError(s.nc.Upload(ctx, userID, path, reader, contentType)) } func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error { @@ -202,6 +216,31 @@ func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil { return err } + if s.scanner != nil { + if err := s.scanAssembledUpload(ctx, userID, targetPath); err != nil { + return err + } + } + return nil +} + +func (s *Service) scanAssembledUpload(ctx context.Context, userID, targetPath string) error { + body, _, err := s.nc.Download(ctx, userID, targetPath) + if err != nil { + return mapDriveError(err) + } + defer body.Close() + + data, scanResult, err := s.scanner.ScanReader(ctx, targetPath, body, -1) + if err != nil { + if errors.Is(err, filescan.ErrMalicious) { + _ = s.nc.Delete(ctx, userID, targetPath) + return ErrMalware + } + return err + } + _ = data + _ = scanResult return nil } diff --git a/internal/api/mail/attachments.go b/internal/api/mail/attachments.go index fb42f26..550dc2b 100644 --- a/internal/api/mail/attachments.go +++ b/internal/api/mail/attachments.go @@ -1,6 +1,7 @@ package mail import ( + "bytes" "context" "encoding/json" "errors" @@ -39,7 +40,7 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag } rows, err := s.db.Query(ctx, ` - SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, '') + SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, ''), virus_scan_status FROM attachments WHERE message_id = $1 ORDER BY created_at ASC `, messageID) @@ -50,15 +51,16 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag out := make([]map[string]any, 0) for rows.Next() { - var id, filename, contentType, contentID, drivePath string + var id, filename, contentType, contentID, drivePath, virusScanStatus string var size int64 var isInline bool - if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath); err != nil { + if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath, &virusScanStatus); err != nil { return nil, err } entry := map[string]any{ "id": id, "filename": filename, "content_type": contentType, "size": size, "is_inline": isInline, + "virus_scan_status": virusScanStatus, } if contentID != "" { entry["content_id"] = contentID @@ -142,16 +144,28 @@ func (s *Service) UploadMessageAttachment( } objectKey := storage.MessageObjectKey(userID, messageID, filename) - if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { + scanStatus := "skipped" + putReader := reader + putSize := size + if s.scanner != nil { + data, result, err := s.scanner.ScanReader(ctx, filename, reader, size) + if err != nil { + return "", err + } + scanStatus = result.Status + putReader = bytes.NewReader(data) + putSize = int64(len(data)) + } + if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil { return "", err } var id string err = s.db.QueryRow(ctx, ` - INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id - `, messageID, filename, contentType, size, s.storageBucket(), objectKey, contentID, isInline).Scan(&id) + `, messageID, filename, contentType, putSize, s.storageBucket(), objectKey, contentID, isInline, scanStatus).Scan(&id) if err != nil { _ = s.storage.Delete(ctx, objectKey) return "", err @@ -224,7 +238,17 @@ func (s *Service) UploadDraftAttachment( } objectKey := storage.DraftObjectKey(userID, draftID, filename) - if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { + putReader := reader + putSize := size + if s.scanner != nil { + data, _, err := s.scanner.ScanReader(ctx, filename, reader, size) + if err != nil { + return "", err + } + putReader = bytes.NewReader(data) + putSize = int64(len(data)) + } + if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil { return "", err } @@ -237,13 +261,13 @@ func (s *Service) UploadDraftAttachment( for _, ref := range refs { totalSize += ref.Size } - if err := limits.ValidateAttachmentQuota(len(refs), totalSize, size); err != nil { + if err := limits.ValidateAttachmentQuota(len(refs), totalSize, putSize); err != nil { return "", err } attID := uuid.NewString() refs = append(refs, draftAttachmentRef{ - ID: attID, Filename: filename, ContentType: contentType, Size: size, + ID: attID, Filename: filename, ContentType: contentType, Size: putSize, S3Bucket: s.storageBucket(), S3Key: objectKey, ContentID: contentID, IsInline: isInline, }) diff --git a/internal/api/mail/drive_save.go b/internal/api/mail/drive_save.go index 4118962..3bb1fbb 100644 --- a/internal/api/mail/drive_save.go +++ b/internal/api/mail/drive_save.go @@ -6,6 +6,7 @@ import ( "github.com/jackc/pgx/v5" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/nextcloud" ) @@ -15,6 +16,10 @@ func (s *Service) SetDriveUploader(uploader DriveUploader) { s.driveUploader = uploader } +func (s *Service) SetFileScanner(scanner *filescan.Scanner) { + s.scanner = scanner +} + func (s *Service) SaveAttachmentToDrive( ctx context.Context, externalID, email, sub, displayName, messageID, attachmentID, folderPath string, diff --git a/internal/api/mail/handlers.go b/internal/api/mail/handlers.go index a325d77..730fe14 100644 --- a/internal/api/mail/handlers.go +++ b/internal/api/mail/handlers.go @@ -13,6 +13,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/mail/sendguard" "github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/limits" mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth" @@ -42,6 +43,13 @@ func (h *Handler) SetDriveUploader(uploader DriveUploader) { } } +// SetFileScanner wires VirusTotal scanning for mail attachment uploads. +func (h *Handler) SetFileScanner(scanner *filescan.Scanner) { + if s, ok := h.svc.(*Service); ok { + s.SetFileScanner(scanner) + } +} + func NewHandlerWithService(svc ServiceAPI) *Handler { return &Handler{ svc: svc, diff --git a/internal/api/mail/handlers_attachments.go b/internal/api/mail/handlers_attachments.go index 0b0645b..8859c91 100644 --- a/internal/api/mail/handlers_attachments.go +++ b/internal/api/mail/handlers_attachments.go @@ -15,6 +15,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/api/apivalidate" driveapi "github.com/ultisuite/ulti-backend/internal/api/drive" "github.com/ultisuite/ulti-backend/internal/api/middleware" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/mail/limits" ) @@ -342,6 +343,9 @@ func writeAttachmentUploadError(w http.ResponseWriter, r *http.Request, err erro case errors.Is(err, limits.ErrTooManyAttachments): apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil) return true + case errors.Is(err, ErrMalware), errors.Is(err, filescan.ErrMalicious): + apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "mail.malware_detected", "malware detected in attachment", nil) + return true default: return false } diff --git a/internal/api/mail/service.go b/internal/api/mail/service.go index c559a17..94748ac 100644 --- a/internal/api/mail/service.go +++ b/internal/api/mail/service.go @@ -11,6 +11,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/api/query" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/imap" "github.com/ultisuite/ulti-backend/internal/mail/sanitize" @@ -28,6 +29,7 @@ var ( ErrInvalidAccountCredentials = errors.New("account credentials invalid") ErrInvalidFolderScope = errors.New("invalid folder scope") ErrFolderHasChildren = errors.New("folder has children") + ErrMalware = filescan.ErrMalicious ) type Service struct { @@ -37,6 +39,7 @@ type Service struct { storage *storage.Client attachmentsBucket string driveUploader DriveUploader + scanner *filescan.Scanner logger *slog.Logger } diff --git a/internal/config/config.go b/internal/config/config.go index 9d5dfc6..71ac13c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -116,6 +116,9 @@ type Config struct { TypesenseKey string TypesenseCollection string + // VirusTotal (optional env fallback for org file_policies.virustotal_api_key) + VirusTotalAPIKey string + // Observability HealthNextcloudURL string HealthImmichURL string @@ -221,6 +224,8 @@ func Load() (*Config, error) { TypesenseKey: secrets.Env("TYPESENSE_API_KEY"), TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"), + VirusTotalAPIKey: secrets.Env("VIRUSTOTAL_API_KEY"), + HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")), HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")), HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))), diff --git a/internal/filescan/scanner.go b/internal/filescan/scanner.go new file mode 100644 index 0000000..6b7907c --- /dev/null +++ b/internal/filescan/scanner.go @@ -0,0 +1,144 @@ +package filescan + +import ( + "bytes" + "context" + "errors" + "io" + "log/slog" + "os" + "path/filepath" + + "github.com/ultisuite/ulti-backend/internal/orgpolicy" + "github.com/ultisuite/ulti-backend/internal/virustotal" +) + +const memoryBufferLimit = 32 * 1024 * 1024 + +// Result holds scan outcome for persistence. +type Result struct { + Status string // clean | skipped +} + +// PolicyLoader supplies org file policies at runtime. +type PolicyLoader interface { + FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) + ScanEnabled(ctx context.Context) (bool, string, error) +} + +// Scanner coordinates org policy and VirusTotal scanning. +type Scanner struct { + policies PolicyLoader + logger *slog.Logger +} + +func NewScanner(policies PolicyLoader, logger *slog.Logger) *Scanner { + if logger == nil { + logger = slog.Default() + } + return &Scanner{policies: policies, logger: logger} +} + +// ScanReader reads all bytes from r (up to maxBytes), optionally scans, returns data for storage. +func (s *Scanner) ScanReader(ctx context.Context, filename string, r io.Reader, size int64) ([]byte, Result, error) { + fp, err := s.policies.FilePolicies(ctx) + if err != nil { + return nil, Result{}, err + } + + maxBytes := fp.MaxUploadBytes + if maxBytes <= 0 { + maxBytes = 512 * 1024 * 1024 + } + if maxBytes > virustotalMaxBytes() { + maxBytes = virustotalMaxBytes() + } + + data, err := readLimited(r, size, maxBytes) + if err != nil { + return nil, Result{}, err + } + + enabled := fp.VirusScanEnabled && fp.VirusTotalAPIKey != "" + if !enabled { + return data, Result{Status: "skipped"}, nil + } + + vt := virustotal.NewScanner(fp.VirusTotalAPIKey, s.logger) + scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data)) + if err != nil { + if errors.Is(err, virustotal.ErrMalicious) { + return nil, Result{Status: "malicious"}, virustotal.ErrMalicious + } + return nil, Result{}, err + } + return data, Result{Status: scanResult.Status}, nil +} + +// ScanBytes scans pre-loaded bytes when policy is already known. +func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte) (Result, error) { + enabled, apiKey, err := s.policies.ScanEnabled(ctx) + if err != nil { + return Result{}, err + } + if !enabled { + return Result{Status: "skipped"}, nil + } + + vt := virustotal.NewScanner(apiKey, s.logger) + scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data)) + if err != nil { + if errors.Is(err, virustotal.ErrMalicious) { + return Result{Status: "malicious"}, virustotal.ErrMalicious + } + return Result{}, err + } + return Result{Status: scanResult.Status}, nil +} + +func readLimited(r io.Reader, size int64, maxBytes int64) ([]byte, error) { + if size >= 0 && size <= memoryBufferLimit { + limited := io.LimitReader(r, maxBytes+1) + data, err := io.ReadAll(limited) + if err != nil { + return nil, err + } + if int64(len(data)) > maxBytes { + return nil, errors.New("file exceeds max upload size") + } + return data, nil + } + + tmp, err := os.CreateTemp("", "ulti-scan-*") + if err != nil { + return nil, err + } + tmpPath := tmp.Name() + defer os.Remove(tmpPath) + + written, err := io.Copy(tmp, io.LimitReader(r, maxBytes+1)) + if err != nil { + tmp.Close() + return nil, err + } + if err := tmp.Close(); err != nil { + return nil, err + } + if written > maxBytes { + return nil, errors.New("file exceeds max upload size") + } + + return os.ReadFile(filepath.Clean(tmpPath)) +} + +func virustotalMaxBytes() int64 { + return 650 * 1024 * 1024 +} + +// ErrMalicious re-exports virustotal.ErrMalicious for handlers. +var ErrMalicious = virustotal.ErrMalicious + +// NopReader helper for tests. +func NopReader(data []byte) io.Reader { + return bytes.NewReader(data) +} diff --git a/internal/filescan/scanner_test.go b/internal/filescan/scanner_test.go new file mode 100644 index 0000000..e2b3cd5 --- /dev/null +++ b/internal/filescan/scanner_test.go @@ -0,0 +1,40 @@ +package filescan + +import ( + "bytes" + "context" + "testing" + + "github.com/ultisuite/ulti-backend/internal/orgpolicy" +) + +type stubPolicyLoader struct { + fp orgpolicy.FilePolicies +} + +func (s stubPolicyLoader) FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) { + return s.fp, nil +} + +func (s stubPolicyLoader) ScanEnabled(ctx context.Context) (bool, string, error) { + if s.fp.VirusScanEnabled && s.fp.VirusTotalAPIKey != "" { + return true, s.fp.VirusTotalAPIKey, nil + } + return false, "", nil +} + +func TestScannerDisabledReturnsSkipped(t *testing.T) { + scanner := &Scanner{ + policies: stubPolicyLoader{fp: orgpolicy.FilePolicies{VirusScanEnabled: false}}, + } + data, result, err := scanner.ScanReader(context.Background(), "test.txt", bytes.NewReader([]byte("hello")), 5) + if err != nil { + t.Fatalf("ScanReader: %v", err) + } + if result.Status != "skipped" { + t.Fatalf("status = %q, want skipped", result.Status) + } + if string(data) != "hello" { + t.Fatalf("data = %q", data) + } +} diff --git a/internal/integrationtest/admin/org_settings_test.go b/internal/integrationtest/admin/org_settings_test.go index 8886aae..4fbc629 100644 --- a/internal/integrationtest/admin/org_settings_test.go +++ b/internal/integrationtest/admin/org_settings_test.go @@ -57,3 +57,60 @@ func TestAdminOrgSettings(t *testing.T) { t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"]) } } + +func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) { + h := integrationtest.RequireHarness(t) + adminClient, _ := integrationtest.RequireAdminClient(t, h) + + putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{ + "policy": map[string]any{ + "file_policies": map[string]any{ + "virus_scan_enabled": true, + "virustotal_api_key": "vt-test-secret-key", + }, + }, + }) + integrationtest.FailIf(err, t, "put org settings with virustotal key") + integrationtest.FailUnlessStatus(t, putResp, 200) + + var afterPut map[string]any + integrationtest.DecodeJSON(t, putResp, &afterPut) + secrets, ok := afterPut["secrets"].(map[string]any) + if !ok { + t.Fatalf("missing secrets: %#v", afterPut) + } + vtSecret, ok := secrets["virustotal_api_key"].(map[string]any) + if !ok || vtSecret["configured"] != true { + t.Fatalf("virustotal_api_key not configured: %#v", secrets) + } + filePolicies, ok := afterPut["policy"].(map[string]any)["file_policies"].(map[string]any) + if !ok { + t.Fatalf("missing file_policies") + } + if filePolicies["virustotal_api_key"] != "" { + t.Fatalf("virustotal_api_key should be masked on GET, got %q", filePolicies["virustotal_api_key"]) + } + + // Empty patch must preserve stored secret. + preserveResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{ + "policy": map[string]any{ + "file_policies": map[string]any{ + "virus_scan_enabled": false, + "virustotal_api_key": "", + }, + }, + }) + integrationtest.FailIf(err, t, "preserve virustotal key") + integrationtest.FailUnlessStatus(t, preserveResp, 200) + + var preserved map[string]any + integrationtest.DecodeJSON(t, preserveResp, &preserved) + preservedSecrets, ok := preserved["secrets"].(map[string]any) + if !ok { + t.Fatalf("missing secrets after preserve") + } + vtPreserved, ok := preservedSecrets["virustotal_api_key"].(map[string]any) + if !ok || vtPreserved["configured"] != true { + t.Fatalf("virustotal key not preserved: %#v", preservedSecrets) + } +} diff --git a/internal/mail/imap/sync.go b/internal/mail/imap/sync.go index b9d3ffe..d701e1e 100644 --- a/internal/mail/imap/sync.go +++ b/internal/mail/imap/sync.go @@ -16,6 +16,7 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/connect" + "github.com/ultisuite/ulti-backend/internal/filescan" mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth" "github.com/ultisuite/ulti-backend/internal/mail/limits" "github.com/ultisuite/ulti-backend/internal/mail/rules" @@ -38,6 +39,7 @@ type SyncDeps struct { Rules *rules.Engine Automation MailAutomation Hub *realtime.Hub + FileScanner *filescan.Scanner } type SyncWorker struct { @@ -48,6 +50,7 @@ type SyncWorker struct { oauth *mailoauth.Service storage *storage.Client attachBucket string + scanner *filescan.Scanner pipeline *syncPipeline } @@ -60,6 +63,7 @@ func NewSyncWorker(db *pgxpool.Pool, interval time.Duration, credManager *creden oauth: oauthSvc, storage: deps.Storage, attachBucket: deps.AttachBucket, + scanner: deps.FileScanner, pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub), } } @@ -575,14 +579,30 @@ func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID str if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) { continue } + + scanStatus := "skipped" + partData := part.Data + if w.scanner != nil { + result, err := w.scanner.ScanBytes(ctx, part.Filename, part.Data) + if err != nil { + if errors.Is(err, filescan.ErrMalicious) { + w.logger.Warn("imap attachment skipped: malware detected", + "message_id", messageID, "filename", part.Filename) + continue + } + return err + } + scanStatus = result.Status + } + objectKey := storage.MessageObjectKey(userID, messageID, part.Filename) - if err := w.storage.Put(ctx, objectKey, bytes.NewReader(part.Data), int64(len(part.Data)), part.ContentType); err != nil { + if err := w.storage.Put(ctx, objectKey, bytes.NewReader(partData), int64(len(partData)), part.ContentType); err != nil { return err } _, err := w.db.Exec(ctx, ` - INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - `, messageID, part.Filename, part.ContentType, len(part.Data), bucket, objectKey, part.ContentID, part.IsInline) + INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) + `, messageID, part.Filename, part.ContentType, len(partData), bucket, objectKey, part.ContentID, part.IsInline, scanStatus) if err != nil { _ = w.storage.Delete(ctx, objectKey) return err diff --git a/internal/orgpolicy/loader.go b/internal/orgpolicy/loader.go new file mode 100644 index 0000000..9514834 --- /dev/null +++ b/internal/orgpolicy/loader.go @@ -0,0 +1,135 @@ +package orgpolicy + +import ( + "context" + "encoding/json" + "strings" + "sync" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/ultisuite/ulti-backend/internal/config" +) + +const orgSettingsSingletonID = 1 + +const defaultMaxUploadMiB = 512 + +// FilePolicies holds runtime file upload policy from org settings. +type FilePolicies struct { + VirusScanEnabled bool + VirusTotalAPIKey string + MaxUploadBytes int64 +} + +type Loader struct { + db *pgxpool.Pool + cfg *config.Config + + mu sync.Mutex + cached FilePolicies + cachedAt time.Time + ttl time.Duration +} + +func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader { + return &Loader{ + db: db, + cfg: cfg, + ttl: 60 * time.Second, + } +} + +func (l *Loader) FilePolicies(ctx context.Context) (FilePolicies, error) { + l.mu.Lock() + if !l.cachedAt.IsZero() && time.Since(l.cachedAt) < l.ttl { + out := l.cached + l.mu.Unlock() + return out, nil + } + l.mu.Unlock() + + fp, err := l.loadFilePolicies(ctx) + if err != nil { + return FilePolicies{}, err + } + + l.mu.Lock() + l.cached = fp + l.cachedAt = time.Now() + l.mu.Unlock() + return fp, nil +} + +func (l *Loader) loadFilePolicies(ctx context.Context) (FilePolicies, error) { + var raw []byte + err := l.db.QueryRow(ctx, ` + SELECT settings FROM org_settings WHERE id = $1 + `, orgSettingsSingletonID).Scan(&raw) + if err != nil && err != pgx.ErrNoRows { + return FilePolicies{}, err + } + + stored := map[string]any{} + if len(raw) > 0 { + if err := json.Unmarshal(raw, &stored); err != nil { + return FilePolicies{}, err + } + } + + filePolicies, _ := stored["file_policies"].(map[string]any) + enabled := boolValue(filePolicies["virus_scan_enabled"]) + apiKey := stringValue(filePolicies["virustotal_api_key"]) + if strings.TrimSpace(apiKey) == "" && l.cfg != nil { + apiKey = strings.TrimSpace(l.cfg.VirusTotalAPIKey) + } + + maxMiB := int64(defaultMaxUploadMiB) + switch v := filePolicies["max_upload_mib"].(type) { + case float64: + if v > 0 { + maxMiB = int64(v) + } + case int: + if v > 0 { + maxMiB = int64(v) + } + case int64: + if v > 0 { + maxMiB = v + } + } + + return FilePolicies{ + VirusScanEnabled: enabled, + VirusTotalAPIKey: apiKey, + MaxUploadBytes: maxMiB * 1024 * 1024, + }, nil +} + +func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) { + fp, err := l.FilePolicies(ctx) + if err != nil { + return false, "", err + } + if !fp.VirusScanEnabled || strings.TrimSpace(fp.VirusTotalAPIKey) == "" { + return false, "", nil + } + return true, fp.VirusTotalAPIKey, nil +} + +func boolValue(v any) bool { + switch t := v.(type) { + case bool: + return t + default: + return false + } +} + +func stringValue(v any) string { + s, _ := v.(string) + return s +} diff --git a/internal/server/bootstrap.go b/internal/server/bootstrap.go index 20d50c4..07f7468 100644 --- a/internal/server/bootstrap.go +++ b/internal/server/bootstrap.go @@ -32,6 +32,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/authentik" "github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/config" + "github.com/ultisuite/ulti-backend/internal/filescan" "github.com/ultisuite/ulti-backend/internal/httpcors" mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials" imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap" @@ -43,6 +44,7 @@ import ( "github.com/ultisuite/ulti-backend/internal/meet" "github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/observability" + "github.com/ultisuite/ulti-backend/internal/orgpolicy" "github.com/ultisuite/ulti-backend/internal/photos" "github.com/ultisuite/ulti-backend/internal/realtime" "github.com/ultisuite/ulti-backend/internal/search" @@ -203,6 +205,9 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { RedirectURL: oauthRedirect, }, rdb) + orgPolicyLoader := orgpolicy.NewLoader(pool, cfg) + fileScanner := filescan.NewScanner(orgPolicyLoader, slog.Default()) + var syncWorker *imapsync.SyncWorker if !opts.WithoutWorkers { syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{ @@ -210,6 +215,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { AttachBucket: cfg.MailAttachmentsBucket, Automation: autoDispatcher, Hub: hub, + FileScanner: fileScanner, }) go syncWorker.Start(workerCtx) } @@ -229,6 +235,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst) mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender) + mailHandler.SetFileScanner(fileScanner) if syncWorker != nil { mailHandler.SetAccountSync(syncWorker) } @@ -262,6 +269,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) { if ncClient != nil { driveSvc = drive.NewService(ncClient, hub, pool) driveSvc.SetAutomation(autoDispatcher) + driveSvc.SetFileScanner(fileScanner) driveHandler = drive.NewHandlerWithService(driveSvc) mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc}) contactsHandler = contacts.NewHandler(ncClient, pool) diff --git a/internal/virustotal/client.go b/internal/virustotal/client.go new file mode 100644 index 0000000..bba5eb2 --- /dev/null +++ b/internal/virustotal/client.go @@ -0,0 +1,267 @@ +package virustotal + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "strings" + "time" +) + +const ( + defaultBaseURL = "https://www.virustotal.com/api/v3" + directUploadLimit = 32 * 1024 * 1024 + maxUploadLimit = 650 * 1024 * 1024 +) + +type Client struct { + apiKey string + baseURL string + httpClient *http.Client +} + +func NewClient(apiKey string) *Client { + return &Client{ + apiKey: apiKey, + baseURL: defaultBaseURL, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +func (c *Client) apiURL(path string) string { + return strings.TrimRight(c.baseURL, "/") + path +} + +type analysisStats struct { + Malicious int `json:"malicious"` + Suspicious int `json:"suspicious"` +} + +type fileReport struct { + Data struct { + Attributes struct { + LastAnalysisStats analysisStats `json:"last_analysis_stats"` + } `json:"attributes"` + } `json:"data"` +} + +type analysisResponse struct { + Data struct { + Attributes struct { + Status string `json:"status"` + Stats analysisStats `json:"stats"` + } `json:"attributes"` + } `json:"data"` +} + +type uploadResponse struct { + Data struct { + ID string `json:"id"` + } `json:"data"` +} + +type uploadURLResponse struct { + Data string `json:"data"` +} + +func (c *Client) headers() http.Header { + h := http.Header{} + h.Set("Accept", "application/json") + h.Set("x-apikey", c.apiKey) + return h +} + +func (c *Client) lookupFile(ctx context.Context, sha256 string) (analysisStats, bool, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/"+sha256), nil) + if err != nil { + return analysisStats{}, false, err + } + req.Header = c.headers() + + res, err := c.httpClient.Do(req) + if err != nil { + return analysisStats{}, false, err + } + defer res.Body.Close() + + if res.StatusCode == http.StatusNotFound { + return analysisStats{}, false, nil + } + if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests { + return analysisStats{}, false, fmt.Errorf("virustotal lookup unavailable: %d", res.StatusCode) + } + if res.StatusCode >= 400 { + return analysisStats{}, false, nil + } + + var report fileReport + if err := json.NewDecoder(res.Body).Decode(&report); err != nil { + return analysisStats{}, false, err + } + return report.Data.Attributes.LastAnalysisStats, true, nil +} + +func (c *Client) uploadFile(ctx context.Context, data []byte, filename string) (string, error) { + if len(data) <= directUploadLimit { + return c.uploadDirect(ctx, data, filename) + } + return c.uploadLarge(ctx, data, filename) +} + +func (c *Client) uploadDirect(ctx context.Context, data []byte, filename string) (string, error) { + var body bytes.Buffer + w := multipart.NewWriter(&body) + part, err := w.CreateFormFile("file", filename) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := w.Close(); err != nil { + return "", err + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL("/files"), &body) + if err != nil { + return "", err + } + req.Header = c.headers() + req.Header.Set("Content-Type", w.FormDataContentType()) + + res, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + b, _ := io.ReadAll(io.LimitReader(res.Body, 4096)) + return "", fmt.Errorf("virustotal upload failed: %d %s", res.StatusCode, string(b)) + } + + var out uploadResponse + if err := json.NewDecoder(res.Body).Decode(&out); err != nil { + return "", err + } + return out.Data.ID, nil +} + +func (c *Client) uploadLarge(ctx context.Context, data []byte, filename string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/upload_url"), nil) + if err != nil { + return "", err + } + req.Header = c.headers() + + res, err := c.httpClient.Do(req) + if err != nil { + return "", err + } + defer res.Body.Close() + + if res.StatusCode >= 400 { + b, _ := io.ReadAll(io.LimitReader(res.Body, 4096)) + return "", fmt.Errorf("virustotal upload_url failed: %d %s", res.StatusCode, string(b)) + } + + var urlResp uploadURLResponse + if err := json.NewDecoder(res.Body).Decode(&urlResp); err != nil { + return "", err + } + if urlResp.Data == "" { + return "", fmt.Errorf("virustotal upload_url empty") + } + + var body bytes.Buffer + w := multipart.NewWriter(&body) + part, err := w.CreateFormFile("file", filename) + if err != nil { + return "", err + } + if _, err := part.Write(data); err != nil { + return "", err + } + if err := w.Close(); err != nil { + return "", err + } + + upReq, err := http.NewRequestWithContext(ctx, http.MethodPost, urlResp.Data, &body) + if err != nil { + return "", err + } + upReq.Header = c.headers() + upReq.Header.Set("Content-Type", w.FormDataContentType()) + + upRes, err := c.httpClient.Do(upReq) + if err != nil { + return "", err + } + defer upRes.Body.Close() + + if upRes.StatusCode >= 400 { + b, _ := io.ReadAll(io.LimitReader(upRes.Body, 4096)) + return "", fmt.Errorf("virustotal large upload failed: %d %s", upRes.StatusCode, string(b)) + } + + var out uploadResponse + if err := json.NewDecoder(upRes.Body).Decode(&out); err != nil { + return "", err + } + return out.Data.ID, nil +} + +func (c *Client) pollAnalysis(ctx context.Context, analysisID string, timeout time.Duration) (analysisStats, error) { + deadline := time.Now().Add(timeout) + for { + if ctx.Err() != nil { + return analysisStats{}, ctx.Err() + } + if time.Now().After(deadline) { + return analysisStats{}, fmt.Errorf("virustotal analysis timeout") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/analyses/"+analysisID), nil) + if err != nil { + return analysisStats{}, err + } + req.Header = c.headers() + + res, err := c.httpClient.Do(req) + if err != nil { + return analysisStats{}, err + } + + var out analysisResponse + decodeErr := json.NewDecoder(res.Body).Decode(&out) + res.Body.Close() + if decodeErr != nil { + return analysisStats{}, decodeErr + } + + if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests { + return analysisStats{}, fmt.Errorf("virustotal poll unavailable: %d", res.StatusCode) + } + + status := out.Data.Attributes.Status + if status == "completed" { + return out.Data.Attributes.Stats, nil + } + + select { + case <-ctx.Done(): + return analysisStats{}, ctx.Err() + case <-time.After(2 * time.Second): + } + } +} + +func statsMalicious(stats analysisStats) bool { + return stats.Malicious > 0 +} diff --git a/internal/virustotal/errors.go b/internal/virustotal/errors.go new file mode 100644 index 0000000..0f6d7e1 --- /dev/null +++ b/internal/virustotal/errors.go @@ -0,0 +1,7 @@ +package virustotal + +import "errors" + +var ( + ErrMalicious = errors.New("malware detected") +) diff --git a/internal/virustotal/scan.go b/internal/virustotal/scan.go new file mode 100644 index 0000000..f66623d --- /dev/null +++ b/internal/virustotal/scan.go @@ -0,0 +1,106 @@ +package virustotal + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "log/slog" + "time" +) + +const defaultScanTimeout = 60 * time.Second + +// ScanResult is the outcome of a VirusTotal scan attempt. +type ScanResult struct { + Status string // clean | skipped +} + +// Scanner wraps Client with fail-open behavior on API errors. +type Scanner struct { + client *Client + logger *slog.Logger + timeout time.Duration +} + +func NewScanner(apiKey string, logger *slog.Logger) *Scanner { + if logger == nil { + logger = slog.Default() + } + return &Scanner{ + client: NewClient(apiKey), + logger: logger, + timeout: defaultScanTimeout, + } +} + +// ScanBytes scans file content. Returns ErrMalicious if detected. +// On VT unavailability, returns skipped (fail-open). +func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte, sha256Hex string) (ScanResult, error) { + if len(data) == 0 { + return ScanResult{Status: "skipped"}, nil + } + if len(data) > maxUploadLimit { + s.logger.Warn("virustotal scan skipped: file too large for VT", "filename", filename, "size", len(data)) + return ScanResult{Status: "skipped"}, nil + } + + if sha256Hex == "" { + sum := sha256.Sum256(data) + sha256Hex = hex.EncodeToString(sum[:]) + } + + stats, found, err := s.client.lookupFile(ctx, sha256Hex) + if err != nil { + s.logger.Warn("virustotal lookup failed, skipping scan", "filename", filename, "error", err) + return ScanResult{Status: "skipped"}, nil + } + if found { + if statsMalicious(stats) { + return ScanResult{}, ErrMalicious + } + return ScanResult{Status: "clean"}, nil + } + + analysisID, err := s.client.uploadFile(ctx, data, safeFilename(filename)) + if err != nil { + s.logger.Warn("virustotal upload failed, skipping scan", "filename", filename, "error", err) + return ScanResult{Status: "skipped"}, nil + } + + scanCtx, cancel := context.WithTimeout(ctx, s.timeout) + defer cancel() + + stats, err = s.client.pollAnalysis(scanCtx, analysisID, s.timeout) + if err != nil { + s.logger.Warn("virustotal analysis failed, skipping scan", "filename", filename, "error", err) + return ScanResult{Status: "skipped"}, nil + } + + if statsMalicious(stats) { + return ScanResult{}, ErrMalicious + } + return ScanResult{Status: "clean"}, nil +} + +func safeFilename(name string) string { + if name == "" { + return "upload" + } + base := name + if len(base) > 200 { + base = base[:200] + } + return base +} + +// SHA256Hex returns hex-encoded SHA-256 of data. +func SHA256Hex(data []byte) string { + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + +// FormatAnalysisID is a helper for logging. +func FormatAnalysisID(id string) string { + return fmt.Sprintf("vt:%s", id) +} diff --git a/internal/virustotal/scan_test.go b/internal/virustotal/scan_test.go new file mode 100644 index 0000000..c4f68d4 --- /dev/null +++ b/internal/virustotal/scan_test.go @@ -0,0 +1,109 @@ +package virustotal + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestScannerLookupMalicious(t *testing.T) { + sha := sha256.Sum256([]byte("evil")) + shaHex := hex.EncodeToString(sha[:]) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/files/"+shaHex) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "attributes": map[string]any{ + "last_analysis_stats": map[string]any{ + "malicious": 2, + "suspicious": 0, + }, + }, + }, + }) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + + sc := NewScanner("test-key", nil) + sc.client.baseURL = srv.URL + "/api/v3" + + _, err := sc.ScanBytes(context.Background(), "evil.bin", []byte("evil"), shaHex) + if err == nil { + t.Fatal("expected ErrMalicious") + } + if err != ErrMalicious { + t.Fatalf("err = %v, want ErrMalicious", err) + } +} + +func TestScannerUploadAndPollClean(t *testing.T) { + pollCount := 0 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/"+SHA256Hex([]byte("clean"))): + http.NotFound(w, r) + case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/files"): + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{"id": "analysis-1"}, + }) + case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/analyses/analysis-1"): + pollCount++ + status := "queued" + if pollCount >= 2 { + status = "completed" + } + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "attributes": map[string]any{ + "status": status, + "stats": map[string]any{ + "malicious": 0, + "suspicious": 0, + }, + }, + }, + }) + default: + http.NotFound(w, r) + } + })) + defer srv.Close() + + sc := NewScanner("test-key", nil) + sc.client.baseURL = srv.URL + "/api/v3" + + result, err := sc.ScanBytes(context.Background(), "clean.txt", []byte("clean"), "") + if err != nil { + t.Fatalf("ScanBytes: %v", err) + } + if result.Status != "clean" { + t.Fatalf("status = %q, want clean", result.Status) + } +} + +func TestScannerFailOpenOnUnavailable(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "down", http.StatusServiceUnavailable) + })) + defer srv.Close() + + sc := NewScanner("test-key", nil) + sc.client.baseURL = srv.URL + "/api/v3" + + result, err := sc.ScanBytes(context.Background(), "file.bin", []byte("payload"), "") + if err != nil { + t.Fatalf("expected fail-open, got err %v", err) + } + if result.Status != "skipped" { + t.Fatalf("status = %q, want skipped", result.Status) + } +} diff --git a/migrations/000034_attachment_virus_scan.down.sql b/migrations/000034_attachment_virus_scan.down.sql new file mode 100644 index 0000000..7b321e6 --- /dev/null +++ b/migrations/000034_attachment_virus_scan.down.sql @@ -0,0 +1 @@ +ALTER TABLE attachments DROP COLUMN IF EXISTS virus_scan_status; diff --git a/migrations/000034_attachment_virus_scan.up.sql b/migrations/000034_attachment_virus_scan.up.sql new file mode 100644 index 0000000..f3c8a99 --- /dev/null +++ b/migrations/000034_attachment_virus_scan.up.sql @@ -0,0 +1,3 @@ +ALTER TABLE attachments + ADD COLUMN virus_scan_status TEXT NOT NULL DEFAULT 'skipped' + CHECK (virus_scan_status IN ('clean', 'skipped', 'malicious'));