package middleware import ( "net/http" "net/http/httptest" "testing" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" ) func TestTraceIDMiddlewareUsesRequestHeader(t *testing.T) { var gotTraceID string handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotTraceID = apiresponse.TraceIDFromContext(r.Context()) })) req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set(apiresponse.TraceIDHeader, "client-trace") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if gotTraceID != "client-trace" { t.Fatalf("context trace_id = %q, want client-trace", gotTraceID) } if rec.Header().Get(apiresponse.TraceIDHeader) != "client-trace" { t.Fatalf("response header = %q, want client-trace", rec.Header().Get(apiresponse.TraceIDHeader)) } } func TestTraceIDMiddlewareGeneratesWhenMissing(t *testing.T) { var gotTraceID string handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { gotTraceID = apiresponse.TraceIDFromContext(r.Context()) })) req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if gotTraceID == "" { t.Fatal("expected generated trace id in context") } if rec.Header().Get(apiresponse.TraceIDHeader) != gotTraceID { t.Fatalf("response header = %q, context = %q", rec.Header().Get(apiresponse.TraceIDHeader), gotTraceID) } }