49 lines
1.4 KiB
Go
49 lines
1.4 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
|
)
|
|
|
|
func TestTraceIDMiddlewareUsesRequestHeader(t *testing.T) {
|
|
var gotTraceID string
|
|
handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotTraceID = apiresponse.TraceIDFromContext(r.Context())
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set(apiresponse.TraceIDHeader, "client-trace")
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if gotTraceID != "client-trace" {
|
|
t.Fatalf("context trace_id = %q, want client-trace", gotTraceID)
|
|
}
|
|
if rec.Header().Get(apiresponse.TraceIDHeader) != "client-trace" {
|
|
t.Fatalf("response header = %q, want client-trace", rec.Header().Get(apiresponse.TraceIDHeader))
|
|
}
|
|
}
|
|
|
|
func TestTraceIDMiddlewareGeneratesWhenMissing(t *testing.T) {
|
|
var gotTraceID string
|
|
handler := TraceID(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
gotTraceID = apiresponse.TraceIDFromContext(r.Context())
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rec, req)
|
|
|
|
if gotTraceID == "" {
|
|
t.Fatal("expected generated trace id in context")
|
|
}
|
|
if rec.Header().Get(apiresponse.TraceIDHeader) != gotTraceID {
|
|
t.Fatalf("response header = %q, context = %q", rec.Header().Get(apiresponse.TraceIDHeader), gotTraceID)
|
|
}
|
|
}
|