package filescan import ( "bytes" "context" "errors" "io" "log/slog" "os" "path/filepath" "github.com/ultisuite/ulti-backend/internal/orgpolicy" "github.com/ultisuite/ulti-backend/internal/virustotal" ) const memoryBufferLimit = 32 * 1024 * 1024 // Result holds scan outcome for persistence. type Result struct { Status string // clean | skipped } // PolicyLoader supplies org file policies at runtime. type PolicyLoader interface { FilePolicies(ctx context.Context) (orgpolicy.FilePolicies, error) ScanEnabled(ctx context.Context) (bool, string, error) } // Scanner coordinates org policy and VirusTotal scanning. type Scanner struct { policies PolicyLoader logger *slog.Logger } func NewScanner(policies PolicyLoader, logger *slog.Logger) *Scanner { if logger == nil { logger = slog.Default() } return &Scanner{policies: policies, logger: logger} } // ScanReader reads all bytes from r (up to maxBytes), optionally scans, returns data for storage. func (s *Scanner) ScanReader(ctx context.Context, filename string, r io.Reader, size int64) ([]byte, Result, error) { fp, err := s.policies.FilePolicies(ctx) if err != nil { return nil, Result{}, err } maxBytes := fp.MaxUploadBytes if maxBytes <= 0 { maxBytes = 512 * 1024 * 1024 } if maxBytes > virustotalMaxBytes() { maxBytes = virustotalMaxBytes() } data, err := readLimited(r, size, maxBytes) if err != nil { return nil, Result{}, err } enabled := fp.VirusScanEnabled && fp.VirusTotalAPIKey != "" if !enabled { return data, Result{Status: "skipped"}, nil } vt := virustotal.NewScanner(fp.VirusTotalAPIKey, s.logger) scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data)) if err != nil { if errors.Is(err, virustotal.ErrMalicious) { return nil, Result{Status: "malicious"}, virustotal.ErrMalicious } return nil, Result{}, err } return data, Result{Status: scanResult.Status}, nil } // ScanBytes scans pre-loaded bytes when policy is already known. func (s *Scanner) ScanBytes(ctx context.Context, filename string, data []byte) (Result, error) { enabled, apiKey, err := s.policies.ScanEnabled(ctx) if err != nil { return Result{}, err } if !enabled { return Result{Status: "skipped"}, nil } vt := virustotal.NewScanner(apiKey, s.logger) scanResult, err := vt.ScanBytes(ctx, filename, data, virustotal.SHA256Hex(data)) if err != nil { if errors.Is(err, virustotal.ErrMalicious) { return Result{Status: "malicious"}, virustotal.ErrMalicious } return Result{}, err } return Result{Status: scanResult.Status}, nil } func readLimited(r io.Reader, size int64, maxBytes int64) ([]byte, error) { if size >= 0 && size <= memoryBufferLimit { limited := io.LimitReader(r, maxBytes+1) data, err := io.ReadAll(limited) if err != nil { return nil, err } if int64(len(data)) > maxBytes { return nil, errors.New("file exceeds max upload size") } return data, nil } tmp, err := os.CreateTemp("", "ulti-scan-*") if err != nil { return nil, err } tmpPath := tmp.Name() defer os.Remove(tmpPath) written, err := io.Copy(tmp, io.LimitReader(r, maxBytes+1)) if err != nil { tmp.Close() return nil, err } if err := tmp.Close(); err != nil { return nil, err } if written > maxBytes { return nil, errors.New("file exceeds max upload size") } return os.ReadFile(filepath.Clean(tmpPath)) } func virustotalMaxBytes() int64 { return 650 * 1024 * 1024 } // ErrMalicious re-exports virustotal.ErrMalicious for handlers. var ErrMalicious = virustotal.ErrMalicious // NopReader helper for tests. func NopReader(data []byte) io.Reader { return bytes.NewReader(data) }