package httpcors import ( "net" "net/http" "net/url" "strings" "github.com/go-chi/cors" "github.com/ultisuite/ulti-backend/internal/api/apiresponse" "github.com/ultisuite/ulti-backend/internal/config" ) // Middleware returns chi CORS handler configured from app config. func Middleware(cfg *config.Config) func(http.Handler) http.Handler { opts := cors.Options{ AllowedMethods: []string{ http.MethodGet, http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete, http.MethodOptions, }, AllowedHeaders: []string{ "Accept", "Authorization", "Content-Type", "Idempotency-Key", "Origin", "X-Requested-With", apiresponse.TraceIDHeader, }, ExposedHeaders: []string{apiresponse.TraceIDHeader}, AllowCredentials: false, MaxAge: 300, } allowed := cfg.CORSAllowedOrigins switch { case len(allowed) == 1 && allowed[0] == "*": opts.AllowedOrigins = []string{"*"} case len(allowed) > 0: opts.AllowedOrigins = allowed case cfg != nil && !cfg.IsProduction(): opts.AllowOriginFunc = allowLocalDevOrigin default: opts.AllowedOrigins = defaultProductionOrigins(cfg) } return cors.Handler(opts) } func defaultProductionOrigins(cfg *config.Config) []string { if cfg == nil || cfg.Domain == "" { return []string{"*"} } domain := strings.TrimSpace(cfg.Domain) return []string{ "https://" + domain, "http://" + domain, } } func allowLocalDevOrigin(_ *http.Request, origin string) bool { u, err := url.Parse(origin) if err != nil || u.Scheme == "" || u.Host == "" { return false } if u.Scheme != "http" && u.Scheme != "https" { return false } host, _, err := net.SplitHostPort(u.Host) if err != nil { host = u.Host } switch strings.ToLower(host) { case "localhost", "127.0.0.1", "::1": return true default: return isPrivateLANHost(host) } } func isPrivateLANHost(host string) bool { ip := net.ParseIP(host) if ip == nil { return false } return ip.IsPrivate() || ip.IsLoopback() }