ultisuite-backend/internal/api/middleware/trace_test.go

49 lines
1.4 KiB
Go

package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
)
func TestTraceIDMiddlewareUsesRequestHeader(t *testing.T) {
var gotTraceID string
handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotTraceID = apiresponse.TraceIDFromContext(r.Context())
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set(apiresponse.TraceIDHeader, "client-trace")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if gotTraceID != "client-trace" {
t.Fatalf("context trace_id = %q, want client-trace", gotTraceID)
}
if rec.Header().Get(apiresponse.TraceIDHeader) != "client-trace" {
t.Fatalf("response header = %q, want client-trace", rec.Header().Get(apiresponse.TraceIDHeader))
}
}
func TestTraceIDMiddlewareGeneratesWhenMissing(t *testing.T) {
var gotTraceID string
handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotTraceID = apiresponse.TraceIDFromContext(r.Context())
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if gotTraceID == "" {
t.Fatal("expected generated trace id in context")
}
if rec.Header().Get(apiresponse.TraceIDHeader) != gotTraceID {
t.Fatalf("response header = %q, context = %q", rec.Header().Get(apiresponse.TraceIDHeader), gotTraceID)
}
}