feat(scan): add VirusTotal upload antivirus
Some checks failed
CI / Go tests (push) Has been cancelled
CI / Integration tests (push) Has been cancelled
CI / DB migrations (push) Has been cancelled

Admin-stored API key with env fallback; scan drive/mail/IMAP uploads.
Fail-open if VT down, 422 on malware; migration for virus_scan_status.
This commit is contained in:
R3D347HR4Y 2026-06-07 22:05:27 +02:00
parent f67c109f2f
commit b90edf317c
24 changed files with 1033 additions and 16 deletions

View File

@ -255,3 +255,8 @@ SEARCH_ENGINE=postgres
# TYPESENSE_URL=http://typesense:8108 # TYPESENSE_URL=http://typesense:8108
# TYPESENSE_API_KEY={{TYPESENSE_API_KEY}} # TYPESENSE_API_KEY={{TYPESENSE_API_KEY}}
# TYPESENSE_COLLECTION=ulti # TYPESENSE_COLLECTION=ulti
# -----------------------------------------------------------------------------
# VirusTotal (optional env fallback; prefer admin Settings > File policies)
# -----------------------------------------------------------------------------
# VIRUSTOTAL_API_KEY=

View File

@ -51,6 +51,7 @@ func defaultOrgPolicy() map[string]any {
"external_sharing": "authenticated", "external_sharing": "authenticated",
"default_link_expiry_days": 30, "default_link_expiry_days": 30,
"virus_scan_enabled": false, "virus_scan_enabled": false,
"virustotal_api_key": "",
"retention_trash_days": 30, "retention_trash_days": 30,
}, },
"llm": map[string]any{ "llm": map[string]any{
@ -174,6 +175,7 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
{"onlyoffice", "jwt_secret"}, {"onlyoffice", "jwt_secret"},
{"search", "meilisearch_api_key"}, {"search", "meilisearch_api_key"},
{"search", "typesense_api_key"}, {"search", "typesense_api_key"},
{"file_policies", "virustotal_api_key"},
} }
for _, p := range secretPaths { for _, p := range secretPaths {
patchSection, _ := patch[p.section].(map[string]any) patchSection, _ := patch[p.section].(map[string]any)
@ -296,6 +298,7 @@ func maskOrgPolicy(policy map[string]any) map[string]any {
maskStringField(cloned, "onlyoffice", "jwt_secret") maskStringField(cloned, "onlyoffice", "jwt_secret")
maskStringField(cloned, "search", "meilisearch_api_key") maskStringField(cloned, "search", "meilisearch_api_key")
maskStringField(cloned, "search", "typesense_api_key") maskStringField(cloned, "search", "typesense_api_key")
maskStringField(cloned, "file_policies", "virustotal_api_key")
if llm, ok := cloned["llm"].(map[string]any); ok { if llm, ok := cloned["llm"].(map[string]any); ok {
if providers, ok := llm["providers"].([]any); ok { if providers, ok := llm["providers"].([]any); ok {
for i, p := range providers { for i, p := range providers {
@ -374,6 +377,9 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
"typesense_api_key": map[string]any{ "typesense_api_key": map[string]any{
"configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "", "configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "",
}, },
"virustotal_api_key": map[string]any{
"configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "",
},
} }
} }

View File

@ -5,6 +5,7 @@ import (
"path" "path"
"github.com/ultisuite/ulti-backend/internal/automation" "github.com/ultisuite/ulti-backend/internal/automation"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/rules"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
) )
@ -13,6 +14,10 @@ func (s *Service) SetAutomation(d driveAutomation) {
s.automation = d s.automation = d
} }
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
s.scanner = scanner
}
func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) { func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) {
normalized := nextcloud.NormalizeClientPath(filePath) normalized := nextcloud.NormalizeClientPath(filePath)
s.notifyFileChanged(externalUserID, normalized) s.notifyFileChanged(externalUserID, normalized)

View File

@ -733,6 +733,8 @@ func writeDriveError(w http.ResponseWriter, r *http.Request, err error) {
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil) apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil)
case errors.Is(err, ErrQuotaExceeded): case errors.Is(err, ErrQuotaExceeded):
apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil) apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil)
case errors.Is(err, ErrMalware):
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "drive.malware_detected", "malware detected in file", nil)
case errors.Is(err, ErrInvalid): case errors.Is(err, ErrInvalid):
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil) apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil)
default: default:

View File

@ -1,11 +1,14 @@
package drive package drive
import ( import (
"bytes"
"context" "context"
"errors"
"io" "io"
"path" "path"
"strings" "strings"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
) )
@ -17,7 +20,18 @@ func (s *Service) UploadPublicShare(ctx context.Context, token, filePath, passwo
if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) { if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) {
return ErrForbidden return ErrForbidden
} }
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, body, contentType)); err != nil { reader := body
if s.scanner != nil {
data, _, err := s.scanner.ScanReader(ctx, filePath, body, -1)
if err != nil {
if errors.Is(err, filescan.ErrMalicious) {
return ErrMalware
}
return err
}
reader = bytes.NewReader(data)
}
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, reader, contentType)); err != nil {
return err return err
} }
s.recordPublicShareAccess(ctx, token) s.recordPublicShareAccess(ctx, token)

View File

@ -18,6 +18,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/automation" "github.com/ultisuite/ulti-backend/internal/automation"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/rules"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
"github.com/ultisuite/ulti-backend/internal/publicshare" "github.com/ultisuite/ulti-backend/internal/publicshare"
@ -30,6 +31,7 @@ var (
ErrForbidden = errors.New("forbidden") ErrForbidden = errors.New("forbidden")
ErrQuotaExceeded = errors.New("quota exceeded") ErrQuotaExceeded = errors.New("quota exceeded")
ErrInvalid = errors.New("invalid request") ErrInvalid = errors.New("invalid request")
ErrMalware = errors.New("malware detected")
) )
type Service struct { type Service struct {
@ -37,6 +39,7 @@ type Service struct {
hub *realtime.Hub hub *realtime.Hub
db *pgxpool.Pool db *pgxpool.Pool
automation driveAutomation automation driveAutomation
scanner *filescan.Scanner
maxUploadBytes int64 maxUploadBytes int64
quotaReserveByte int64 quotaReserveByte int64
} }
@ -184,7 +187,18 @@ func (s *Service) Upload(ctx context.Context, userID, path string, body io.Reade
if err := s.ensureQuota(ctx, userID, contentLength); err != nil { if err := s.ensureQuota(ctx, userID, contentLength); err != nil {
return err return err
} }
return mapDriveError(s.nc.Upload(ctx, userID, path, body, contentType)) reader := body
if s.scanner != nil {
data, _, err := s.scanner.ScanReader(ctx, path, body, contentLength)
if err != nil {
if errors.Is(err, filescan.ErrMalicious) {
return ErrMalware
}
return err
}
reader = bytes.NewReader(data)
}
return mapDriveError(s.nc.Upload(ctx, userID, path, reader, contentType))
} }
func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error { func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error {
@ -202,6 +216,31 @@ func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath
if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil { if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil {
return err return err
} }
if s.scanner != nil {
if err := s.scanAssembledUpload(ctx, userID, targetPath); err != nil {
return err
}
}
return nil
}
func (s *Service) scanAssembledUpload(ctx context.Context, userID, targetPath string) error {
body, _, err := s.nc.Download(ctx, userID, targetPath)
if err != nil {
return mapDriveError(err)
}
defer body.Close()
data, scanResult, err := s.scanner.ScanReader(ctx, targetPath, body, -1)
if err != nil {
if errors.Is(err, filescan.ErrMalicious) {
_ = s.nc.Delete(ctx, userID, targetPath)
return ErrMalware
}
return err
}
_ = data
_ = scanResult
return nil return nil
} }

View File

@ -1,6 +1,7 @@
package mail package mail
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"errors" "errors"
@ -39,7 +40,7 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
} }
rows, err := s.db.Query(ctx, ` rows, err := s.db.Query(ctx, `
SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, '') SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, ''), virus_scan_status
FROM attachments WHERE message_id = $1 FROM attachments WHERE message_id = $1
ORDER BY created_at ASC ORDER BY created_at ASC
`, messageID) `, messageID)
@ -50,15 +51,16 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
out := make([]map[string]any, 0) out := make([]map[string]any, 0)
for rows.Next() { for rows.Next() {
var id, filename, contentType, contentID, drivePath string var id, filename, contentType, contentID, drivePath, virusScanStatus string
var size int64 var size int64
var isInline bool var isInline bool
if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath); err != nil { if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath, &virusScanStatus); err != nil {
return nil, err return nil, err
} }
entry := map[string]any{ entry := map[string]any{
"id": id, "filename": filename, "content_type": contentType, "id": id, "filename": filename, "content_type": contentType,
"size": size, "is_inline": isInline, "size": size, "is_inline": isInline,
"virus_scan_status": virusScanStatus,
} }
if contentID != "" { if contentID != "" {
entry["content_id"] = contentID entry["content_id"] = contentID
@ -142,16 +144,28 @@ func (s *Service) UploadMessageAttachment(
} }
objectKey := storage.MessageObjectKey(userID, messageID, filename) objectKey := storage.MessageObjectKey(userID, messageID, filename)
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { scanStatus := "skipped"
putReader := reader
putSize := size
if s.scanner != nil {
data, result, err := s.scanner.ScanReader(ctx, filename, reader, size)
if err != nil {
return "", err
}
scanStatus = result.Status
putReader = bytes.NewReader(data)
putSize = int64(len(data))
}
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
return "", err return "", err
} }
var id string var id string
err = s.db.QueryRow(ctx, ` err = s.db.QueryRow(ctx, `
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline) INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
RETURNING id RETURNING id
`, messageID, filename, contentType, size, s.storageBucket(), objectKey, contentID, isInline).Scan(&id) `, messageID, filename, contentType, putSize, s.storageBucket(), objectKey, contentID, isInline, scanStatus).Scan(&id)
if err != nil { if err != nil {
_ = s.storage.Delete(ctx, objectKey) _ = s.storage.Delete(ctx, objectKey)
return "", err return "", err
@ -224,7 +238,17 @@ func (s *Service) UploadDraftAttachment(
} }
objectKey := storage.DraftObjectKey(userID, draftID, filename) objectKey := storage.DraftObjectKey(userID, draftID, filename)
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil { putReader := reader
putSize := size
if s.scanner != nil {
data, _, err := s.scanner.ScanReader(ctx, filename, reader, size)
if err != nil {
return "", err
}
putReader = bytes.NewReader(data)
putSize = int64(len(data))
}
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
return "", err return "", err
} }
@ -237,13 +261,13 @@ func (s *Service) UploadDraftAttachment(
for _, ref := range refs { for _, ref := range refs {
totalSize += ref.Size totalSize += ref.Size
} }
if err := limits.ValidateAttachmentQuota(len(refs), totalSize, size); err != nil { if err := limits.ValidateAttachmentQuota(len(refs), totalSize, putSize); err != nil {
return "", err return "", err
} }
attID := uuid.NewString() attID := uuid.NewString()
refs = append(refs, draftAttachmentRef{ refs = append(refs, draftAttachmentRef{
ID: attID, Filename: filename, ContentType: contentType, Size: size, ID: attID, Filename: filename, ContentType: contentType, Size: putSize,
S3Bucket: s.storageBucket(), S3Key: objectKey, S3Bucket: s.storageBucket(), S3Key: objectKey,
ContentID: contentID, IsInline: isInline, ContentID: contentID, IsInline: isInline,
}) })

View File

@ -6,6 +6,7 @@ import (
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
) )
@ -15,6 +16,10 @@ func (s *Service) SetDriveUploader(uploader DriveUploader) {
s.driveUploader = uploader s.driveUploader = uploader
} }
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
s.scanner = scanner
}
func (s *Service) SaveAttachmentToDrive( func (s *Service) SaveAttachmentToDrive(
ctx context.Context, ctx context.Context,
externalID, email, sub, displayName, messageID, attachmentID, folderPath string, externalID, email, sub, displayName, messageID, attachmentID, folderPath string,

View File

@ -13,6 +13,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/mail/sendguard" "github.com/ultisuite/ulti-backend/internal/api/mail/sendguard"
"github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/middleware"
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/limits" "github.com/ultisuite/ulti-backend/internal/mail/limits"
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth" mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
@ -42,6 +43,13 @@ func (h *Handler) SetDriveUploader(uploader DriveUploader) {
} }
} }
// SetFileScanner wires VirusTotal scanning for mail attachment uploads.
func (h *Handler) SetFileScanner(scanner *filescan.Scanner) {
if s, ok := h.svc.(*Service); ok {
s.SetFileScanner(scanner)
}
}
func NewHandlerWithService(svc ServiceAPI) *Handler { func NewHandlerWithService(svc ServiceAPI) *Handler {
return &Handler{ return &Handler{
svc: svc, svc: svc,

View File

@ -15,6 +15,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/api/apivalidate" "github.com/ultisuite/ulti-backend/internal/api/apivalidate"
driveapi "github.com/ultisuite/ulti-backend/internal/api/drive" driveapi "github.com/ultisuite/ulti-backend/internal/api/drive"
"github.com/ultisuite/ulti-backend/internal/api/middleware" "github.com/ultisuite/ulti-backend/internal/api/middleware"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/mail/limits" "github.com/ultisuite/ulti-backend/internal/mail/limits"
) )
@ -342,6 +343,9 @@ func writeAttachmentUploadError(w http.ResponseWriter, r *http.Request, err erro
case errors.Is(err, limits.ErrTooManyAttachments): case errors.Is(err, limits.ErrTooManyAttachments):
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil) apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil)
return true return true
case errors.Is(err, ErrMalware), errors.Is(err, filescan.ErrMalicious):
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "mail.malware_detected", "malware detected in attachment", nil)
return true
default: default:
return false return false
} }

View File

@ -11,6 +11,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/api/query" "github.com/ultisuite/ulti-backend/internal/api/query"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/imap" "github.com/ultisuite/ulti-backend/internal/mail/imap"
"github.com/ultisuite/ulti-backend/internal/mail/sanitize" "github.com/ultisuite/ulti-backend/internal/mail/sanitize"
@ -28,6 +29,7 @@ var (
ErrInvalidAccountCredentials = errors.New("account credentials invalid") ErrInvalidAccountCredentials = errors.New("account credentials invalid")
ErrInvalidFolderScope = errors.New("invalid folder scope") ErrInvalidFolderScope = errors.New("invalid folder scope")
ErrFolderHasChildren = errors.New("folder has children") ErrFolderHasChildren = errors.New("folder has children")
ErrMalware = filescan.ErrMalicious
) )
type Service struct { type Service struct {
@ -37,6 +39,7 @@ type Service struct {
storage *storage.Client storage *storage.Client
attachmentsBucket string attachmentsBucket string
driveUploader DriveUploader driveUploader DriveUploader
scanner *filescan.Scanner
logger *slog.Logger logger *slog.Logger
} }

View File

@ -116,6 +116,9 @@ type Config struct {
TypesenseKey string TypesenseKey string
TypesenseCollection string TypesenseCollection string
// VirusTotal (optional env fallback for org file_policies.virustotal_api_key)
VirusTotalAPIKey string
// Observability // Observability
HealthNextcloudURL string HealthNextcloudURL string
HealthImmichURL string HealthImmichURL string
@ -221,6 +224,8 @@ func Load() (*Config, error) {
TypesenseKey: secrets.Env("TYPESENSE_API_KEY"), TypesenseKey: secrets.Env("TYPESENSE_API_KEY"),
TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"), TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"),
VirusTotalAPIKey: secrets.Env("VIRUSTOTAL_API_KEY"),
HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")), HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")),
HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")), HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")),
HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))), HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))),

View File

@ -0,0 +1,144 @@
package filescan
import (
"bytes"
"context"
"errors"
"io"
"log/slog"
"os"
"path/filepath"
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
"github.com/ultisuite/ulti-backend/internal/virustotal"
)
const memoryBufferLimit = 32 * 1024 * 1024
// Result holds scan outcome for persistence.
type Result struct {
Status string // clean | skipped
}
// PolicyLoader supplies org file policies at runtime.
type PolicyLoader interface {
FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error)
ScanEnabled(ctx context.Context) (bool, string, error)
}
// Scanner coordinates org policy and VirusTotal scanning.
type Scanner struct {
policies PolicyLoader
logger *slog.Logger
}
func NewScanner(policies PolicyLoader, logger *slog.Logger) *Scanner {
if logger == nil {
logger = slog.Default()
}
return &Scanner{policies: policies, logger: logger}
}
// ScanReader reads all bytes from r (up to maxBytes), optionally scans, returns data for storage.
func (s *Scanner) ScanReader(ctx context.Context, filename string, r io.Reader, size int64) ([]byte, Result, error) {
fp, err := s.policies.FilePolicies(ctx)
if err != nil {
return nil, Result{}, err
}
maxBytes := fp.MaxUploadBytes
if maxBytes <= 0 {
maxBytes = 512 * 1024 * 1024
}
if maxBytes > virustotalMaxBytes() {
maxBytes = virustotalMaxBytes()
}
data, err := readLimited(r, size, maxBytes)
if err != nil {
return nil, Result{}, err
}
enabled := fp.VirusScanEnabled && fp.VirusTotalAPIKey != ""
if !enabled {
return data, Result{Status: "skipped"}, nil
}
vt := virustotal.NewScanner(fp.VirusTotalAPIKey, s.logger)
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
if err != nil {
if errors.Is(err, virustotal.ErrMalicious) {
return nil, Result{Status: "malicious"}, virustotal.ErrMalicious
}
return nil, Result{}, err
}
return data, Result{Status: scanResult.Status}, nil
}
// ScanBytes scans pre-loaded bytes when policy is already known.
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte) (Result, error) {
enabled, apiKey, err := s.policies.ScanEnabled(ctx)
if err != nil {
return Result{}, err
}
if !enabled {
return Result{Status: "skipped"}, nil
}
vt := virustotal.NewScanner(apiKey, s.logger)
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
if err != nil {
if errors.Is(err, virustotal.ErrMalicious) {
return Result{Status: "malicious"}, virustotal.ErrMalicious
}
return Result{}, err
}
return Result{Status: scanResult.Status}, nil
}
func readLimited(r io.Reader, size int64, maxBytes int64) ([]byte, error) {
if size >= 0 && size <= memoryBufferLimit {
limited := io.LimitReader(r, maxBytes+1)
data, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if int64(len(data)) > maxBytes {
return nil, errors.New("file exceeds max upload size")
}
return data, nil
}
tmp, err := os.CreateTemp("", "ulti-scan-*")
if err != nil {
return nil, err
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
written, err := io.Copy(tmp, io.LimitReader(r, maxBytes+1))
if err != nil {
tmp.Close()
return nil, err
}
if err := tmp.Close(); err != nil {
return nil, err
}
if written > maxBytes {
return nil, errors.New("file exceeds max upload size")
}
return os.ReadFile(filepath.Clean(tmpPath))
}
func virustotalMaxBytes() int64 {
return 650 * 1024 * 1024
}
// ErrMalicious re-exports virustotal.ErrMalicious for handlers.
var ErrMalicious = virustotal.ErrMalicious
// NopReader helper for tests.
func NopReader(data []byte) io.Reader {
return bytes.NewReader(data)
}

View File

@ -0,0 +1,40 @@
package filescan
import (
"bytes"
"context"
"testing"
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
)
type stubPolicyLoader struct {
fp orgpolicy.FilePolicies
}
func (s stubPolicyLoader) FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) {
return s.fp, nil
}
func (s stubPolicyLoader) ScanEnabled(ctx context.Context) (bool, string, error) {
if s.fp.VirusScanEnabled && s.fp.VirusTotalAPIKey != "" {
return true, s.fp.VirusTotalAPIKey, nil
}
return false, "", nil
}
func TestScannerDisabledReturnsSkipped(t *testing.T) {
scanner := &Scanner{
policies: stubPolicyLoader{fp: orgpolicy.FilePolicies{VirusScanEnabled: false}},
}
data, result, err := scanner.ScanReader(context.Background(), "test.txt", bytes.NewReader([]byte("hello")), 5)
if err != nil {
t.Fatalf("ScanReader: %v", err)
}
if result.Status != "skipped" {
t.Fatalf("status = %q, want skipped", result.Status)
}
if string(data) != "hello" {
t.Fatalf("data = %q", data)
}
}

View File

@ -57,3 +57,60 @@ func TestAdminOrgSettings(t *testing.T) {
t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"]) t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"])
} }
} }
func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) {
h := integrationtest.RequireHarness(t)
adminClient, _ := integrationtest.RequireAdminClient(t, h)
putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
"policy": map[string]any{
"file_policies": map[string]any{
"virus_scan_enabled": true,
"virustotal_api_key": "vt-test-secret-key",
},
},
})
integrationtest.FailIf(err, t, "put org settings with virustotal key")
integrationtest.FailUnlessStatus(t, putResp, 200)
var afterPut map[string]any
integrationtest.DecodeJSON(t, putResp, &afterPut)
secrets, ok := afterPut["secrets"].(map[string]any)
if !ok {
t.Fatalf("missing secrets: %#v", afterPut)
}
vtSecret, ok := secrets["virustotal_api_key"].(map[string]any)
if !ok || vtSecret["configured"] != true {
t.Fatalf("virustotal_api_key not configured: %#v", secrets)
}
filePolicies, ok := afterPut["policy"].(map[string]any)["file_policies"].(map[string]any)
if !ok {
t.Fatalf("missing file_policies")
}
if filePolicies["virustotal_api_key"] != "" {
t.Fatalf("virustotal_api_key should be masked on GET, got %q", filePolicies["virustotal_api_key"])
}
// Empty patch must preserve stored secret.
preserveResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
"policy": map[string]any{
"file_policies": map[string]any{
"virus_scan_enabled": false,
"virustotal_api_key": "",
},
},
})
integrationtest.FailIf(err, t, "preserve virustotal key")
integrationtest.FailUnlessStatus(t, preserveResp, 200)
var preserved map[string]any
integrationtest.DecodeJSON(t, preserveResp, &preserved)
preservedSecrets, ok := preserved["secrets"].(map[string]any)
if !ok {
t.Fatalf("missing secrets after preserve")
}
vtPreserved, ok := preservedSecrets["virustotal_api_key"].(map[string]any)
if !ok || vtPreserved["configured"] != true {
t.Fatalf("virustotal key not preserved: %#v", preservedSecrets)
}
}

View File

@ -16,6 +16,7 @@ import (
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/mail/credentials" "github.com/ultisuite/ulti-backend/internal/mail/credentials"
"github.com/ultisuite/ulti-backend/internal/mail/connect" "github.com/ultisuite/ulti-backend/internal/mail/connect"
"github.com/ultisuite/ulti-backend/internal/filescan"
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth" mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
"github.com/ultisuite/ulti-backend/internal/mail/limits" "github.com/ultisuite/ulti-backend/internal/mail/limits"
"github.com/ultisuite/ulti-backend/internal/mail/rules" "github.com/ultisuite/ulti-backend/internal/mail/rules"
@ -38,6 +39,7 @@ type SyncDeps struct {
Rules *rules.Engine Rules *rules.Engine
Automation MailAutomation Automation MailAutomation
Hub *realtime.Hub Hub *realtime.Hub
FileScanner *filescan.Scanner
} }
type SyncWorker struct { type SyncWorker struct {
@ -48,6 +50,7 @@ type SyncWorker struct {
oauth *mailoauth.Service oauth *mailoauth.Service
storage *storage.Client storage *storage.Client
attachBucket string attachBucket string
scanner *filescan.Scanner
pipeline *syncPipeline pipeline *syncPipeline
} }
@ -60,6 +63,7 @@ func NewSyncWorker(db *pgxpool.Pool, interval time.Duration, credManager *creden
oauth: oauthSvc, oauth: oauthSvc,
storage: deps.Storage, storage: deps.Storage,
attachBucket: deps.AttachBucket, attachBucket: deps.AttachBucket,
scanner: deps.FileScanner,
pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub), pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub),
} }
} }
@ -575,14 +579,30 @@ func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID str
if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) { if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) {
continue continue
} }
scanStatus := "skipped"
partData := part.Data
if w.scanner != nil {
result, err := w.scanner.ScanBytes(ctx, part.Filename, part.Data)
if err != nil {
if errors.Is(err, filescan.ErrMalicious) {
w.logger.Warn("imap attachment skipped: malware detected",
"message_id", messageID, "filename", part.Filename)
continue
}
return err
}
scanStatus = result.Status
}
objectKey := storage.MessageObjectKey(userID, messageID, part.Filename) objectKey := storage.MessageObjectKey(userID, messageID, part.Filename)
if err := w.storage.Put(ctx, objectKey, bytes.NewReader(part.Data), int64(len(part.Data)), part.ContentType); err != nil { if err := w.storage.Put(ctx, objectKey, bytes.NewReader(partData), int64(len(partData)), part.ContentType); err != nil {
return err return err
} }
_, err := w.db.Exec(ctx, ` _, err := w.db.Exec(ctx, `
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline) INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
`, messageID, part.Filename, part.ContentType, len(part.Data), bucket, objectKey, part.ContentID, part.IsInline) `, messageID, part.Filename, part.ContentType, len(partData), bucket, objectKey, part.ContentID, part.IsInline, scanStatus)
if err != nil { if err != nil {
_ = w.storage.Delete(ctx, objectKey) _ = w.storage.Delete(ctx, objectKey)
return err return err

View File

@ -0,0 +1,135 @@
package orgpolicy
import (
"context"
"encoding/json"
"strings"
"sync"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/ultisuite/ulti-backend/internal/config"
)
const orgSettingsSingletonID = 1
const defaultMaxUploadMiB = 512
// FilePolicies holds runtime file upload policy from org settings.
type FilePolicies struct {
VirusScanEnabled bool
VirusTotalAPIKey string
MaxUploadBytes int64
}
type Loader struct {
db *pgxpool.Pool
cfg *config.Config
mu sync.Mutex
cached FilePolicies
cachedAt time.Time
ttl time.Duration
}
func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader {
return &Loader{
db: db,
cfg: cfg,
ttl: 60 * time.Second,
}
}
func (l *Loader) FilePolicies(ctx context.Context) (FilePolicies, error) {
l.mu.Lock()
if !l.cachedAt.IsZero() && time.Since(l.cachedAt) < l.ttl {
out := l.cached
l.mu.Unlock()
return out, nil
}
l.mu.Unlock()
fp, err := l.loadFilePolicies(ctx)
if err != nil {
return FilePolicies{}, err
}
l.mu.Lock()
l.cached = fp
l.cachedAt = time.Now()
l.mu.Unlock()
return fp, nil
}
func (l *Loader) loadFilePolicies(ctx context.Context) (FilePolicies, error) {
var raw []byte
err := l.db.QueryRow(ctx, `
SELECT settings FROM org_settings WHERE id = $1
`, orgSettingsSingletonID).Scan(&raw)
if err != nil && err != pgx.ErrNoRows {
return FilePolicies{}, err
}
stored := map[string]any{}
if len(raw) > 0 {
if err := json.Unmarshal(raw, &stored); err != nil {
return FilePolicies{}, err
}
}
filePolicies, _ := stored["file_policies"].(map[string]any)
enabled := boolValue(filePolicies["virus_scan_enabled"])
apiKey := stringValue(filePolicies["virustotal_api_key"])
if strings.TrimSpace(apiKey) == "" && l.cfg != nil {
apiKey = strings.TrimSpace(l.cfg.VirusTotalAPIKey)
}
maxMiB := int64(defaultMaxUploadMiB)
switch v := filePolicies["max_upload_mib"].(type) {
case float64:
if v > 0 {
maxMiB = int64(v)
}
case int:
if v > 0 {
maxMiB = int64(v)
}
case int64:
if v > 0 {
maxMiB = v
}
}
return FilePolicies{
VirusScanEnabled: enabled,
VirusTotalAPIKey: apiKey,
MaxUploadBytes: maxMiB * 1024 * 1024,
}, nil
}
func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) {
fp, err := l.FilePolicies(ctx)
if err != nil {
return false, "", err
}
if !fp.VirusScanEnabled || strings.TrimSpace(fp.VirusTotalAPIKey) == "" {
return false, "", nil
}
return true, fp.VirusTotalAPIKey, nil
}
func boolValue(v any) bool {
switch t := v.(type) {
case bool:
return t
default:
return false
}
}
func stringValue(v any) string {
s, _ := v.(string)
return s
}

View File

@ -32,6 +32,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/authentik" "github.com/ultisuite/ulti-backend/internal/authentik"
"github.com/ultisuite/ulti-backend/internal/auth" "github.com/ultisuite/ulti-backend/internal/auth"
"github.com/ultisuite/ulti-backend/internal/config" "github.com/ultisuite/ulti-backend/internal/config"
"github.com/ultisuite/ulti-backend/internal/filescan"
"github.com/ultisuite/ulti-backend/internal/httpcors" "github.com/ultisuite/ulti-backend/internal/httpcors"
mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials" mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials"
imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap" imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap"
@ -43,6 +44,7 @@ import (
"github.com/ultisuite/ulti-backend/internal/meet" "github.com/ultisuite/ulti-backend/internal/meet"
"github.com/ultisuite/ulti-backend/internal/nextcloud" "github.com/ultisuite/ulti-backend/internal/nextcloud"
"github.com/ultisuite/ulti-backend/internal/observability" "github.com/ultisuite/ulti-backend/internal/observability"
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
"github.com/ultisuite/ulti-backend/internal/photos" "github.com/ultisuite/ulti-backend/internal/photos"
"github.com/ultisuite/ulti-backend/internal/realtime" "github.com/ultisuite/ulti-backend/internal/realtime"
"github.com/ultisuite/ulti-backend/internal/search" "github.com/ultisuite/ulti-backend/internal/search"
@ -203,6 +205,9 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
RedirectURL: oauthRedirect, RedirectURL: oauthRedirect,
}, rdb) }, rdb)
orgPolicyLoader := orgpolicy.NewLoader(pool, cfg)
fileScanner := filescan.NewScanner(orgPolicyLoader, slog.Default())
var syncWorker *imapsync.SyncWorker var syncWorker *imapsync.SyncWorker
if !opts.WithoutWorkers { if !opts.WithoutWorkers {
syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{ syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{
@ -210,6 +215,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
AttachBucket: cfg.MailAttachmentsBucket, AttachBucket: cfg.MailAttachmentsBucket,
Automation: autoDispatcher, Automation: autoDispatcher,
Hub: hub, Hub: hub,
FileScanner: fileScanner,
}) })
go syncWorker.Start(workerCtx) go syncWorker.Start(workerCtx)
} }
@ -229,6 +235,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst) sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender) mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
mailHandler.SetFileScanner(fileScanner)
if syncWorker != nil { if syncWorker != nil {
mailHandler.SetAccountSync(syncWorker) mailHandler.SetAccountSync(syncWorker)
} }
@ -262,6 +269,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
if ncClient != nil { if ncClient != nil {
driveSvc = drive.NewService(ncClient, hub, pool) driveSvc = drive.NewService(ncClient, hub, pool)
driveSvc.SetAutomation(autoDispatcher) driveSvc.SetAutomation(autoDispatcher)
driveSvc.SetFileScanner(fileScanner)
driveHandler = drive.NewHandlerWithService(driveSvc) driveHandler = drive.NewHandlerWithService(driveSvc)
mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc}) mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc})
contactsHandler = contacts.NewHandler(ncClient, pool) contactsHandler = contacts.NewHandler(ncClient, pool)

View File

@ -0,0 +1,267 @@
package virustotal
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"strings"
"time"
)
const (
defaultBaseURL = "https://www.virustotal.com/api/v3"
directUploadLimit = 32 * 1024 * 1024
maxUploadLimit = 650 * 1024 * 1024
)
type Client struct {
apiKey string
baseURL string
httpClient *http.Client
}
func NewClient(apiKey string) *Client {
return &Client{
apiKey: apiKey,
baseURL: defaultBaseURL,
httpClient: &http.Client{
Timeout: 120 * time.Second,
},
}
}
func (c *Client) apiURL(path string) string {
return strings.TrimRight(c.baseURL, "/") + path
}
type analysisStats struct {
Malicious int `json:"malicious"`
Suspicious int `json:"suspicious"`
}
type fileReport struct {
Data struct {
Attributes struct {
LastAnalysisStats analysisStats `json:"last_analysis_stats"`
} `json:"attributes"`
} `json:"data"`
}
type analysisResponse struct {
Data struct {
Attributes struct {
Status string `json:"status"`
Stats analysisStats `json:"stats"`
} `json:"attributes"`
} `json:"data"`
}
type uploadResponse struct {
Data struct {
ID string `json:"id"`
} `json:"data"`
}
type uploadURLResponse struct {
Data string `json:"data"`
}
func (c *Client) headers() http.Header {
h := http.Header{}
h.Set("Accept", "application/json")
h.Set("x-apikey", c.apiKey)
return h
}
func (c *Client) lookupFile(ctx context.Context, sha256 string) (analysisStats, bool, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/"+sha256), nil)
if err != nil {
return analysisStats{}, false, err
}
req.Header = c.headers()
res, err := c.httpClient.Do(req)
if err != nil {
return analysisStats{}, false, err
}
defer res.Body.Close()
if res.StatusCode == http.StatusNotFound {
return analysisStats{}, false, nil
}
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
return analysisStats{}, false, fmt.Errorf("virustotal lookup unavailable: %d", res.StatusCode)
}
if res.StatusCode >= 400 {
return analysisStats{}, false, nil
}
var report fileReport
if err := json.NewDecoder(res.Body).Decode(&report); err != nil {
return analysisStats{}, false, err
}
return report.Data.Attributes.LastAnalysisStats, true, nil
}
func (c *Client) uploadFile(ctx context.Context, data []byte, filename string) (string, error) {
if len(data) <= directUploadLimit {
return c.uploadDirect(ctx, data, filename)
}
return c.uploadLarge(ctx, data, filename)
}
func (c *Client) uploadDirect(ctx context.Context, data []byte, filename string) (string, error) {
var body bytes.Buffer
w := multipart.NewWriter(&body)
part, err := w.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL("/files"), &body)
if err != nil {
return "", err
}
req.Header = c.headers()
req.Header.Set("Content-Type", w.FormDataContentType())
res, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
return "", fmt.Errorf("virustotal upload failed: %d %s", res.StatusCode, string(b))
}
var out uploadResponse
if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
return "", err
}
return out.Data.ID, nil
}
func (c *Client) uploadLarge(ctx context.Context, data []byte, filename string) (string, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/upload_url"), nil)
if err != nil {
return "", err
}
req.Header = c.headers()
res, err := c.httpClient.Do(req)
if err != nil {
return "", err
}
defer res.Body.Close()
if res.StatusCode >= 400 {
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
return "", fmt.Errorf("virustotal upload_url failed: %d %s", res.StatusCode, string(b))
}
var urlResp uploadURLResponse
if err := json.NewDecoder(res.Body).Decode(&urlResp); err != nil {
return "", err
}
if urlResp.Data == "" {
return "", fmt.Errorf("virustotal upload_url empty")
}
var body bytes.Buffer
w := multipart.NewWriter(&body)
part, err := w.CreateFormFile("file", filename)
if err != nil {
return "", err
}
if _, err := part.Write(data); err != nil {
return "", err
}
if err := w.Close(); err != nil {
return "", err
}
upReq, err := http.NewRequestWithContext(ctx, http.MethodPost, urlResp.Data, &body)
if err != nil {
return "", err
}
upReq.Header = c.headers()
upReq.Header.Set("Content-Type", w.FormDataContentType())
upRes, err := c.httpClient.Do(upReq)
if err != nil {
return "", err
}
defer upRes.Body.Close()
if upRes.StatusCode >= 400 {
b, _ := io.ReadAll(io.LimitReader(upRes.Body, 4096))
return "", fmt.Errorf("virustotal large upload failed: %d %s", upRes.StatusCode, string(b))
}
var out uploadResponse
if err := json.NewDecoder(upRes.Body).Decode(&out); err != nil {
return "", err
}
return out.Data.ID, nil
}
func (c *Client) pollAnalysis(ctx context.Context, analysisID string, timeout time.Duration) (analysisStats, error) {
deadline := time.Now().Add(timeout)
for {
if ctx.Err() != nil {
return analysisStats{}, ctx.Err()
}
if time.Now().After(deadline) {
return analysisStats{}, fmt.Errorf("virustotal analysis timeout")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/analyses/"+analysisID), nil)
if err != nil {
return analysisStats{}, err
}
req.Header = c.headers()
res, err := c.httpClient.Do(req)
if err != nil {
return analysisStats{}, err
}
var out analysisResponse
decodeErr := json.NewDecoder(res.Body).Decode(&out)
res.Body.Close()
if decodeErr != nil {
return analysisStats{}, decodeErr
}
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
return analysisStats{}, fmt.Errorf("virustotal poll unavailable: %d", res.StatusCode)
}
status := out.Data.Attributes.Status
if status == "completed" {
return out.Data.Attributes.Stats, nil
}
select {
case <-ctx.Done():
return analysisStats{}, ctx.Err()
case <-time.After(2 * time.Second):
}
}
}
func statsMalicious(stats analysisStats) bool {
return stats.Malicious > 0
}

View File

@ -0,0 +1,7 @@
package virustotal
import "errors"
var (
ErrMalicious = errors.New("malware detected")
)

106
internal/virustotal/scan.go Normal file
View File

@ -0,0 +1,106 @@
package virustotal
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"log/slog"
"time"
)
const defaultScanTimeout = 60 * time.Second
// ScanResult is the outcome of a VirusTotal scan attempt.
type ScanResult struct {
Status string // clean | skipped
}
// Scanner wraps Client with fail-open behavior on API errors.
type Scanner struct {
client *Client
logger *slog.Logger
timeout time.Duration
}
func NewScanner(apiKey string, logger *slog.Logger) *Scanner {
if logger == nil {
logger = slog.Default()
}
return &Scanner{
client: NewClient(apiKey),
logger: logger,
timeout: defaultScanTimeout,
}
}
// ScanBytes scans file content. Returns ErrMalicious if detected.
// On VT unavailability, returns skipped (fail-open).
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte, sha256Hex string) (ScanResult, error) {
if len(data) == 0 {
return ScanResult{Status: "skipped"}, nil
}
if len(data) > maxUploadLimit {
s.logger.Warn("virustotal scan skipped: file too large for VT", "filename", filename, "size", len(data))
return ScanResult{Status: "skipped"}, nil
}
if sha256Hex == "" {
sum := sha256.Sum256(data)
sha256Hex = hex.EncodeToString(sum[:])
}
stats, found, err := s.client.lookupFile(ctx, sha256Hex)
if err != nil {
s.logger.Warn("virustotal lookup failed, skipping scan", "filename", filename, "error", err)
return ScanResult{Status: "skipped"}, nil
}
if found {
if statsMalicious(stats) {
return ScanResult{}, ErrMalicious
}
return ScanResult{Status: "clean"}, nil
}
analysisID, err := s.client.uploadFile(ctx, data, safeFilename(filename))
if err != nil {
s.logger.Warn("virustotal upload failed, skipping scan", "filename", filename, "error", err)
return ScanResult{Status: "skipped"}, nil
}
scanCtx, cancel := context.WithTimeout(ctx, s.timeout)
defer cancel()
stats, err = s.client.pollAnalysis(scanCtx, analysisID, s.timeout)
if err != nil {
s.logger.Warn("virustotal analysis failed, skipping scan", "filename", filename, "error", err)
return ScanResult{Status: "skipped"}, nil
}
if statsMalicious(stats) {
return ScanResult{}, ErrMalicious
}
return ScanResult{Status: "clean"}, nil
}
func safeFilename(name string) string {
if name == "" {
return "upload"
}
base := name
if len(base) > 200 {
base = base[:200]
}
return base
}
// SHA256Hex returns hex-encoded SHA-256 of data.
func SHA256Hex(data []byte) string {
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
// FormatAnalysisID is a helper for logging.
func FormatAnalysisID(id string) string {
return fmt.Sprintf("vt:%s", id)
}

View File

@ -0,0 +1,109 @@
package virustotal
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestScannerLookupMalicious(t *testing.T) {
sha := sha256.Sum256([]byte("evil"))
shaHex := hex.EncodeToString(sha[:])
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/files/"+shaHex) {
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"attributes": map[string]any{
"last_analysis_stats": map[string]any{
"malicious": 2,
"suspicious": 0,
},
},
},
})
return
}
http.NotFound(w, r)
}))
defer srv.Close()
sc := NewScanner("test-key", nil)
sc.client.baseURL = srv.URL + "/api/v3"
_, err := sc.ScanBytes(context.Background(), "evil.bin", []byte("evil"), shaHex)
if err == nil {
t.Fatal("expected ErrMalicious")
}
if err != ErrMalicious {
t.Fatalf("err = %v, want ErrMalicious", err)
}
}
func TestScannerUploadAndPollClean(t *testing.T) {
pollCount := 0
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/"+SHA256Hex([]byte("clean"))):
http.NotFound(w, r)
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/files"):
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{"id": "analysis-1"},
})
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/analyses/analysis-1"):
pollCount++
status := "queued"
if pollCount >= 2 {
status = "completed"
}
_ = json.NewEncoder(w).Encode(map[string]any{
"data": map[string]any{
"attributes": map[string]any{
"status": status,
"stats": map[string]any{
"malicious": 0,
"suspicious": 0,
},
},
},
})
default:
http.NotFound(w, r)
}
}))
defer srv.Close()
sc := NewScanner("test-key", nil)
sc.client.baseURL = srv.URL + "/api/v3"
result, err := sc.ScanBytes(context.Background(), "clean.txt", []byte("clean"), "")
if err != nil {
t.Fatalf("ScanBytes: %v", err)
}
if result.Status != "clean" {
t.Fatalf("status = %q, want clean", result.Status)
}
}
func TestScannerFailOpenOnUnavailable(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "down", http.StatusServiceUnavailable)
}))
defer srv.Close()
sc := NewScanner("test-key", nil)
sc.client.baseURL = srv.URL + "/api/v3"
result, err := sc.ScanBytes(context.Background(), "file.bin", []byte("payload"), "")
if err != nil {
t.Fatalf("expected fail-open, got err %v", err)
}
if result.Status != "skipped" {
t.Fatalf("status = %q, want skipped", result.Status)
}
}

View File

@ -0,0 +1 @@
ALTER TABLE attachments DROP COLUMN IF EXISTS virus_scan_status;

View File

@ -0,0 +1,3 @@
ALTER TABLE attachments
ADD COLUMN virus_scan_status TEXT NOT NULL DEFAULT 'skipped'
CHECK (virus_scan_status IN ('clean', 'skipped', 'malicious'));