feat(scan): add VirusTotal upload antivirus
Admin-stored API key with env fallback; scan drive/mail/IMAP uploads. Fail-open if VT down, 422 on malware; migration for virus_scan_status.
This commit is contained in:
parent
f67c109f2f
commit
b90edf317c
@ -255,3 +255,8 @@ SEARCH_ENGINE=postgres
|
|||||||
# TYPESENSE_URL=http://typesense:8108
|
# TYPESENSE_URL=http://typesense:8108
|
||||||
# TYPESENSE_API_KEY={{TYPESENSE_API_KEY}}
|
# TYPESENSE_API_KEY={{TYPESENSE_API_KEY}}
|
||||||
# TYPESENSE_COLLECTION=ulti
|
# TYPESENSE_COLLECTION=ulti
|
||||||
|
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# VirusTotal (optional env fallback; prefer admin Settings > File policies)
|
||||||
|
# -----------------------------------------------------------------------------
|
||||||
|
# VIRUSTOTAL_API_KEY=
|
||||||
|
|||||||
@ -51,6 +51,7 @@ func defaultOrgPolicy() map[string]any {
|
|||||||
"external_sharing": "authenticated",
|
"external_sharing": "authenticated",
|
||||||
"default_link_expiry_days": 30,
|
"default_link_expiry_days": 30,
|
||||||
"virus_scan_enabled": false,
|
"virus_scan_enabled": false,
|
||||||
|
"virustotal_api_key": "",
|
||||||
"retention_trash_days": 30,
|
"retention_trash_days": 30,
|
||||||
},
|
},
|
||||||
"llm": map[string]any{
|
"llm": map[string]any{
|
||||||
@ -174,6 +175,7 @@ func mergeOrgSecrets(existing, patch map[string]any) map[string]any {
|
|||||||
{"onlyoffice", "jwt_secret"},
|
{"onlyoffice", "jwt_secret"},
|
||||||
{"search", "meilisearch_api_key"},
|
{"search", "meilisearch_api_key"},
|
||||||
{"search", "typesense_api_key"},
|
{"search", "typesense_api_key"},
|
||||||
|
{"file_policies", "virustotal_api_key"},
|
||||||
}
|
}
|
||||||
for _, p := range secretPaths {
|
for _, p := range secretPaths {
|
||||||
patchSection, _ := patch[p.section].(map[string]any)
|
patchSection, _ := patch[p.section].(map[string]any)
|
||||||
@ -296,6 +298,7 @@ func maskOrgPolicy(policy map[string]any) map[string]any {
|
|||||||
maskStringField(cloned, "onlyoffice", "jwt_secret")
|
maskStringField(cloned, "onlyoffice", "jwt_secret")
|
||||||
maskStringField(cloned, "search", "meilisearch_api_key")
|
maskStringField(cloned, "search", "meilisearch_api_key")
|
||||||
maskStringField(cloned, "search", "typesense_api_key")
|
maskStringField(cloned, "search", "typesense_api_key")
|
||||||
|
maskStringField(cloned, "file_policies", "virustotal_api_key")
|
||||||
if llm, ok := cloned["llm"].(map[string]any); ok {
|
if llm, ok := cloned["llm"].(map[string]any); ok {
|
||||||
if providers, ok := llm["providers"].([]any); ok {
|
if providers, ok := llm["providers"].([]any); ok {
|
||||||
for i, p := range providers {
|
for i, p := range providers {
|
||||||
@ -374,6 +377,9 @@ func buildOrgSecretsStatus(policy map[string]any, cfg *config.Config) map[string
|
|||||||
"typesense_api_key": map[string]any{
|
"typesense_api_key": map[string]any{
|
||||||
"configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "",
|
"configured": secretConfigured(policy, "search", "typesense_api_key") || strings.TrimSpace(cfg.TypesenseKey) != "",
|
||||||
},
|
},
|
||||||
|
"virustotal_api_key": map[string]any{
|
||||||
|
"configured": secretConfigured(policy, "file_policies", "virustotal_api_key") || strings.TrimSpace(cfg.VirusTotalAPIKey) != "",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import (
|
|||||||
"path"
|
"path"
|
||||||
|
|
||||||
"github.com/ultisuite/ulti-backend/internal/automation"
|
"github.com/ultisuite/ulti-backend/internal/automation"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||||
)
|
)
|
||||||
@ -13,6 +14,10 @@ func (s *Service) SetAutomation(d driveAutomation) {
|
|||||||
s.automation = d
|
s.automation = d
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
|
||||||
|
s.scanner = scanner
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) {
|
func (s *Service) afterDriveFileEvent(ctx context.Context, externalUserID string, trigger rules.TriggerType, filePath string, isFolder bool) {
|
||||||
normalized := nextcloud.NormalizeClientPath(filePath)
|
normalized := nextcloud.NormalizeClientPath(filePath)
|
||||||
s.notifyFileChanged(externalUserID, normalized)
|
s.notifyFileChanged(externalUserID, normalized)
|
||||||
|
|||||||
@ -733,6 +733,8 @@ func writeDriveError(w http.ResponseWriter, r *http.Request, err error) {
|
|||||||
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil)
|
apiresponse.WriteError(w, r, http.StatusForbidden, apiresponse.CodeAuthForbidden, "forbidden", nil)
|
||||||
case errors.Is(err, ErrQuotaExceeded):
|
case errors.Is(err, ErrQuotaExceeded):
|
||||||
apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil)
|
apiresponse.WriteError(w, r, http.StatusInsufficientStorage, "drive.quota_exceeded", "quota exceeded", nil)
|
||||||
|
case errors.Is(err, ErrMalware):
|
||||||
|
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "drive.malware_detected", "malware detected in file", nil)
|
||||||
case errors.Is(err, ErrInvalid):
|
case errors.Is(err, ErrInvalid):
|
||||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil)
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "invalid request body", nil)
|
||||||
default:
|
default:
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
package drive
|
package drive
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -17,7 +20,18 @@ func (s *Service) UploadPublicShare(ctx context.Context, token, filePath, passwo
|
|||||||
if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) {
|
if !nextcloud.PublicShareCanCreate(perms) && !nextcloud.PublicShareCanUpdate(perms) {
|
||||||
return ErrForbidden
|
return ErrForbidden
|
||||||
}
|
}
|
||||||
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, body, contentType)); err != nil {
|
reader := body
|
||||||
|
if s.scanner != nil {
|
||||||
|
data, _, err := s.scanner.ScanReader(ctx, filePath, body, -1)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, filescan.ErrMalicious) {
|
||||||
|
return ErrMalware
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reader = bytes.NewReader(data)
|
||||||
|
}
|
||||||
|
if err := mapPublicShareError(s.nc.UploadPublicShare(ctx, token, filePath, password, reader, contentType)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.recordPublicShareAccess(ctx, token)
|
s.recordPublicShareAccess(ctx, token)
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||||
"github.com/ultisuite/ulti-backend/internal/automation"
|
"github.com/ultisuite/ulti-backend/internal/automation"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||||
"github.com/ultisuite/ulti-backend/internal/publicshare"
|
"github.com/ultisuite/ulti-backend/internal/publicshare"
|
||||||
@ -30,6 +31,7 @@ var (
|
|||||||
ErrForbidden = errors.New("forbidden")
|
ErrForbidden = errors.New("forbidden")
|
||||||
ErrQuotaExceeded = errors.New("quota exceeded")
|
ErrQuotaExceeded = errors.New("quota exceeded")
|
||||||
ErrInvalid = errors.New("invalid request")
|
ErrInvalid = errors.New("invalid request")
|
||||||
|
ErrMalware = errors.New("malware detected")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
@ -37,6 +39,7 @@ type Service struct {
|
|||||||
hub *realtime.Hub
|
hub *realtime.Hub
|
||||||
db *pgxpool.Pool
|
db *pgxpool.Pool
|
||||||
automation driveAutomation
|
automation driveAutomation
|
||||||
|
scanner *filescan.Scanner
|
||||||
maxUploadBytes int64
|
maxUploadBytes int64
|
||||||
quotaReserveByte int64
|
quotaReserveByte int64
|
||||||
}
|
}
|
||||||
@ -184,7 +187,18 @@ func (s *Service) Upload(ctx context.Context, userID, path string, body io.Reade
|
|||||||
if err := s.ensureQuota(ctx, userID, contentLength); err != nil {
|
if err := s.ensureQuota(ctx, userID, contentLength); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return mapDriveError(s.nc.Upload(ctx, userID, path, body, contentType))
|
reader := body
|
||||||
|
if s.scanner != nil {
|
||||||
|
data, _, err := s.scanner.ScanReader(ctx, path, body, contentLength)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, filescan.ErrMalicious) {
|
||||||
|
return ErrMalware
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
reader = bytes.NewReader(data)
|
||||||
|
}
|
||||||
|
return mapDriveError(s.nc.Upload(ctx, userID, path, reader, contentType))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error {
|
func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath string, chunk ChunkUpload, body io.Reader, contentType string) error {
|
||||||
@ -202,6 +216,31 @@ func (s *Service) UploadChunk(ctx context.Context, userID, uploadID, targetPath
|
|||||||
if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil {
|
if err := mapDriveError(s.nc.AssembleChunks(ctx, userID, uploadID, targetPath, chunk.TotalSize)); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if s.scanner != nil {
|
||||||
|
if err := s.scanAssembledUpload(ctx, userID, targetPath); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) scanAssembledUpload(ctx context.Context, userID, targetPath string) error {
|
||||||
|
body, _, err := s.nc.Download(ctx, userID, targetPath)
|
||||||
|
if err != nil {
|
||||||
|
return mapDriveError(err)
|
||||||
|
}
|
||||||
|
defer body.Close()
|
||||||
|
|
||||||
|
data, scanResult, err := s.scanner.ScanReader(ctx, targetPath, body, -1)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, filescan.ErrMalicious) {
|
||||||
|
_ = s.nc.Delete(ctx, userID, targetPath)
|
||||||
|
return ErrMalware
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = data
|
||||||
|
_ = scanResult
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package mail
|
package mail
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
@ -39,7 +40,7 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
|
|||||||
}
|
}
|
||||||
|
|
||||||
rows, err := s.db.Query(ctx, `
|
rows, err := s.db.Query(ctx, `
|
||||||
SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, '')
|
SELECT id, filename, content_type, size, content_id, is_inline, COALESCE(drive_path, ''), virus_scan_status
|
||||||
FROM attachments WHERE message_id = $1
|
FROM attachments WHERE message_id = $1
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
`, messageID)
|
`, messageID)
|
||||||
@ -50,15 +51,16 @@ func (s *Service) ListMessageAttachments(ctx context.Context, externalID, messag
|
|||||||
|
|
||||||
out := make([]map[string]any, 0)
|
out := make([]map[string]any, 0)
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var id, filename, contentType, contentID, drivePath string
|
var id, filename, contentType, contentID, drivePath, virusScanStatus string
|
||||||
var size int64
|
var size int64
|
||||||
var isInline bool
|
var isInline bool
|
||||||
if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath); err != nil {
|
if err := rows.Scan(&id, &filename, &contentType, &size, &contentID, &isInline, &drivePath, &virusScanStatus); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
entry := map[string]any{
|
entry := map[string]any{
|
||||||
"id": id, "filename": filename, "content_type": contentType,
|
"id": id, "filename": filename, "content_type": contentType,
|
||||||
"size": size, "is_inline": isInline,
|
"size": size, "is_inline": isInline,
|
||||||
|
"virus_scan_status": virusScanStatus,
|
||||||
}
|
}
|
||||||
if contentID != "" {
|
if contentID != "" {
|
||||||
entry["content_id"] = contentID
|
entry["content_id"] = contentID
|
||||||
@ -142,16 +144,28 @@ func (s *Service) UploadMessageAttachment(
|
|||||||
}
|
}
|
||||||
|
|
||||||
objectKey := storage.MessageObjectKey(userID, messageID, filename)
|
objectKey := storage.MessageObjectKey(userID, messageID, filename)
|
||||||
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil {
|
scanStatus := "skipped"
|
||||||
|
putReader := reader
|
||||||
|
putSize := size
|
||||||
|
if s.scanner != nil {
|
||||||
|
data, result, err := s.scanner.ScanReader(ctx, filename, reader, size)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
scanStatus = result.Status
|
||||||
|
putReader = bytes.NewReader(data)
|
||||||
|
putSize = int64(len(data))
|
||||||
|
}
|
||||||
|
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
var id string
|
var id string
|
||||||
err = s.db.QueryRow(ctx, `
|
err = s.db.QueryRow(ctx, `
|
||||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline)
|
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
RETURNING id
|
RETURNING id
|
||||||
`, messageID, filename, contentType, size, s.storageBucket(), objectKey, contentID, isInline).Scan(&id)
|
`, messageID, filename, contentType, putSize, s.storageBucket(), objectKey, contentID, isInline, scanStatus).Scan(&id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = s.storage.Delete(ctx, objectKey)
|
_ = s.storage.Delete(ctx, objectKey)
|
||||||
return "", err
|
return "", err
|
||||||
@ -224,7 +238,17 @@ func (s *Service) UploadDraftAttachment(
|
|||||||
}
|
}
|
||||||
|
|
||||||
objectKey := storage.DraftObjectKey(userID, draftID, filename)
|
objectKey := storage.DraftObjectKey(userID, draftID, filename)
|
||||||
if err := s.storage.Put(ctx, objectKey, reader, size, contentType); err != nil {
|
putReader := reader
|
||||||
|
putSize := size
|
||||||
|
if s.scanner != nil {
|
||||||
|
data, _, err := s.scanner.ScanReader(ctx, filename, reader, size)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
putReader = bytes.NewReader(data)
|
||||||
|
putSize = int64(len(data))
|
||||||
|
}
|
||||||
|
if err := s.storage.Put(ctx, objectKey, putReader, putSize, contentType); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -237,13 +261,13 @@ func (s *Service) UploadDraftAttachment(
|
|||||||
for _, ref := range refs {
|
for _, ref := range refs {
|
||||||
totalSize += ref.Size
|
totalSize += ref.Size
|
||||||
}
|
}
|
||||||
if err := limits.ValidateAttachmentQuota(len(refs), totalSize, size); err != nil {
|
if err := limits.ValidateAttachmentQuota(len(refs), totalSize, putSize); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
attID := uuid.NewString()
|
attID := uuid.NewString()
|
||||||
refs = append(refs, draftAttachmentRef{
|
refs = append(refs, draftAttachmentRef{
|
||||||
ID: attID, Filename: filename, ContentType: contentType, Size: size,
|
ID: attID, Filename: filename, ContentType: contentType, Size: putSize,
|
||||||
S3Bucket: s.storageBucket(), S3Key: objectKey,
|
S3Bucket: s.storageBucket(), S3Key: objectKey,
|
||||||
ContentID: contentID, IsInline: isInline,
|
ContentID: contentID, IsInline: isInline,
|
||||||
})
|
})
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/jackc/pgx/v5"
|
"github.com/jackc/pgx/v5"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -15,6 +16,10 @@ func (s *Service) SetDriveUploader(uploader DriveUploader) {
|
|||||||
s.driveUploader = uploader
|
s.driveUploader = uploader
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) SetFileScanner(scanner *filescan.Scanner) {
|
||||||
|
s.scanner = scanner
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) SaveAttachmentToDrive(
|
func (s *Service) SaveAttachmentToDrive(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
externalID, email, sub, displayName, messageID, attachmentID, folderPath string,
|
externalID, email, sub, displayName, messageID, attachmentID, folderPath string,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/ultisuite/ulti-backend/internal/api/mail/sendguard"
|
"github.com/ultisuite/ulti-backend/internal/api/mail/sendguard"
|
||||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||||
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
||||||
@ -42,6 +43,13 @@ func (h *Handler) SetDriveUploader(uploader DriveUploader) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetFileScanner wires VirusTotal scanning for mail attachment uploads.
|
||||||
|
func (h *Handler) SetFileScanner(scanner *filescan.Scanner) {
|
||||||
|
if s, ok := h.svc.(*Service); ok {
|
||||||
|
s.SetFileScanner(scanner)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func NewHandlerWithService(svc ServiceAPI) *Handler {
|
func NewHandlerWithService(svc ServiceAPI) *Handler {
|
||||||
return &Handler{
|
return &Handler{
|
||||||
svc: svc,
|
svc: svc,
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
"github.com/ultisuite/ulti-backend/internal/api/apivalidate"
|
||||||
driveapi "github.com/ultisuite/ulti-backend/internal/api/drive"
|
driveapi "github.com/ultisuite/ulti-backend/internal/api/drive"
|
||||||
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
"github.com/ultisuite/ulti-backend/internal/api/middleware"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -342,6 +343,9 @@ func writeAttachmentUploadError(w http.ResponseWriter, r *http.Request, err erro
|
|||||||
case errors.Is(err, limits.ErrTooManyAttachments):
|
case errors.Is(err, limits.ErrTooManyAttachments):
|
||||||
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil)
|
apiresponse.WriteError(w, r, http.StatusBadRequest, apiresponse.CodeInvalidRequest, "too many attachments", nil)
|
||||||
return true
|
return true
|
||||||
|
case errors.Is(err, ErrMalware), errors.Is(err, filescan.ErrMalicious):
|
||||||
|
apiresponse.WriteError(w, r, http.StatusUnprocessableEntity, "mail.malware_detected", "malware detected in attachment", nil)
|
||||||
|
return true
|
||||||
default:
|
default:
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
"github.com/ultisuite/ulti-backend/internal/api/query"
|
"github.com/ultisuite/ulti-backend/internal/api/query"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/imap"
|
"github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
"github.com/ultisuite/ulti-backend/internal/mail/sanitize"
|
||||||
@ -28,6 +29,7 @@ var (
|
|||||||
ErrInvalidAccountCredentials = errors.New("account credentials invalid")
|
ErrInvalidAccountCredentials = errors.New("account credentials invalid")
|
||||||
ErrInvalidFolderScope = errors.New("invalid folder scope")
|
ErrInvalidFolderScope = errors.New("invalid folder scope")
|
||||||
ErrFolderHasChildren = errors.New("folder has children")
|
ErrFolderHasChildren = errors.New("folder has children")
|
||||||
|
ErrMalware = filescan.ErrMalicious
|
||||||
)
|
)
|
||||||
|
|
||||||
type Service struct {
|
type Service struct {
|
||||||
@ -37,6 +39,7 @@ type Service struct {
|
|||||||
storage *storage.Client
|
storage *storage.Client
|
||||||
attachmentsBucket string
|
attachmentsBucket string
|
||||||
driveUploader DriveUploader
|
driveUploader DriveUploader
|
||||||
|
scanner *filescan.Scanner
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -116,6 +116,9 @@ type Config struct {
|
|||||||
TypesenseKey string
|
TypesenseKey string
|
||||||
TypesenseCollection string
|
TypesenseCollection string
|
||||||
|
|
||||||
|
// VirusTotal (optional env fallback for org file_policies.virustotal_api_key)
|
||||||
|
VirusTotalAPIKey string
|
||||||
|
|
||||||
// Observability
|
// Observability
|
||||||
HealthNextcloudURL string
|
HealthNextcloudURL string
|
||||||
HealthImmichURL string
|
HealthImmichURL string
|
||||||
@ -221,6 +224,8 @@ func Load() (*Config, error) {
|
|||||||
TypesenseKey: secrets.Env("TYPESENSE_API_KEY"),
|
TypesenseKey: secrets.Env("TYPESENSE_API_KEY"),
|
||||||
TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"),
|
TypesenseCollection: envOrDefault("TYPESENSE_COLLECTION", "ulti"),
|
||||||
|
|
||||||
|
VirusTotalAPIKey: secrets.Env("VIRUSTOTAL_API_KEY"),
|
||||||
|
|
||||||
HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")),
|
HealthNextcloudURL: envOrDefault("HEALTH_NEXTCLOUD_URL", joinURL(envOrDefault("NEXTCLOUD_URL", "http://nextcloud:80"), "/status.php")),
|
||||||
HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")),
|
HealthImmichURL: envOrDefault("HEALTH_IMMICH_URL", joinURL(envOrDefault("IMMICH_API_URL", "http://immich-server:2283/api"), "/server-info/ping")),
|
||||||
HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))),
|
HealthJitsiURL: envOrDefault("HEALTH_JITSI_URL", defaultHealthJitsiURL(envOrDefault("JITSI_PUBLIC_URL", "https://localhost/meet"))),
|
||||||
|
|||||||
144
internal/filescan/scanner.go
Normal file
144
internal/filescan/scanner.go
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
package filescan
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/virustotal"
|
||||||
|
)
|
||||||
|
|
||||||
|
const memoryBufferLimit = 32 * 1024 * 1024
|
||||||
|
|
||||||
|
// Result holds scan outcome for persistence.
|
||||||
|
type Result struct {
|
||||||
|
Status string // clean | skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
// PolicyLoader supplies org file policies at runtime.
|
||||||
|
type PolicyLoader interface {
|
||||||
|
FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error)
|
||||||
|
ScanEnabled(ctx context.Context) (bool, string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scanner coordinates org policy and VirusTotal scanning.
|
||||||
|
type Scanner struct {
|
||||||
|
policies PolicyLoader
|
||||||
|
logger *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScanner(policies PolicyLoader, logger *slog.Logger) *Scanner {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return &Scanner{policies: policies, logger: logger}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanReader reads all bytes from r (up to maxBytes), optionally scans, returns data for storage.
|
||||||
|
func (s *Scanner) ScanReader(ctx context.Context, filename string, r io.Reader, size int64) ([]byte, Result, error) {
|
||||||
|
fp, err := s.policies.FilePolicies(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Result{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
maxBytes := fp.MaxUploadBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 512 * 1024 * 1024
|
||||||
|
}
|
||||||
|
if maxBytes > virustotalMaxBytes() {
|
||||||
|
maxBytes = virustotalMaxBytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := readLimited(r, size, maxBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, Result{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := fp.VirusScanEnabled && fp.VirusTotalAPIKey != ""
|
||||||
|
if !enabled {
|
||||||
|
return data, Result{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vt := virustotal.NewScanner(fp.VirusTotalAPIKey, s.logger)
|
||||||
|
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, virustotal.ErrMalicious) {
|
||||||
|
return nil, Result{Status: "malicious"}, virustotal.ErrMalicious
|
||||||
|
}
|
||||||
|
return nil, Result{}, err
|
||||||
|
}
|
||||||
|
return data, Result{Status: scanResult.Status}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanBytes scans pre-loaded bytes when policy is already known.
|
||||||
|
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte) (Result, error) {
|
||||||
|
enabled, apiKey, err := s.policies.ScanEnabled(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
if !enabled {
|
||||||
|
return Result{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
vt := virustotal.NewScanner(apiKey, s.logger)
|
||||||
|
scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data))
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, virustotal.ErrMalicious) {
|
||||||
|
return Result{Status: "malicious"}, virustotal.ErrMalicious
|
||||||
|
}
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
return Result{Status: scanResult.Status}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func readLimited(r io.Reader, size int64, maxBytes int64) ([]byte, error) {
|
||||||
|
if size >= 0 && size <= memoryBufferLimit {
|
||||||
|
limited := io.LimitReader(r, maxBytes+1)
|
||||||
|
data, err := io.ReadAll(limited)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if int64(len(data)) > maxBytes {
|
||||||
|
return nil, errors.New("file exceeds max upload size")
|
||||||
|
}
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
tmp, err := os.CreateTemp("", "ulti-scan-*")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tmpPath := tmp.Name()
|
||||||
|
defer os.Remove(tmpPath)
|
||||||
|
|
||||||
|
written, err := io.Copy(tmp, io.LimitReader(r, maxBytes+1))
|
||||||
|
if err != nil {
|
||||||
|
tmp.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := tmp.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if written > maxBytes {
|
||||||
|
return nil, errors.New("file exceeds max upload size")
|
||||||
|
}
|
||||||
|
|
||||||
|
return os.ReadFile(filepath.Clean(tmpPath))
|
||||||
|
}
|
||||||
|
|
||||||
|
func virustotalMaxBytes() int64 {
|
||||||
|
return 650 * 1024 * 1024
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrMalicious re-exports virustotal.ErrMalicious for handlers.
|
||||||
|
var ErrMalicious = virustotal.ErrMalicious
|
||||||
|
|
||||||
|
// NopReader helper for tests.
|
||||||
|
func NopReader(data []byte) io.Reader {
|
||||||
|
return bytes.NewReader(data)
|
||||||
|
}
|
||||||
40
internal/filescan/scanner_test.go
Normal file
40
internal/filescan/scanner_test.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package filescan
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type stubPolicyLoader struct {
|
||||||
|
fp orgpolicy.FilePolicies
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s stubPolicyLoader) FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) {
|
||||||
|
return s.fp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s stubPolicyLoader) ScanEnabled(ctx context.Context) (bool, string, error) {
|
||||||
|
if s.fp.VirusScanEnabled && s.fp.VirusTotalAPIKey != "" {
|
||||||
|
return true, s.fp.VirusTotalAPIKey, nil
|
||||||
|
}
|
||||||
|
return false, "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScannerDisabledReturnsSkipped(t *testing.T) {
|
||||||
|
scanner := &Scanner{
|
||||||
|
policies: stubPolicyLoader{fp: orgpolicy.FilePolicies{VirusScanEnabled: false}},
|
||||||
|
}
|
||||||
|
data, result, err := scanner.ScanReader(context.Background(), "test.txt", bytes.NewReader([]byte("hello")), 5)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ScanReader: %v", err)
|
||||||
|
}
|
||||||
|
if result.Status != "skipped" {
|
||||||
|
t.Fatalf("status = %q, want skipped", result.Status)
|
||||||
|
}
|
||||||
|
if string(data) != "hello" {
|
||||||
|
t.Fatalf("data = %q", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -57,3 +57,60 @@ func TestAdminOrgSettings(t *testing.T) {
|
|||||||
t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"])
|
t.Fatalf("default_mail_gib = %v, want 10", updatedStorage["default_mail_gib"])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAdminOrgSettingsVirusTotalSecret(t *testing.T) {
|
||||||
|
h := integrationtest.RequireHarness(t)
|
||||||
|
adminClient, _ := integrationtest.RequireAdminClient(t, h)
|
||||||
|
|
||||||
|
putResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
|
||||||
|
"policy": map[string]any{
|
||||||
|
"file_policies": map[string]any{
|
||||||
|
"virus_scan_enabled": true,
|
||||||
|
"virustotal_api_key": "vt-test-secret-key",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
integrationtest.FailIf(err, t, "put org settings with virustotal key")
|
||||||
|
integrationtest.FailUnlessStatus(t, putResp, 200)
|
||||||
|
|
||||||
|
var afterPut map[string]any
|
||||||
|
integrationtest.DecodeJSON(t, putResp, &afterPut)
|
||||||
|
secrets, ok := afterPut["secrets"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing secrets: %#v", afterPut)
|
||||||
|
}
|
||||||
|
vtSecret, ok := secrets["virustotal_api_key"].(map[string]any)
|
||||||
|
if !ok || vtSecret["configured"] != true {
|
||||||
|
t.Fatalf("virustotal_api_key not configured: %#v", secrets)
|
||||||
|
}
|
||||||
|
filePolicies, ok := afterPut["policy"].(map[string]any)["file_policies"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing file_policies")
|
||||||
|
}
|
||||||
|
if filePolicies["virustotal_api_key"] != "" {
|
||||||
|
t.Fatalf("virustotal_api_key should be masked on GET, got %q", filePolicies["virustotal_api_key"])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty patch must preserve stored secret.
|
||||||
|
preserveResp, err := adminClient.Put("/api/v1/admin/org/settings", map[string]any{
|
||||||
|
"policy": map[string]any{
|
||||||
|
"file_policies": map[string]any{
|
||||||
|
"virus_scan_enabled": false,
|
||||||
|
"virustotal_api_key": "",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
integrationtest.FailIf(err, t, "preserve virustotal key")
|
||||||
|
integrationtest.FailUnlessStatus(t, preserveResp, 200)
|
||||||
|
|
||||||
|
var preserved map[string]any
|
||||||
|
integrationtest.DecodeJSON(t, preserveResp, &preserved)
|
||||||
|
preservedSecrets, ok := preserved["secrets"].(map[string]any)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("missing secrets after preserve")
|
||||||
|
}
|
||||||
|
vtPreserved, ok := preservedSecrets["virustotal_api_key"].(map[string]any)
|
||||||
|
if !ok || vtPreserved["configured"] != true {
|
||||||
|
t.Fatalf("virustotal key not preserved: %#v", preservedSecrets)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/jackc/pgx/v5/pgxpool"
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
"github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/connect"
|
"github.com/ultisuite/ulti-backend/internal/mail/connect"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
mailoauth "github.com/ultisuite/ulti-backend/internal/mail/oauth"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
"github.com/ultisuite/ulti-backend/internal/mail/limits"
|
||||||
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
"github.com/ultisuite/ulti-backend/internal/mail/rules"
|
||||||
@ -38,6 +39,7 @@ type SyncDeps struct {
|
|||||||
Rules *rules.Engine
|
Rules *rules.Engine
|
||||||
Automation MailAutomation
|
Automation MailAutomation
|
||||||
Hub *realtime.Hub
|
Hub *realtime.Hub
|
||||||
|
FileScanner *filescan.Scanner
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncWorker struct {
|
type SyncWorker struct {
|
||||||
@ -48,6 +50,7 @@ type SyncWorker struct {
|
|||||||
oauth *mailoauth.Service
|
oauth *mailoauth.Service
|
||||||
storage *storage.Client
|
storage *storage.Client
|
||||||
attachBucket string
|
attachBucket string
|
||||||
|
scanner *filescan.Scanner
|
||||||
pipeline *syncPipeline
|
pipeline *syncPipeline
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -60,6 +63,7 @@ func NewSyncWorker(db *pgxpool.Pool, interval time.Duration, credManager *creden
|
|||||||
oauth: oauthSvc,
|
oauth: oauthSvc,
|
||||||
storage: deps.Storage,
|
storage: deps.Storage,
|
||||||
attachBucket: deps.AttachBucket,
|
attachBucket: deps.AttachBucket,
|
||||||
|
scanner: deps.FileScanner,
|
||||||
pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub),
|
pipeline: newSyncPipeline(db, deps.Rules, deps.Automation, deps.Hub),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -575,14 +579,30 @@ func (w *SyncWorker) storeAttachments(ctx context.Context, userID, messageID str
|
|||||||
if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) {
|
if messageExisted && attachmentPartExists(ctx, w.db, messageID, part) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
scanStatus := "skipped"
|
||||||
|
partData := part.Data
|
||||||
|
if w.scanner != nil {
|
||||||
|
result, err := w.scanner.ScanBytes(ctx, part.Filename, part.Data)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, filescan.ErrMalicious) {
|
||||||
|
w.logger.Warn("imap attachment skipped: malware detected",
|
||||||
|
"message_id", messageID, "filename", part.Filename)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
scanStatus = result.Status
|
||||||
|
}
|
||||||
|
|
||||||
objectKey := storage.MessageObjectKey(userID, messageID, part.Filename)
|
objectKey := storage.MessageObjectKey(userID, messageID, part.Filename)
|
||||||
if err := w.storage.Put(ctx, objectKey, bytes.NewReader(part.Data), int64(len(part.Data)), part.ContentType); err != nil {
|
if err := w.storage.Put(ctx, objectKey, bytes.NewReader(partData), int64(len(partData)), part.ContentType); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err := w.db.Exec(ctx, `
|
_, err := w.db.Exec(ctx, `
|
||||||
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline)
|
INSERT INTO attachments (message_id, filename, content_type, size, s3_bucket, s3_key, content_id, is_inline, virus_scan_status)
|
||||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||||
`, messageID, part.Filename, part.ContentType, len(part.Data), bucket, objectKey, part.ContentID, part.IsInline)
|
`, messageID, part.Filename, part.ContentType, len(partData), bucket, objectKey, part.ContentID, part.IsInline, scanStatus)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = w.storage.Delete(ctx, objectKey)
|
_ = w.storage.Delete(ctx, objectKey)
|
||||||
return err
|
return err
|
||||||
|
|||||||
135
internal/orgpolicy/loader.go
Normal file
135
internal/orgpolicy/loader.go
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
package orgpolicy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/jackc/pgx/v5"
|
||||||
|
"github.com/jackc/pgx/v5/pgxpool"
|
||||||
|
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
const orgSettingsSingletonID = 1
|
||||||
|
|
||||||
|
const defaultMaxUploadMiB = 512
|
||||||
|
|
||||||
|
// FilePolicies holds runtime file upload policy from org settings.
|
||||||
|
type FilePolicies struct {
|
||||||
|
VirusScanEnabled bool
|
||||||
|
VirusTotalAPIKey string
|
||||||
|
MaxUploadBytes int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type Loader struct {
|
||||||
|
db *pgxpool.Pool
|
||||||
|
cfg *config.Config
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
cached FilePolicies
|
||||||
|
cachedAt time.Time
|
||||||
|
ttl time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLoader(db *pgxpool.Pool, cfg *config.Config) *Loader {
|
||||||
|
return &Loader{
|
||||||
|
db: db,
|
||||||
|
cfg: cfg,
|
||||||
|
ttl: 60 * time.Second,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loader) FilePolicies(ctx context.Context) (FilePolicies, error) {
|
||||||
|
l.mu.Lock()
|
||||||
|
if !l.cachedAt.IsZero() && time.Since(l.cachedAt) < l.ttl {
|
||||||
|
out := l.cached
|
||||||
|
l.mu.Unlock()
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
fp, err := l.loadFilePolicies(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return FilePolicies{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.mu.Lock()
|
||||||
|
l.cached = fp
|
||||||
|
l.cachedAt = time.Now()
|
||||||
|
l.mu.Unlock()
|
||||||
|
return fp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loader) loadFilePolicies(ctx context.Context) (FilePolicies, error) {
|
||||||
|
var raw []byte
|
||||||
|
err := l.db.QueryRow(ctx, `
|
||||||
|
SELECT settings FROM org_settings WHERE id = $1
|
||||||
|
`, orgSettingsSingletonID).Scan(&raw)
|
||||||
|
if err != nil && err != pgx.ErrNoRows {
|
||||||
|
return FilePolicies{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
stored := map[string]any{}
|
||||||
|
if len(raw) > 0 {
|
||||||
|
if err := json.Unmarshal(raw, &stored); err != nil {
|
||||||
|
return FilePolicies{}, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filePolicies, _ := stored["file_policies"].(map[string]any)
|
||||||
|
enabled := boolValue(filePolicies["virus_scan_enabled"])
|
||||||
|
apiKey := stringValue(filePolicies["virustotal_api_key"])
|
||||||
|
if strings.TrimSpace(apiKey) == "" && l.cfg != nil {
|
||||||
|
apiKey = strings.TrimSpace(l.cfg.VirusTotalAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
maxMiB := int64(defaultMaxUploadMiB)
|
||||||
|
switch v := filePolicies["max_upload_mib"].(type) {
|
||||||
|
case float64:
|
||||||
|
if v > 0 {
|
||||||
|
maxMiB = int64(v)
|
||||||
|
}
|
||||||
|
case int:
|
||||||
|
if v > 0 {
|
||||||
|
maxMiB = int64(v)
|
||||||
|
}
|
||||||
|
case int64:
|
||||||
|
if v > 0 {
|
||||||
|
maxMiB = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return FilePolicies{
|
||||||
|
VirusScanEnabled: enabled,
|
||||||
|
VirusTotalAPIKey: apiKey,
|
||||||
|
MaxUploadBytes: maxMiB * 1024 * 1024,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Loader) ScanEnabled(ctx context.Context) (bool, string, error) {
|
||||||
|
fp, err := l.FilePolicies(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return false, "", err
|
||||||
|
}
|
||||||
|
if !fp.VirusScanEnabled || strings.TrimSpace(fp.VirusTotalAPIKey) == "" {
|
||||||
|
return false, "", nil
|
||||||
|
}
|
||||||
|
return true, fp.VirusTotalAPIKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolValue(v any) bool {
|
||||||
|
switch t := v.(type) {
|
||||||
|
case bool:
|
||||||
|
return t
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringValue(v any) string {
|
||||||
|
s, _ := v.(string)
|
||||||
|
return s
|
||||||
|
}
|
||||||
@ -32,6 +32,7 @@ import (
|
|||||||
"github.com/ultisuite/ulti-backend/internal/authentik"
|
"github.com/ultisuite/ulti-backend/internal/authentik"
|
||||||
"github.com/ultisuite/ulti-backend/internal/auth"
|
"github.com/ultisuite/ulti-backend/internal/auth"
|
||||||
"github.com/ultisuite/ulti-backend/internal/config"
|
"github.com/ultisuite/ulti-backend/internal/config"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/filescan"
|
||||||
"github.com/ultisuite/ulti-backend/internal/httpcors"
|
"github.com/ultisuite/ulti-backend/internal/httpcors"
|
||||||
mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
mailcredentials "github.com/ultisuite/ulti-backend/internal/mail/credentials"
|
||||||
imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap"
|
imapsync "github.com/ultisuite/ulti-backend/internal/mail/imap"
|
||||||
@ -43,6 +44,7 @@ import (
|
|||||||
"github.com/ultisuite/ulti-backend/internal/meet"
|
"github.com/ultisuite/ulti-backend/internal/meet"
|
||||||
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
"github.com/ultisuite/ulti-backend/internal/nextcloud"
|
||||||
"github.com/ultisuite/ulti-backend/internal/observability"
|
"github.com/ultisuite/ulti-backend/internal/observability"
|
||||||
|
"github.com/ultisuite/ulti-backend/internal/orgpolicy"
|
||||||
"github.com/ultisuite/ulti-backend/internal/photos"
|
"github.com/ultisuite/ulti-backend/internal/photos"
|
||||||
"github.com/ultisuite/ulti-backend/internal/realtime"
|
"github.com/ultisuite/ulti-backend/internal/realtime"
|
||||||
"github.com/ultisuite/ulti-backend/internal/search"
|
"github.com/ultisuite/ulti-backend/internal/search"
|
||||||
@ -203,6 +205,9 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
|||||||
RedirectURL: oauthRedirect,
|
RedirectURL: oauthRedirect,
|
||||||
}, rdb)
|
}, rdb)
|
||||||
|
|
||||||
|
orgPolicyLoader := orgpolicy.NewLoader(pool, cfg)
|
||||||
|
fileScanner := filescan.NewScanner(orgPolicyLoader, slog.Default())
|
||||||
|
|
||||||
var syncWorker *imapsync.SyncWorker
|
var syncWorker *imapsync.SyncWorker
|
||||||
if !opts.WithoutWorkers {
|
if !opts.WithoutWorkers {
|
||||||
syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{
|
syncWorker = imapsync.NewSyncWorker(pool, cfg.MailSyncInterval, credentialManager, mailOAuthSvc, imapsync.SyncDeps{
|
||||||
@ -210,6 +215,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
|||||||
AttachBucket: cfg.MailAttachmentsBucket,
|
AttachBucket: cfg.MailAttachmentsBucket,
|
||||||
Automation: autoDispatcher,
|
Automation: autoDispatcher,
|
||||||
Hub: hub,
|
Hub: hub,
|
||||||
|
FileScanner: fileScanner,
|
||||||
})
|
})
|
||||||
go syncWorker.Start(workerCtx)
|
go syncWorker.Start(workerCtx)
|
||||||
}
|
}
|
||||||
@ -229,6 +235,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
|||||||
|
|
||||||
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
|
sendRateLimiter := sendguard.NewRateLimiter(cfg.MailSendRatePerMinute, cfg.MailSendBurst)
|
||||||
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
|
mailHandler := mailapi.NewHandler(pool, auditLogger, credentialManager, attachmentStorage, cfg.MailAttachmentsBucket, sendRateLimiter, mailOAuthSvc, cfg.MailAppURL, sender)
|
||||||
|
mailHandler.SetFileScanner(fileScanner)
|
||||||
if syncWorker != nil {
|
if syncWorker != nil {
|
||||||
mailHandler.SetAccountSync(syncWorker)
|
mailHandler.SetAccountSync(syncWorker)
|
||||||
}
|
}
|
||||||
@ -262,6 +269,7 @@ func New(ctx context.Context, cfg *config.Config, opts Options) (*App, error) {
|
|||||||
if ncClient != nil {
|
if ncClient != nil {
|
||||||
driveSvc = drive.NewService(ncClient, hub, pool)
|
driveSvc = drive.NewService(ncClient, hub, pool)
|
||||||
driveSvc.SetAutomation(autoDispatcher)
|
driveSvc.SetAutomation(autoDispatcher)
|
||||||
|
driveSvc.SetFileScanner(fileScanner)
|
||||||
driveHandler = drive.NewHandlerWithService(driveSvc)
|
driveHandler = drive.NewHandlerWithService(driveSvc)
|
||||||
mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc})
|
mailHandler.SetDriveUploader(&drivebridge.Bridge{Svc: driveSvc})
|
||||||
contactsHandler = contacts.NewHandler(ncClient, pool)
|
contactsHandler = contacts.NewHandler(ncClient, pool)
|
||||||
|
|||||||
267
internal/virustotal/client.go
Normal file
267
internal/virustotal/client.go
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
package virustotal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"mime/multipart"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultBaseURL = "https://www.virustotal.com/api/v3"
|
||||||
|
directUploadLimit = 32 * 1024 * 1024
|
||||||
|
maxUploadLimit = 650 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type Client struct {
|
||||||
|
apiKey string
|
||||||
|
baseURL string
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewClient(apiKey string) *Client {
|
||||||
|
return &Client{
|
||||||
|
apiKey: apiKey,
|
||||||
|
baseURL: defaultBaseURL,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 120 * time.Second,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) apiURL(path string) string {
|
||||||
|
return strings.TrimRight(c.baseURL, "/") + path
|
||||||
|
}
|
||||||
|
|
||||||
|
type analysisStats struct {
|
||||||
|
Malicious int `json:"malicious"`
|
||||||
|
Suspicious int `json:"suspicious"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type fileReport struct {
|
||||||
|
Data struct {
|
||||||
|
Attributes struct {
|
||||||
|
LastAnalysisStats analysisStats `json:"last_analysis_stats"`
|
||||||
|
} `json:"attributes"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type analysisResponse struct {
|
||||||
|
Data struct {
|
||||||
|
Attributes struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Stats analysisStats `json:"stats"`
|
||||||
|
} `json:"attributes"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type uploadResponse struct {
|
||||||
|
Data struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
} `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type uploadURLResponse struct {
|
||||||
|
Data string `json:"data"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) headers() http.Header {
|
||||||
|
h := http.Header{}
|
||||||
|
h.Set("Accept", "application/json")
|
||||||
|
h.Set("x-apikey", c.apiKey)
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) lookupFile(ctx context.Context, sha256 string) (analysisStats, bool, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/"+sha256), nil)
|
||||||
|
if err != nil {
|
||||||
|
return analysisStats{}, false, err
|
||||||
|
}
|
||||||
|
req.Header = c.headers()
|
||||||
|
|
||||||
|
res, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return analysisStats{}, false, err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode == http.StatusNotFound {
|
||||||
|
return analysisStats{}, false, nil
|
||||||
|
}
|
||||||
|
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
|
||||||
|
return analysisStats{}, false, fmt.Errorf("virustotal lookup unavailable: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
return analysisStats{}, false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var report fileReport
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&report); err != nil {
|
||||||
|
return analysisStats{}, false, err
|
||||||
|
}
|
||||||
|
return report.Data.Attributes.LastAnalysisStats, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) uploadFile(ctx context.Context, data []byte, filename string) (string, error) {
|
||||||
|
if len(data) <= directUploadLimit {
|
||||||
|
return c.uploadDirect(ctx, data, filename)
|
||||||
|
}
|
||||||
|
return c.uploadLarge(ctx, data, filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) uploadDirect(ctx context.Context, data []byte, filename string) (string, error) {
|
||||||
|
var body bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&body)
|
||||||
|
part, err := w.CreateFormFile("file", filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err := part.Write(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.apiURL("/files"), &body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header = c.headers()
|
||||||
|
req.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
|
||||||
|
res, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||||
|
return "", fmt.Errorf("virustotal upload failed: %d %s", res.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var out uploadResponse
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&out); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return out.Data.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) uploadLarge(ctx context.Context, data []byte, filename string) (string, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/files/upload_url"), nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
req.Header = c.headers()
|
||||||
|
|
||||||
|
res, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
if res.StatusCode >= 400 {
|
||||||
|
b, _ := io.ReadAll(io.LimitReader(res.Body, 4096))
|
||||||
|
return "", fmt.Errorf("virustotal upload_url failed: %d %s", res.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var urlResp uploadURLResponse
|
||||||
|
if err := json.NewDecoder(res.Body).Decode(&urlResp); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if urlResp.Data == "" {
|
||||||
|
return "", fmt.Errorf("virustotal upload_url empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
var body bytes.Buffer
|
||||||
|
w := multipart.NewWriter(&body)
|
||||||
|
part, err := w.CreateFormFile("file", filename)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if _, err := part.Write(data); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
if err := w.Close(); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
upReq, err := http.NewRequestWithContext(ctx, http.MethodPost, urlResp.Data, &body)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
upReq.Header = c.headers()
|
||||||
|
upReq.Header.Set("Content-Type", w.FormDataContentType())
|
||||||
|
|
||||||
|
upRes, err := c.httpClient.Do(upReq)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
defer upRes.Body.Close()
|
||||||
|
|
||||||
|
if upRes.StatusCode >= 400 {
|
||||||
|
b, _ := io.ReadAll(io.LimitReader(upRes.Body, 4096))
|
||||||
|
return "", fmt.Errorf("virustotal large upload failed: %d %s", upRes.StatusCode, string(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
var out uploadResponse
|
||||||
|
if err := json.NewDecoder(upRes.Body).Decode(&out); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return out.Data.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) pollAnalysis(ctx context.Context, analysisID string, timeout time.Duration) (analysisStats, error) {
|
||||||
|
deadline := time.Now().Add(timeout)
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return analysisStats{}, ctx.Err()
|
||||||
|
}
|
||||||
|
if time.Now().After(deadline) {
|
||||||
|
return analysisStats{}, fmt.Errorf("virustotal analysis timeout")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, c.apiURL("/analyses/"+analysisID), nil)
|
||||||
|
if err != nil {
|
||||||
|
return analysisStats{}, err
|
||||||
|
}
|
||||||
|
req.Header = c.headers()
|
||||||
|
|
||||||
|
res, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return analysisStats{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var out analysisResponse
|
||||||
|
decodeErr := json.NewDecoder(res.Body).Decode(&out)
|
||||||
|
res.Body.Close()
|
||||||
|
if decodeErr != nil {
|
||||||
|
return analysisStats{}, decodeErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.StatusCode >= 500 || res.StatusCode == http.StatusTooManyRequests {
|
||||||
|
return analysisStats{}, fmt.Errorf("virustotal poll unavailable: %d", res.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
status := out.Data.Attributes.Status
|
||||||
|
if status == "completed" {
|
||||||
|
return out.Data.Attributes.Stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return analysisStats{}, ctx.Err()
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func statsMalicious(stats analysisStats) bool {
|
||||||
|
return stats.Malicious > 0
|
||||||
|
}
|
||||||
7
internal/virustotal/errors.go
Normal file
7
internal/virustotal/errors.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package virustotal
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMalicious = errors.New("malware detected")
|
||||||
|
)
|
||||||
106
internal/virustotal/scan.go
Normal file
106
internal/virustotal/scan.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package virustotal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const defaultScanTimeout = 60 * time.Second
|
||||||
|
|
||||||
|
// ScanResult is the outcome of a VirusTotal scan attempt.
|
||||||
|
type ScanResult struct {
|
||||||
|
Status string // clean | skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scanner wraps Client with fail-open behavior on API errors.
|
||||||
|
type Scanner struct {
|
||||||
|
client *Client
|
||||||
|
logger *slog.Logger
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewScanner(apiKey string, logger *slog.Logger) *Scanner {
|
||||||
|
if logger == nil {
|
||||||
|
logger = slog.Default()
|
||||||
|
}
|
||||||
|
return &Scanner{
|
||||||
|
client: NewClient(apiKey),
|
||||||
|
logger: logger,
|
||||||
|
timeout: defaultScanTimeout,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ScanBytes scans file content. Returns ErrMalicious if detected.
|
||||||
|
// On VT unavailability, returns skipped (fail-open).
|
||||||
|
func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte, sha256Hex string) (ScanResult, error) {
|
||||||
|
if len(data) == 0 {
|
||||||
|
return ScanResult{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
if len(data) > maxUploadLimit {
|
||||||
|
s.logger.Warn("virustotal scan skipped: file too large for VT", "filename", filename, "size", len(data))
|
||||||
|
return ScanResult{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if sha256Hex == "" {
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
sha256Hex = hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
stats, found, err := s.client.lookupFile(ctx, sha256Hex)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("virustotal lookup failed, skipping scan", "filename", filename, "error", err)
|
||||||
|
return ScanResult{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
if statsMalicious(stats) {
|
||||||
|
return ScanResult{}, ErrMalicious
|
||||||
|
}
|
||||||
|
return ScanResult{Status: "clean"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
analysisID, err := s.client.uploadFile(ctx, data, safeFilename(filename))
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("virustotal upload failed, skipping scan", "filename", filename, "error", err)
|
||||||
|
return ScanResult{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
scanCtx, cancel := context.WithTimeout(ctx, s.timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
stats, err = s.client.pollAnalysis(scanCtx, analysisID, s.timeout)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("virustotal analysis failed, skipping scan", "filename", filename, "error", err)
|
||||||
|
return ScanResult{Status: "skipped"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if statsMalicious(stats) {
|
||||||
|
return ScanResult{}, ErrMalicious
|
||||||
|
}
|
||||||
|
return ScanResult{Status: "clean"}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func safeFilename(name string) string {
|
||||||
|
if name == "" {
|
||||||
|
return "upload"
|
||||||
|
}
|
||||||
|
base := name
|
||||||
|
if len(base) > 200 {
|
||||||
|
base = base[:200]
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// SHA256Hex returns hex-encoded SHA-256 of data.
|
||||||
|
func SHA256Hex(data []byte) string {
|
||||||
|
sum := sha256.Sum256(data)
|
||||||
|
return hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// FormatAnalysisID is a helper for logging.
|
||||||
|
func FormatAnalysisID(id string) string {
|
||||||
|
return fmt.Sprintf("vt:%s", id)
|
||||||
|
}
|
||||||
109
internal/virustotal/scan_test.go
Normal file
109
internal/virustotal/scan_test.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
package virustotal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestScannerLookupMalicious(t *testing.T) {
|
||||||
|
sha := sha256.Sum256([]byte("evil"))
|
||||||
|
shaHex := hex.EncodeToString(sha[:])
|
||||||
|
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/files/"+shaHex) {
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": map[string]any{
|
||||||
|
"attributes": map[string]any{
|
||||||
|
"last_analysis_stats": map[string]any{
|
||||||
|
"malicious": 2,
|
||||||
|
"suspicious": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc := NewScanner("test-key", nil)
|
||||||
|
sc.client.baseURL = srv.URL + "/api/v3"
|
||||||
|
|
||||||
|
_, err := sc.ScanBytes(context.Background(), "evil.bin", []byte("evil"), shaHex)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected ErrMalicious")
|
||||||
|
}
|
||||||
|
if err != ErrMalicious {
|
||||||
|
t.Fatalf("err = %v, want ErrMalicious", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScannerUploadAndPollClean(t *testing.T) {
|
||||||
|
pollCount := 0
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch {
|
||||||
|
case r.Method == http.MethodGet && strings.HasSuffix(r.URL.Path, "/files/"+SHA256Hex([]byte("clean"))):
|
||||||
|
http.NotFound(w, r)
|
||||||
|
case r.Method == http.MethodPost && strings.HasSuffix(r.URL.Path, "/files"):
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": map[string]any{"id": "analysis-1"},
|
||||||
|
})
|
||||||
|
case r.Method == http.MethodGet && strings.Contains(r.URL.Path, "/analyses/analysis-1"):
|
||||||
|
pollCount++
|
||||||
|
status := "queued"
|
||||||
|
if pollCount >= 2 {
|
||||||
|
status = "completed"
|
||||||
|
}
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"data": map[string]any{
|
||||||
|
"attributes": map[string]any{
|
||||||
|
"status": status,
|
||||||
|
"stats": map[string]any{
|
||||||
|
"malicious": 0,
|
||||||
|
"suspicious": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc := NewScanner("test-key", nil)
|
||||||
|
sc.client.baseURL = srv.URL + "/api/v3"
|
||||||
|
|
||||||
|
result, err := sc.ScanBytes(context.Background(), "clean.txt", []byte("clean"), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ScanBytes: %v", err)
|
||||||
|
}
|
||||||
|
if result.Status != "clean" {
|
||||||
|
t.Fatalf("status = %q, want clean", result.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestScannerFailOpenOnUnavailable(t *testing.T) {
|
||||||
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "down", http.StatusServiceUnavailable)
|
||||||
|
}))
|
||||||
|
defer srv.Close()
|
||||||
|
|
||||||
|
sc := NewScanner("test-key", nil)
|
||||||
|
sc.client.baseURL = srv.URL + "/api/v3"
|
||||||
|
|
||||||
|
result, err := sc.ScanBytes(context.Background(), "file.bin", []byte("payload"), "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected fail-open, got err %v", err)
|
||||||
|
}
|
||||||
|
if result.Status != "skipped" {
|
||||||
|
t.Fatalf("status = %q, want skipped", result.Status)
|
||||||
|
}
|
||||||
|
}
|
||||||
1
migrations/000034_attachment_virus_scan.down.sql
Normal file
1
migrations/000034_attachment_virus_scan.down.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
ALTER TABLE attachments DROP COLUMN IF EXISTS virus_scan_status;
|
||||||
3
migrations/000034_attachment_virus_scan.up.sql
Normal file
3
migrations/000034_attachment_virus_scan.up.sql
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
ALTER TABLE attachments
|
||||||
|
ADD COLUMN virus_scan_status TEXT NOT NULL DEFAULT 'skipped'
|
||||||
|
CHECK (virus_scan_status IN ('clean', 'skipped', 'malicious'));
|
||||||
Loading…
Reference in New Issue
Block a user