feat(scan): add VirusTotal upload antivirus
Admin-stored API key with env fallback; scan drive/mail/IMAP uploads. Fail-open if VT down, 422 on malware; migration for virus_scan_status.
This commit is contained in:
parent
f67c109f2f
commit
b90edf317c
@ -255,3 +255,8 @@ SEARCH_ENGINE=postgres
|
||||
# TYPESENSE_URL=http://typesense:8108
|
||||
# TYPESENSE_API_KEY={{TYPESENSE_API_KEY}}
|
||||
# TYPESENSE_COLLECTION=ulti
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# VirusTotal (optional env fallback; prefer admin Settings > File policies)
|
||||
# -----------------------------------------------------------------------------
|
||||
# VIRUSTOTAL_API_KEY=
|
||||
|
||||
@ -51,6 +51,7 @@ func defaultOrgPolicy() map[string]any {
|
||||
"external_sharing": "authenticated",
|
||||
"default_link_expiry_days": 30,
|
||||
"virus_scan_enabled": false,
|
||||
"virustotal_api_key": "",
|
||||
"retention_trash_days": 30,
|
||||
},
|
||||
"llm": map[string]any{
|
||||
@ -174,6 +175,7 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
|
||||
{"onlyoffice", "jwt_secret"},
|
||||
{"search", "meilisearch_api_key"},
|
||||
{"search", "typesense_api_key"},
|
||||
{"file_policies", "virustotal_api_key"},
|
||||
}
|
||||
for _, p := range secretPaths {
|
||||
patchSection, _ := patch[p.section].(map[string]any)
|
||||
@ -296,6 +298,7 @@ func maskOrgPolicy(policy map[string]any) map[string]any {
|
||||
maskStringField(cloned, "onlyoffice", "jwt_secret")
|
||||
maskStringField(cloned, "search", "meilisearch_api_key")
|
||||
maskStringField(cloned, "search", "typesense_api_key")
|
||||
maskStringField(cloned, "file_policies", "virustotal_api_key")
|
||||
if llm, ok := cloned["llm"].(map[string]any); ok {
|
||||
if providers, ok := llm["providers"].([]any); ok {
|
||||
for i, p := range providers {
|
||||
@ -374,6 +377,9 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
|
||||
"typesense_api_key": map[string]any{
|
||||
"configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "",
|
||||
},
|
||||
"virustotal_api_key": map[string]any{
|
||||
"configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -5,6 +5,7 @@ import (
|
||||
"path"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/automation"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
@ -13,6 +14,10 @@ func (s *Service) SetAutomation(d driveAutomation) {
|
||||
s.automation = d
|
||||
}
|
||||
|
||||
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
|
||||
s.scanner = scanner
|
||||
}
|
||||
|
||||
func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) {
|
||||
normalized := nextcloud.NormalizeClientPath(filePath)
|
||||
s.notifyFileChanged(externalUserID, normalized)
|
||||
|
||||
@ -733,6 +733,8 @@ func writeDriveError(w http.ResponseWriter, r *http.Request, err error) {
|
||||
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil)
|
||||
case errors.Is(err, ErrQuotaExceeded):
|
||||
apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil)
|
||||
case errors.Is(err, ErrMalware):
|
||||
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "drive.malware_detected", "malware detected in file", nil)
|
||||
case errors.Is(err, ErrInvalid):
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil)
|
||||
default:
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
package drive
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
@ -17,7 +20,18 @@ func (s *Service) UploadPublicShare(ctx context.Context, token, filePath, passwo
|
||||
if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) {
|
||||
return ErrForbidden
|
||||
}
|
||||
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, body, contentType)); err != nil {
|
||||
reader := body
|
||||
if s.scanner != nil {
|
||||
data, _, err := s.scanner.ScanReader(ctx, filePath, body, -1)
|
||||
if err != nil {
|
||||
if errors.Is(err, filescan.ErrMalicious) {
|
||||
return ErrMalware
|
||||
}
|
||||
return err
|
||||
}
|
||||
reader = bytes.NewReader(data)
|
||||
}
|
||||
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, reader, contentType)); err != nil {
|
||||
return err
|
||||
}
|
||||
s.recordPublicShareAccess(ctx, token)
|
||||
|
||||
@ -18,6 +18,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/automation"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/publicshare"
|
||||
@ -30,6 +31,7 @@ var (
|
||||
ErrForbidden = errors.New("forbidden")
|
||||
ErrQuotaExceeded = errors.New("quota exceeded")
|
||||
ErrInvalid = errors.New("invalid request")
|
||||
ErrMalware = errors.New("malware detected")
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@ -37,6 +39,7 @@ type Service struct {
|
||||
hub *realtime.Hub
|
||||
db *pgxpool.Pool
|
||||
automation driveAutomation
|
||||
scanner *filescan.Scanner
|
||||
maxUploadBytes int64
|
||||
quotaReserveByte int64
|
||||
}
|
||||
@ -184,7 +187,18 @@ func (s *Service) Upload(ctx context.Context, userID, path string, body io.Reade
|
||||
if err := s.ensureQuota(ctx, userID, contentLength); err != nil {
|
||||
return err
|
||||
}
|
||||
return mapDriveError(s.nc.Upload(ctx, userID, path, body, contentType))
|
||||
reader := body
|
||||
if s.scanner != nil {
|
||||
data, _, err := s.scanner.ScanReader(ctx, path, body, contentLength)
|
||||
if err != nil {
|
||||
if errors.Is(err, filescan.ErrMalicious) {
|
||||
return ErrMalware
|
||||
}
|
||||
return err
|
||||
}
|
||||
reader = bytes.NewReader(data)
|
||||
}
|
||||
return mapDriveError(s.nc.Upload(ctx, userID, path, reader, contentType))
|
||||
}
|
||||
|
||||
func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error {
|
||||
@ -202,6 +216,31 @@ func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath
|
||||
if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.scanner != nil {
|
||||
if err := s.scanAssembledUpload(ctx, userID, targetPath); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) scanAssembledUpload(ctx context.Context, userID, targetPath string) error {
|
||||
body, _, err := s.nc.Download(ctx, userID, targetPath)
|
||||
if err != nil {
|
||||
return mapDriveError(err)
|
||||
}
|
||||
defer body.Close()
|
||||
|
||||
data, scanResult, err := s.scanner.ScanReader(ctx, targetPath, body, -1)
|
||||
if err != nil {
|
||||
if errors.Is(err, filescan.ErrMalicious) {
|
||||
_ = s.nc.Delete(ctx, userID, targetPath)
|
||||
return ErrMalware
|
||||
}
|
||||
return err
|
||||
}
|
||||
_ = data
|
||||
_ = scanResult
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@ -39,7 +40,7 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
|
||||
}
|
||||
|
||||
rows, err := s.db.Query(ctx, `
|
||||
SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, '')
|
||||
SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, ''), virus_scan_status
|
||||
FROM attachments WHERE message_id = $1
|
||||
ORDER BY created_at ASC
|
||||
`, messageID)
|
||||
@ -50,15 +51,16 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
|
||||
|
||||
out := make([]map[string]any, 0)
|
||||
for rows.Next() {
|
||||
var id, filename, contentType, contentID, drivePath string
|
||||
var id, filename, contentType, contentID, drivePath, virusScanStatus string
|
||||
var size int64
|
||||
var isInline bool
|
||||
if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath); err != nil {
|
||||
if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath, &virusScanStatus); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry := map[string]any{
|
||||
"id": id, "filename": filename, "content_type": contentType,
|
||||
"size": size, "is_inline": isInline,
|
||||
"virus_scan_status": virusScanStatus,
|
||||
}
|
||||
if contentID != "" {
|
||||
entry["content_id"] = contentID
|
||||
@ -142,16 +144,28 @@ func (s *Service) UploadMessageAttachment(
|
||||
}
|
||||
|
||||
objectKey := storage.MessageObjectKey(userID, messageID, filename)
|
||||
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil {
|
||||
scanStatus := "skipped"
|
||||
putReader := reader
|
||||
putSize := size
|
||||
if s.scanner != nil {
|
||||
data, result, err := s.scanner.ScanReader(ctx, filename, reader, size)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
scanStatus = result.Status
|
||||
putReader = bytes.NewReader(data)
|
||||
putSize = int64(len(data))
|
||||
}
|
||||
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var id string
|
||||
err = s.db.QueryRow(ctx, `
|
||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id
|
||||
`, messageID, filename, contentType, size, s.storageBucket(), objectKey, contentID, isInline).Scan(&id)
|
||||
`, messageID, filename, contentType, putSize, s.storageBucket(), objectKey, contentID, isInline, scanStatus).Scan(&id)
|
||||
if err != nil {
|
||||
_ = s.storage.Delete(ctx, objectKey)
|
||||
return "", err
|
||||
@ -224,7 +238,17 @@ func (s *Service) UploadDraftAttachment(
|
||||
}
|
||||
|
||||
objectKey := storage.DraftObjectKey(userID, draftID, filename)
|
||||
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil {
|
||||
putReader := reader
|
||||
putSize := size
|
||||
if s.scanner != nil {
|
||||
data, _, err := s.scanner.ScanReader(ctx, filename, reader, size)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
putReader = bytes.NewReader(data)
|
||||
putSize = int64(len(data))
|
||||
}
|
||||
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@ -237,13 +261,13 @@ func (s *Service) UploadDraftAttachment(
|
||||
for _, ref := range refs {
|
||||
totalSize += ref.Size
|
||||
}
|
||||
if err := limits.ValidateAttachmentQuota(len(refs), totalSize, size); err != nil {
|
||||
if err := limits.ValidateAttachmentQuota(len(refs), totalSize, putSize); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
attID := uuid.NewString()
|
||||
refs = append(refs, draftAttachmentRef{
|
||||
ID: attID, Filename: filename, ContentType: contentType, Size: size,
|
||||
ID: attID, Filename: filename, ContentType: contentType, Size: putSize,
|
||||
S3Bucket: s.storageBucket(), S3Key: objectKey,
|
||||
ContentID: contentID, IsInline: isInline,
|
||||
})
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
)
|
||||
|
||||
@ -15,6 +16,10 @@ func (s *Service) SetDriveUploader(uploader DriveUploader) {
|
||||
s.driveUploader = uploader
|
||||
}
|
||||
|
||||
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
|
||||
s.scanner = scanner
|
||||
}
|
||||
|
||||
func (s *Service) SaveAttachmentToDrive(
|
||||
ctx context.Context,
|
||||
externalID, email, sub, displayName, messageID, attachmentID, folderPath string,
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/mail/sendguard"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
||||
@ -42,6 +43,13 @@ func (h *Handler) SetDriveUploader(uploader DriveUploader) {
|
||||
}
|
||||
}
|
||||
|
||||
// SetFileScanner wires VirusTotal scanning for mail attachment uploads.
|
||||
func (h *Handler) SetFileScanner(scanner *filescan.Scanner) {
|
||||
if s, ok := h.svc.(*Service); ok {
|
||||
s.SetFileScanner(scanner)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandlerWithService(svc ServiceAPI) *Handler {
|
||||
return &Handler{
|
||||
svc: svc,
|
||||
|
||||
@ -15,6 +15,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||
driveapi "github.com/ultisuite/ulti-backend/internal/api/drive"
|
||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||
)
|
||||
|
||||
@ -342,6 +343,9 @@ func writeAttachmentUploadError(w http.ResponseWriter, r *http.Request, err erro
|
||||
case errors.Is(err, limits.ErrTooManyAttachments):
|
||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil)
|
||||
return true
|
||||
case errors.Is(err, ErrMalware), errors.Is(err, filescan.ErrMalicious):
|
||||
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "mail.malware_detected", "malware detected in attachment", nil)
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
@ -11,6 +11,7 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
||||
@ -28,6 +29,7 @@ var (
|
||||
ErrInvalidAccountCredentials = errors.New("account credentials invalid")
|
||||
ErrInvalidFolderScope = errors.New("invalid folder scope")
|
||||
ErrFolderHasChildren = errors.New("folder has children")
|
||||
ErrMalware = filescan.ErrMalicious
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
@ -37,6 +39,7 @@ type Service struct {
|
||||
storage *storage.Client
|
||||
attachmentsBucket string
|
||||
driveUploader DriveUploader
|
||||
scanner *filescan.Scanner
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
|
||||
@ -116,6 +116,9 @@ type Config struct {
|
||||
TypesenseKey string
|
||||
TypesenseCollection string
|
||||
|
||||
// VirusTotal (optional env fallback for org file_policies.virustotal_api_key)
|
||||
VirusTotalAPIKey string
|
||||
|
||||
// Observability
|
||||
HealthNextcloudURL string
|
||||
HealthImmichURL string
|
||||
@ -221,6 +224,8 @@ func Load() (*Config, error) {
|
||||
TypesenseKey: secrets.Env("TYPESENSE_API_KEY"),
|
||||
TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"),
|
||||
|
||||
VirusTotalAPIKey: secrets.Env("VIRUSTOTAL_API_KEY"),
|
||||
|
||||
HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")),
|
||||
HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")),
|
||||
HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))),
|
||||
|
||||
144
internal/filescan/scanner.go
Normal file
144
internal/filescan/scanner.go
Normal file
@ -0,0 +1,144 @@
|
||||
package filescan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
"github.com/ultisuite/ulti-backend/internal/virustotal"
|
||||
)
|
||||
|
||||
const memoryBufferLimit = 32 * 1024 * 1024
|
||||
|
||||
// Result holds scan outcome for persistence.
|
||||
type Result struct {
|
||||
Status string // clean | skipped
|
||||
}
|
||||
|
||||
// PolicyLoader supplies org file policies at runtime.
|
||||
type PolicyLoader interface {
|
||||
FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error)
|
||||
ScanEnabled(ctx context.Context) (bool, string, error)
|
||||
}
|
||||
|
||||
// Scanner coordinates org policy and VirusTotal scanning.
|
||||
type Scanner struct {
|
||||
policies PolicyLoader
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
func NewScanner(policies PolicyLoader, logger *slog.Logger) *Scanner {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Scanner{policies: policies, logger: logger}
|
||||
}
|
||||
|
||||
// ScanReader reads all bytes from r (up to maxBytes), optionally scans, returns data for storage.
|
||||
func (s *Scanner) ScanReader(ctx context.Context, filename string, r io.Reader, size int64) ([]byte, Result, error) {
|
||||
fp, err := s.policies.FilePolicies(ctx)
|
||||
if err != nil {
|
||||
return nil, Result{}, err
|
||||
}
|
||||
|
||||
maxBytes := fp.MaxUploadBytes
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = 512 * 1024 * 1024
|
||||
}
|
||||
if maxBytes > virustotalMaxBytes() {
|
||||
maxBytes = virustotalMaxBytes()
|
||||
}
|
||||
|
||||
data, err := readLimited(r, size, maxBytes)
|
||||
if err != nil {
|
||||
return nil, Result{}, err
|
||||
}
|
||||
|
||||
enabled := fp.VirusScanEnabled && fp.VirusTotalAPIKey != ""
|
||||
if !enabled {
|
||||
return data, Result{Status: "skipped"}, nil
|
||||
}
|
||||
|
||||
vt := virustotal.NewScanner(fp.VirusTotalAPIKey, s.logger)
|
||||
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
|
||||
if err != nil {
|
||||
if errors.Is(err, virustotal.ErrMalicious) {
|
||||
return nil, Result{Status: "malicious"}, virustotal.ErrMalicious
|
||||
}
|
||||
return nil, Result{}, err
|
||||
}
|
||||
return data, Result{Status: scanResult.Status}, nil
|
||||
}
|
||||
|
||||
// ScanBytes scans pre-loaded bytes when policy is already known.
|
||||
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte) (Result, error) {
|
||||
enabled, apiKey, err := s.policies.ScanEnabled(ctx)
|
||||
if err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
if !enabled {
|
||||
return Result{Status: "skipped"}, nil
|
||||
}
|
||||
|
||||
vt := virustotal.NewScanner(apiKey, s.logger)
|
||||
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
|
||||
if err != nil {
|
||||
if errors.Is(err, virustotal.ErrMalicious) {
|
||||
return Result{Status: "malicious"}, virustotal.ErrMalicious
|
||||
}
|
||||
return Result{}, err
|
||||
}
|
||||
return Result{Status: scanResult.Status}, nil
|
||||
}
|
||||
|
||||
func readLimited(r io.Reader, size int64, maxBytes int64) ([]byte, error) {
|
||||
if size >= 0 && size <= memoryBufferLimit {
|
||||
limited := io.LimitReader(r, maxBytes+1)
|
||||
data, err := io.ReadAll(limited)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(data)) > maxBytes {
|
||||
return nil, errors.New("file exceeds max upload size")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
tmp, err := os.CreateTemp("", "ulti-scan-*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
defer os.Remove(tmpPath)
|
||||
|
||||
written, err := io.Copy(tmp, io.LimitReader(r, maxBytes+1))
|
||||
if err != nil {
|
||||
tmp.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written > maxBytes {
|
||||
return nil, errors.New("file exceeds max upload size")
|
||||
}
|
||||
|
||||
return os.ReadFile(filepath.Clean(tmpPath))
|
||||
}
|
||||
|
||||
func virustotalMaxBytes() int64 {
|
||||
return 650 * 1024 * 1024
|
||||
}
|
||||
|
||||
// ErrMalicious re-exports virustotal.ErrMalicious for handlers.
|
||||
var ErrMalicious = virustotal.ErrMalicious
|
||||
|
||||
// NopReader helper for tests.
|
||||
func NopReader(data []byte) io.Reader {
|
||||
return bytes.NewReader(data)
|
||||
}
|
||||
40
internal/filescan/scanner_test.go
Normal file
40
internal/filescan/scanner_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package filescan
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
)
|
||||
|
||||
type stubPolicyLoader struct {
|
||||
fp orgpolicy.FilePolicies
|
||||
}
|
||||
|
||||
func (s stubPolicyLoader) FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) {
|
||||
return s.fp, nil
|
||||
}
|
||||
|
||||
func (s stubPolicyLoader) ScanEnabled(ctx context.Context) (bool, string, error) {
|
||||
if s.fp.VirusScanEnabled && s.fp.VirusTotalAPIKey != "" {
|
||||
return true, s.fp.VirusTotalAPIKey, nil
|
||||
}
|
||||
return false, "", nil
|
||||
}
|
||||
|
||||
func TestScannerDisabledReturnsSkipped(t *testing.T) {
|
||||
scanner := &Scanner{
|
||||
policies: stubPolicyLoader{fp: orgpolicy.FilePolicies{VirusScanEnabled: false}},
|
||||
}
|
||||
data, result, err := scanner.ScanReader(context.Background(), "test.txt", bytes.NewReader([]byte("hello")), 5)
|
||||
if err != nil {
|
||||
t.Fatalf("ScanReader: %v", err)
|
||||
}
|
||||
if result.Status != "skipped" {
|
||||
t.Fatalf("status = %q, want skipped", result.Status)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("data = %q", data)
|
||||
}
|
||||
}
|
||||
@ -57,3 +57,60 @@ func TestAdminOrgSettings(t *testing.T) {
|
||||
t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) {
|
||||
h := integrationtest.RequireHarness(t)
|
||||
adminClient, _ := integrationtest.RequireAdminClient(t, h)
|
||||
|
||||
putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
|
||||
"policy": map[string]any{
|
||||
"file_policies": map[string]any{
|
||||
"virus_scan_enabled": true,
|
||||
"virustotal_api_key": "vt-test-secret-key",
|
||||
},
|
||||
},
|
||||
})
|
||||
integrationtest.FailIf(err, t, "put org settings with virustotal key")
|
||||
integrationtest.FailUnlessStatus(t, putResp, 200)
|
||||
|
||||
var afterPut map[string]any
|
||||
integrationtest.DecodeJSON(t, putResp, &afterPut)
|
||||
secrets, ok := afterPut["secrets"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing secrets: %#v", afterPut)
|
||||
}
|
||||
vtSecret, ok := secrets["virustotal_api_key"].(map[string]any)
|
||||
if !ok || vtSecret["configured"] != true {
|
||||
t.Fatalf("virustotal_api_key not configured: %#v", secrets)
|
||||
}
|
||||
filePolicies, ok := afterPut["policy"].(map[string]any)["file_policies"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing file_policies")
|
||||
}
|
||||
if filePolicies["virustotal_api_key"] != "" {
|
||||
t.Fatalf("virustotal_api_key should be masked on GET, got %q", filePolicies["virustotal_api_key"])
|
||||
}
|
||||
|
||||
// Empty patch must preserve stored secret.
|
||||
preserveResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
|
||||
"policy": map[string]any{
|
||||
"file_policies": map[string]any{
|
||||
"virus_scan_enabled": false,
|
||||
"virustotal_api_key": "",
|
||||
},
|
||||
},
|
||||
})
|
||||
integrationtest.FailIf(err, t, "preserve virustotal key")
|
||||
integrationtest.FailUnlessStatus(t, preserveResp, 200)
|
||||
|
||||
var preserved map[string]any
|
||||
integrationtest.DecodeJSON(t, preserveResp, &preserved)
|
||||
preservedSecrets, ok := preserved["secrets"].(map[string]any)
|
||||
if !ok {
|
||||
t.Fatalf("missing secrets after preserve")
|
||||
}
|
||||
vtPreserved, ok := preservedSecrets["virustotal_api_key"].(map[string]any)
|
||||
if !ok || vtPreserved["configured"] != true {
|
||||
t.Fatalf("virustotal key not preserved: %#v", preservedSecrets)
|
||||
}
|
||||
}
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/connect"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||
@ -38,6 +39,7 @@ type SyncDeps struct {
|
||||
Rules *rules.Engine
|
||||
Automation MailAutomation
|
||||
Hub *realtime.Hub
|
||||
FileScanner *filescan.Scanner
|
||||
}
|
||||
|
||||
type SyncWorker struct {
|
||||
@ -48,6 +50,7 @@ type SyncWorker struct {
|
||||
oauth *mailoauth.Service
|
||||
storage *storage.Client
|
||||
attachBucket string
|
||||
scanner *filescan.Scanner
|
||||
pipeline *syncPipeline
|
||||
}
|
||||
|
||||
@ -60,6 +63,7 @@ func NewSyncWorker(db *pgxpool.Pool, interval time.Duration, credManager *creden
|
||||
oauth: oauthSvc,
|
||||
storage: deps.Storage,
|
||||
attachBucket: deps.AttachBucket,
|
||||
scanner: deps.FileScanner,
|
||||
pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub),
|
||||
}
|
||||
}
|
||||
@ -575,14 +579,30 @@ func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID str
|
||||
if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) {
|
||||
continue
|
||||
}
|
||||
|
||||
scanStatus := "skipped"
|
||||
partData := part.Data
|
||||
if w.scanner != nil {
|
||||
result, err := w.scanner.ScanBytes(ctx, part.Filename, part.Data)
|
||||
if err != nil {
|
||||
if errors.Is(err, filescan.ErrMalicious) {
|
||||
w.logger.Warn("imap attachment skipped: malware detected",
|
||||
"message_id", messageID, "filename", part.Filename)
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
scanStatus = result.Status
|
||||
}
|
||||
|
||||
objectKey := storage.MessageObjectKey(userID, messageID, part.Filename)
|
||||
if err := w.storage.Put(ctx, objectKey, bytes.NewReader(part.Data), int64(len(part.Data)), part.ContentType); err != nil {
|
||||
if err := w.storage.Put(ctx, objectKey, bytes.NewReader(partData), int64(len(partData)), part.ContentType); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.db.Exec(ctx, `
|
||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`, messageID, part.Filename, part.ContentType, len(part.Data), bucket, objectKey, part.ContentID, part.IsInline)
|
||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`, messageID, part.Filename, part.ContentType, len(partData), bucket, objectKey, part.ContentID, part.IsInline, scanStatus)
|
||||
if err != nil {
|
||||
_ = w.storage.Delete(ctx, objectKey)
|
||||
return err
|
||||
|
||||
135
internal/orgpolicy/loader.go
Normal file
135
internal/orgpolicy/loader.go
Normal file
@ -0,0 +1,135 @@
|
||||
package orgpolicy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
)
|
||||
|
||||
const orgSettingsSingletonID = 1
|
||||
|
||||
const defaultMaxUploadMiB = 512
|
||||
|
||||
// FilePolicies holds runtime file upload policy from org settings.
|
||||
type FilePolicies struct {
|
||||
VirusScanEnabled bool
|
||||
VirusTotalAPIKey string
|
||||
MaxUploadBytes int64
|
||||
}
|
||||
|
||||
type Loader struct {
|
||||
db *pgxpool.Pool
|
||||
cfg *config.Config
|
||||
|
||||
mu sync.Mutex
|
||||
cached FilePolicies
|
||||
cachedAt time.Time
|
||||
ttl time.Duration
|
||||
}
|
||||
|
||||
func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader {
|
||||
return &Loader{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
ttl: 60 * time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Loader) FilePolicies(ctx context.Context) (FilePolicies, error) {
|
||||
l.mu.Lock()
|
||||
if !l.cachedAt.IsZero() && time.Since(l.cachedAt) < l.ttl {
|
||||
out := l.cached
|
||||
l.mu.Unlock()
|
||||
return out, nil
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
fp, err := l.loadFilePolicies(ctx)
|
||||
if err != nil {
|
||||
return FilePolicies{}, err
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
l.cached = fp
|
||||
l.cachedAt = time.Now()
|
||||
l.mu.Unlock()
|
||||
return fp, nil
|
||||
}
|
||||
|
||||
func (l *Loader) loadFilePolicies(ctx context.Context) (FilePolicies, error) {
|
||||
var raw []byte
|
||||
err := l.db.QueryRow(ctx, `
|
||||
SELECT settings FROM org_settings WHERE id = $1
|
||||
`, orgSettingsSingletonID).Scan(&raw)
|
||||
if err != nil && err != pgx.ErrNoRows {
|
||||
return FilePolicies{}, err
|
||||
}
|
||||
|
||||
stored := map[string]any{}
|
||||
if len(raw) > 0 {
|
||||
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||
return FilePolicies{}, err
|
||||
}
|
||||
}
|
||||
|
||||
filePolicies, _ := stored["file_policies"].(map[string]any)
|
||||
enabled := boolValue(filePolicies["virus_scan_enabled"])
|
||||
apiKey := stringValue(filePolicies["virustotal_api_key"])
|
||||
if strings.TrimSpace(apiKey) == "" && l.cfg != nil {
|
||||
apiKey = strings.TrimSpace(l.cfg.VirusTotalAPIKey)
|
||||
}
|
||||
|
||||
maxMiB := int64(defaultMaxUploadMiB)
|
||||
switch v := filePolicies["max_upload_mib"].(type) {
|
||||
case float64:
|
||||
if v > 0 {
|
||||
maxMiB = int64(v)
|
||||
}
|
||||
case int:
|
||||
if v > 0 {
|
||||
maxMiB = int64(v)
|
||||
}
|
||||
case int64:
|
||||
if v > 0 {
|
||||
maxMiB = v
|
||||
}
|
||||
}
|
||||
|
||||
return FilePolicies{
|
||||
VirusScanEnabled: enabled,
|
||||
VirusTotalAPIKey: apiKey,
|
||||
MaxUploadBytes: maxMiB * 1024 * 1024,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) {
|
||||
fp, err := l.FilePolicies(ctx)
|
||||
if err != nil {
|
||||
return false, "", err
|
||||
}
|
||||
if !fp.VirusScanEnabled || strings.TrimSpace(fp.VirusTotalAPIKey) == "" {
|
||||
return false, "", nil
|
||||
}
|
||||
return true, fp.VirusTotalAPIKey, nil
|
||||
}
|
||||
|
||||
func boolValue(v any) bool {
|
||||
switch t := v.(type) {
|
||||
case bool:
|
||||
return t
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func stringValue(v any) string {
|
||||
s, _ := v.(string)
|
||||
return s
|
||||
}
|
||||
@ -32,6 +32,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||
"github.com/ultisuite/ulti-backend/internal/config"
|
||||
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||
"github.com/ultisuite/ulti-backend/internal/httpcors"
|
||||
mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||
imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||
@ -43,6 +44,7 @@ import (
|
||||
"github.com/ultisuite/ulti-backend/internal/meet"
|
||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||
"github.com/ultisuite/ulti-backend/internal/observability"
|
||||
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||
"github.com/ultisuite/ulti-backend/internal/photos"
|
||||
"github.com/ultisuite/ulti-backend/internal/realtime"
|
||||
"github.com/ultisuite/ulti-backend/internal/search"
|
||||
@ -203,6 +205,9 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
RedirectURL: oauthRedirect,
|
||||
}, rdb)
|
||||
|
||||
orgPolicyLoader := orgpolicy.NewLoader(pool, cfg)
|
||||
fileScanner := filescan.NewScanner(orgPolicyLoader, slog.Default())
|
||||
|
||||
var syncWorker *imapsync.SyncWorker
|
||||
if !opts.WithoutWorkers {
|
||||
syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{
|
||||
@ -210,6 +215,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
AttachBucket: cfg.MailAttachmentsBucket,
|
||||
Automation: autoDispatcher,
|
||||
Hub: hub,
|
||||
FileScanner: fileScanner,
|
||||
})
|
||||
go syncWorker.Start(workerCtx)
|
||||
}
|
||||
@ -229,6 +235,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
|
||||
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
|
||||
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
|
||||
mailHandler.SetFileScanner(fileScanner)
|
||||
if syncWorker != nil {
|
||||
mailHandler.SetAccountSync(syncWorker)
|
||||
}
|
||||
@ -262,6 +269,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
||||
if ncClient != nil {
|
||||
driveSvc = drive.NewService(ncClient, hub, pool)
|
||||
driveSvc.SetAutomation(autoDispatcher)
|
||||
driveSvc.SetFileScanner(fileScanner)
|
||||
driveHandler = drive.NewHandlerWithService(driveSvc)
|
||||
mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc})
|
||||
contactsHandler = contacts.NewHandler(ncClient, pool)
|
||||
|
||||
267
internal/virustotal/client.go
Normal file
267
internal/virustotal/client.go
Normal file
@ -0,0 +1,267 @@
|
||||
package virustotal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultBaseURL = "https://www.virustotal.com/api/v3"
|
||||
directUploadLimit = 32 * 1024 * 1024
|
||||
maxUploadLimit = 650 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
apiKey string
|
||||
baseURL string
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
func NewClient(apiKey string) *Client {
|
||||
return &Client{
|
||||
apiKey: apiKey,
|
||||
baseURL: defaultBaseURL,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 120 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) apiURL(path string) string {
|
||||
return strings.TrimRight(c.baseURL, "/") + path
|
||||
}
|
||||
|
||||
type analysisStats struct {
|
||||
Malicious int `json:"malicious"`
|
||||
Suspicious int `json:"suspicious"`
|
||||
}
|
||||
|
||||
type fileReport struct {
|
||||
Data struct {
|
||||
Attributes struct {
|
||||
LastAnalysisStats analysisStats `json:"last_analysis_stats"`
|
||||
} `json:"attributes"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type analysisResponse struct {
|
||||
Data struct {
|
||||
Attributes struct {
|
||||
Status string `json:"status"`
|
||||
Stats analysisStats `json:"stats"`
|
||||
} `json:"attributes"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type uploadResponse struct {
|
||||
Data struct {
|
||||
ID string `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
|
||||
type uploadURLResponse struct {
|
||||
Data string `json:"data"`
|
||||
}
|
||||
|
||||
func (c *Client) headers() http.Header {
|
||||
h := http.Header{}
|
||||
h.Set("Accept", "application/json")
|
||||
h.Set("x-apikey", c.apiKey)
|
||||
return h
|
||||
}
|
||||
|
||||
func (c *Client) lookupFile(ctx context.Context, sha256 string) (analysisStats, bool, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/"+sha256), nil)
|
||||
if err != nil {
|
||||
return analysisStats{}, false, err
|
||||
}
|
||||
req.Header = c.headers()
|
||||
|
||||
res, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return analysisStats{}, false, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode == http.StatusNotFound {
|
||||
return analysisStats{}, false, nil
|
||||
}
|
||||
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
|
||||
return analysisStats{}, false, fmt.Errorf("virustotal lookup unavailable: %d", res.StatusCode)
|
||||
}
|
||||
if res.StatusCode >= 400 {
|
||||
return analysisStats{}, false, nil
|
||||
}
|
||||
|
||||
var report fileReport
|
||||
if err := json.NewDecoder(res.Body).Decode(&report); err != nil {
|
||||
return analysisStats{}, false, err
|
||||
}
|
||||
return report.Data.Attributes.LastAnalysisStats, true, nil
|
||||
}
|
||||
|
||||
func (c *Client) uploadFile(ctx context.Context, data []byte, filename string) (string, error) {
|
||||
if len(data) <= directUploadLimit {
|
||||
return c.uploadDirect(ctx, data, filename)
|
||||
}
|
||||
return c.uploadLarge(ctx, data, filename)
|
||||
}
|
||||
|
||||
func (c *Client) uploadDirect(ctx context.Context, data []byte, filename string) (string, error) {
|
||||
var body bytes.Buffer
|
||||
w := multipart.NewWriter(&body)
|
||||
part, err := w.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL("/files"), &body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = c.headers()
|
||||
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
res, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||
return "", fmt.Errorf("virustotal upload failed: %d %s", res.StatusCode, string(b))
|
||||
}
|
||||
|
||||
var out uploadResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return out.Data.ID, nil
|
||||
}
|
||||
|
||||
func (c *Client) uploadLarge(ctx context.Context, data []byte, filename string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/upload_url"), nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header = c.headers()
|
||||
|
||||
res, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode >= 400 {
|
||||
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||
return "", fmt.Errorf("virustotal upload_url failed: %d %s", res.StatusCode, string(b))
|
||||
}
|
||||
|
||||
var urlResp uploadURLResponse
|
||||
if err := json.NewDecoder(res.Body).Decode(&urlResp); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if urlResp.Data == "" {
|
||||
return "", fmt.Errorf("virustotal upload_url empty")
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
w := multipart.NewWriter(&body)
|
||||
part, err := w.CreateFormFile("file", filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := part.Write(data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := w.Close(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
upReq, err := http.NewRequestWithContext(ctx, http.MethodPost, urlResp.Data, &body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
upReq.Header = c.headers()
|
||||
upReq.Header.Set("Content-Type", w.FormDataContentType())
|
||||
|
||||
upRes, err := c.httpClient.Do(upReq)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer upRes.Body.Close()
|
||||
|
||||
if upRes.StatusCode >= 400 {
|
||||
b, _ := io.ReadAll(io.LimitReader(upRes.Body, 4096))
|
||||
return "", fmt.Errorf("virustotal large upload failed: %d %s", upRes.StatusCode, string(b))
|
||||
}
|
||||
|
||||
var out uploadResponse
|
||||
if err := json.NewDecoder(upRes.Body).Decode(&out); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return out.Data.ID, nil
|
||||
}
|
||||
|
||||
func (c *Client) pollAnalysis(ctx context.Context, analysisID string, timeout time.Duration) (analysisStats, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return analysisStats{}, ctx.Err()
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
return analysisStats{}, fmt.Errorf("virustotal analysis timeout")
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/analyses/"+analysisID), nil)
|
||||
if err != nil {
|
||||
return analysisStats{}, err
|
||||
}
|
||||
req.Header = c.headers()
|
||||
|
||||
res, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return analysisStats{}, err
|
||||
}
|
||||
|
||||
var out analysisResponse
|
||||
decodeErr := json.NewDecoder(res.Body).Decode(&out)
|
||||
res.Body.Close()
|
||||
if decodeErr != nil {
|
||||
return analysisStats{}, decodeErr
|
||||
}
|
||||
|
||||
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
|
||||
return analysisStats{}, fmt.Errorf("virustotal poll unavailable: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
status := out.Data.Attributes.Status
|
||||
if status == "completed" {
|
||||
return out.Data.Attributes.Stats, nil
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return analysisStats{}, ctx.Err()
|
||||
case <-time.After(2 * time.Second):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func statsMalicious(stats analysisStats) bool {
|
||||
return stats.Malicious > 0
|
||||
}
|
||||
7
internal/virustotal/errors.go
Normal file
7
internal/virustotal/errors.go
Normal file
@ -0,0 +1,7 @@
|
||||
package virustotal
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrMalicious = errors.New("malware detected")
|
||||
)
|
||||
106
internal/virustotal/scan.go
Normal file
106
internal/virustotal/scan.go
Normal file
@ -0,0 +1,106 @@
|
||||
package virustotal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultScanTimeout = 60 * time.Second
|
||||
|
||||
// ScanResult is the outcome of a VirusTotal scan attempt.
|
||||
type ScanResult struct {
|
||||
Status string // clean | skipped
|
||||
}
|
||||
|
||||
// Scanner wraps Client with fail-open behavior on API errors.
|
||||
type Scanner struct {
|
||||
client *Client
|
||||
logger *slog.Logger
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewScanner(apiKey string, logger *slog.Logger) *Scanner {
|
||||
if logger == nil {
|
||||
logger = slog.Default()
|
||||
}
|
||||
return &Scanner{
|
||||
client: NewClient(apiKey),
|
||||
logger: logger,
|
||||
timeout: defaultScanTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// ScanBytes scans file content. Returns ErrMalicious if detected.
|
||||
// On VT unavailability, returns skipped (fail-open).
|
||||
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte, sha256Hex string) (ScanResult, error) {
|
||||
if len(data) == 0 {
|
||||
return ScanResult{Status: "skipped"}, nil
|
||||
}
|
||||
if len(data) > maxUploadLimit {
|
||||
s.logger.Warn("virustotal scan skipped: file too large for VT", "filename", filename, "size", len(data))
|
||||
return ScanResult{Status: "skipped"}, nil
|
||||
}
|
||||
|
||||
if sha256Hex == "" {
|
||||
sum := sha256.Sum256(data)
|
||||
sha256Hex = hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
stats, found, err := s.client.lookupFile(ctx, sha256Hex)
|
||||
if err != nil {
|
||||
s.logger.Warn("virustotal lookup failed, skipping scan", "filename", filename, "error", err)
|
||||
return ScanResult{Status: "skipped"}, nil
|
||||
}
|
||||
if found {
|
||||
if statsMalicious(stats) {
|
||||
return ScanResult{}, ErrMalicious
|
||||
}
|
||||
return ScanResult{Status: "clean"}, nil
|
||||
}
|
||||
|
||||
analysisID, err := s.client.uploadFile(ctx, data, safeFilename(filename))
|
||||
if err != nil {
|
||||
s.logger.Warn("virustotal upload failed, skipping scan", "filename", filename, "error", err)
|
||||
return ScanResult{Status: "skipped"}, nil
|
||||
}
|
||||
|
||||
scanCtx, cancel := context.WithTimeout(ctx, s.timeout)
|
||||
defer cancel()
|
||||
|
||||
stats, err = s.client.pollAnalysis(scanCtx, analysisID, s.timeout)
|
||||
if err != nil {
|
||||
s.logger.Warn("virustotal analysis failed, skipping scan", "filename", filename, "error", err)
|
||||
return ScanResult{Status: "skipped"}, nil
|
||||
}
|
||||
|
||||
if statsMalicious(stats) {
|
||||
return ScanResult{}, ErrMalicious
|
||||
}
|
||||
return ScanResult{Status: "clean"}, nil
|
||||
}
|
||||
|
||||
func safeFilename(name string) string {
|
||||
if name == "" {
|
||||
return "upload"
|
||||
}
|
||||
base := name
|
||||
if len(base) > 200 {
|
||||
base = base[:200]
|
||||
}
|
||||
return base
|
||||
}
|
||||
|
||||
// SHA256Hex returns hex-encoded SHA-256 of data.
|
||||
func SHA256Hex(data []byte) string {
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// FormatAnalysisID is a helper for logging.
|
||||
func FormatAnalysisID(id string) string {
|
||||
return fmt.Sprintf("vt:%s", id)
|
||||
}
|
||||
109
internal/virustotal/scan_test.go
Normal file
109
internal/virustotal/scan_test.go
Normal file
@ -0,0 +1,109 @@
|
||||
package virustotal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestScannerLookupMalicious(t *testing.T) {
|
||||
sha := sha256.Sum256([]byte("evil"))
|
||||
shaHex := hex.EncodeToString(sha[:])
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/files/"+shaHex) {
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"attributes": map[string]any{
|
||||
"last_analysis_stats": map[string]any{
|
||||
"malicious": 2,
|
||||
"suspicious": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
http.NotFound(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc := NewScanner("test-key", nil)
|
||||
sc.client.baseURL = srv.URL + "/api/v3"
|
||||
|
||||
_, err := sc.ScanBytes(context.Background(), "evil.bin", []byte("evil"), shaHex)
|
||||
if err == nil {
|
||||
t.Fatal("expected ErrMalicious")
|
||||
}
|
||||
if err != ErrMalicious {
|
||||
t.Fatalf("err = %v, want ErrMalicious", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScannerUploadAndPollClean(t *testing.T) {
|
||||
pollCount := 0
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/"+SHA256Hex([]byte("clean"))):
|
||||
http.NotFound(w, r)
|
||||
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/files"):
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{"id": "analysis-1"},
|
||||
})
|
||||
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/analyses/analysis-1"):
|
||||
pollCount++
|
||||
status := "queued"
|
||||
if pollCount >= 2 {
|
||||
status = "completed"
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||
"data": map[string]any{
|
||||
"attributes": map[string]any{
|
||||
"status": status,
|
||||
"stats": map[string]any{
|
||||
"malicious": 0,
|
||||
"suspicious": 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
default:
|
||||
http.NotFound(w, r)
|
||||
}
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc := NewScanner("test-key", nil)
|
||||
sc.client.baseURL = srv.URL + "/api/v3"
|
||||
|
||||
result, err := sc.ScanBytes(context.Background(), "clean.txt", []byte("clean"), "")
|
||||
if err != nil {
|
||||
t.Fatalf("ScanBytes: %v", err)
|
||||
}
|
||||
if result.Status != "clean" {
|
||||
t.Fatalf("status = %q, want clean", result.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScannerFailOpenOnUnavailable(t *testing.T) {
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "down", http.StatusServiceUnavailable)
|
||||
}))
|
||||
defer srv.Close()
|
||||
|
||||
sc := NewScanner("test-key", nil)
|
||||
sc.client.baseURL = srv.URL + "/api/v3"
|
||||
|
||||
result, err := sc.ScanBytes(context.Background(), "file.bin", []byte("payload"), "")
|
||||
if err != nil {
|
||||
t.Fatalf("expected fail-open, got err %v", err)
|
||||
}
|
||||
if result.Status != "skipped" {
|
||||
t.Fatalf("status = %q, want skipped", result.Status)
|
||||
}
|
||||
}
|
||||
1
migrations/000034_attachment_virus_scan.down.sql
Normal file
1
migrations/000034_attachment_virus_scan.down.sql
Normal file
@ -0,0 +1 @@
|
||||
ALTER TABLE attachments DROP COLUMN IF EXISTS virus_scan_status;
|
||||
3
migrations/000034_attachment_virus_scan.up.sql
Normal file
3
migrations/000034_attachment_virus_scan.up.sql
Normal file
@ -0,0 +1,3 @@
|
||||
ALTER TABLE attachments
|
||||
ADD COLUMN virus_scan_status TEXT NOT NULL DEFAULT 'skipped'
|
||||
CHECK (virus_scan_status IN ('clean', 'skipped', 'malicious'));
|
||||
Loading…
Reference in New Issue
Block a user