96 lines
2.0 KiB
Go
96 lines
2.0 KiB
Go
package httpcors
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/go-chi/cors"
|
|
|
|
"github.com/ultisuite/ulti-backend/internal/api/apiresponse"
|
|
"github.com/ultisuite/ulti-backend/internal/config"
|
|
)
|
|
|
|
// Middleware returns chi CORS handler configured from app config.
|
|
func Middleware(cfg *config.Config) func(http.Handler) http.Handler {
|
|
opts := cors.Options{
|
|
AllowedMethods: []string{
|
|
http.MethodGet,
|
|
http.MethodHead,
|
|
http.MethodPost,
|
|
http.MethodPut,
|
|
http.MethodPatch,
|
|
http.MethodDelete,
|
|
http.MethodOptions,
|
|
},
|
|
AllowedHeaders: []string{
|
|
"Accept",
|
|
"Authorization",
|
|
"Content-Type",
|
|
"Idempotency-Key",
|
|
"Origin",
|
|
"X-Requested-With",
|
|
apiresponse.TraceIDHeader,
|
|
},
|
|
ExposedHeaders: []string{apiresponse.TraceIDHeader},
|
|
AllowCredentials: false,
|
|
MaxAge: 300,
|
|
}
|
|
|
|
allowed := cfg.CORSAllowedOrigins
|
|
switch {
|
|
case len(allowed) == 1 && allowed[0] == "*":
|
|
opts.AllowedOrigins = []string{"*"}
|
|
case len(allowed) > 0:
|
|
opts.AllowedOrigins = allowed
|
|
case cfg != nil && !cfg.IsProduction():
|
|
opts.AllowOriginFunc = allowLocalDevOrigin
|
|
default:
|
|
opts.AllowedOrigins = defaultProductionOrigins(cfg)
|
|
}
|
|
|
|
return cors.Handler(opts)
|
|
}
|
|
|
|
func defaultProductionOrigins(cfg *config.Config) []string {
|
|
if cfg == nil || cfg.Domain == "" {
|
|
return []string{"*"}
|
|
}
|
|
domain := strings.TrimSpace(cfg.Domain)
|
|
return []string{
|
|
"https://" + domain,
|
|
"http://" + domain,
|
|
}
|
|
}
|
|
|
|
func allowLocalDevOrigin(_ *http.Request, origin string) bool {
|
|
u, err := url.Parse(origin)
|
|
if err != nil || u.Scheme == "" || u.Host == "" {
|
|
return false
|
|
}
|
|
if u.Scheme != "http" && u.Scheme != "https" {
|
|
return false
|
|
}
|
|
|
|
host, _, err := net.SplitHostPort(u.Host)
|
|
if err != nil {
|
|
host = u.Host
|
|
}
|
|
|
|
switch strings.ToLower(host) {
|
|
case "localhost", "127.0.0.1", "::1":
|
|
return true
|
|
default:
|
|
return isPrivateLANHost(host)
|
|
}
|
|
}
|
|
|
|
func isPrivateLANHost(host string) bool {
|
|
ip := net.ParseIP(host)
|
|
if ip == nil {
|
|
return false
|
|
}
|
|
return ip.IsPrivate() || ip.IsLoopback()
|
|
}
|