package middleware import ( "bytes" "encoding/json" "log/slog" "net/http" "net/http/httptest" "testing" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" ) func withTestLogger(t *testing.T) *bytes.Buffer { t.Helper() var buf bytes.Buffer old := slog.Default() slog.SetDefault(slog.New(slog.NewJSONHandler(&buf, nil))) t.Cleanup(func() { slog.SetDefault(old) }) return &buf } func parseLogRecord(t *testing.T, buf *bytes.Buffer) map[string]any { t.Helper() var record map[string]any if err := json.Unmarshal(buf.Bytes(), &record); err != nil { t.Fatalf("unmarshal log record: %v", err) } return record } func TestLoggingIncludesRequestFields(t *testing.T) { buf := withTestLogger(t) handler := Logging(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) })) req := httptest.NewRequest(http.MethodPost, "/api/messages", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) record := parseLogRecord(t, buf) if record["msg"] != "request" { t.Fatalf("msg = %v, want request", record["msg"]) } if record["method"] != http.MethodPost { t.Fatalf("method = %v, want POST", record["method"]) } if record["path"] != "/api/messages" { t.Fatalf("path = %v, want /api/messages", record["path"]) } if record["status"] != float64(http.StatusCreated) { t.Fatalf("status = %v, want %d", record["status"], http.StatusCreated) } if _, ok := record["duration"]; !ok { t.Fatal("expected duration field in log record") } } func TestLoggingIncludesRequestIDFromContext(t *testing.T) { buf := withTestLogger(t) handler := Logging(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest(http.MethodGet, "/", nil) req = req.WithContext(apiresponse.WithTraceID(req.Context(), "trace-abc")) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) record := parseLogRecord(t, buf) if record["request_id"] != "trace-abc" { t.Fatalf("request_id = %v, want trace-abc", record["request_id"]) } } func TestLoggingOmitsRequestIDWhenMissing(t *testing.T) { buf := withTestLogger(t) handler := Logging(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) record := parseLogRecord(t, buf) if _, ok := record["request_id"]; ok { t.Fatalf("request_id = %v, want field omitted", record["request_id"]) } } func TestLoggingWithTraceMiddlewareUsesClientTraceID(t *testing.T) { buf := withTestLogger(t) handler := TraceID(Logging(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))) req := httptest.NewRequest(http.MethodGet, "/healthz", nil) req.Header.Set(apiresponse.TraceIDHeader, "client-trace") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) record := parseLogRecord(t, buf) if record["request_id"] != "client-trace" { t.Fatalf("request_id = %v, want client-trace", record["request_id"]) } }