Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(middleware/cors): Validation of multiple Origins #2883

Merged
merged 8 commits into from
Mar 1, 2024
36 changes: 21 additions & 15 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,32 @@ func New(config ...Config) fiber.Handler {
panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.")
}

// Validate and normalize static AllowOrigins if not using AllowOriginsFunc
if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
validatedOrigins := []string{}
for _, origin := range strings.Split(cfg.AllowOrigins, ",") {
isValid, normalizedOrigin := normalizeOrigin(origin)
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
var allowOrigins []string

// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
allowOrigins = make([]string, len(origins))

for i, origin := range origins {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)

if isValid {
validatedOrigins = append(validatedOrigins, normalizedOrigin)
allowOrigins[i] = normalizedOrigin
} else {
log.Warnf("[CORS] Invalid origin format in configuration: %s", origin)
log.Warnf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin)
panic("[CORS] Invalid origin provided in configuration")
}
}
cfg.AllowOrigins = strings.Join(validatedOrigins, ",")
} else {
// If AllowOrigins is set to a wildcard or not set,
// set allowOrigins to a slice with a single element
allowOrigins = []string{cfg.AllowOrigins}
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
}

// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")

// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
Expand Down Expand Up @@ -165,10 +173,8 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}

// Simple request
Expand Down
114 changes: 114 additions & 0 deletions middleware/cors/cors_test.go
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,21 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
reqOrigin: "https://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "https://domain-1.com, https://example.com",
reqOrigin: "https://example.com",
shouldAllowOrigin: true,
},
{
pattern: "https://domain-1.com, https://example.com",
reqOrigin: "https://domain-2.com",
shouldAllowOrigin: false,
},
{
pattern: "https://domain-1.com,https://example.com",
reqOrigin: "https://domain-1.com",
shouldAllowOrigin: true,
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -452,6 +467,33 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
RequestOrigin: "https://aaa.com",
ResponseOrigin: "https://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginAllowed",
Config: Config{
AllowOrigins: "https://aaa.com,https://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "https://bbb.com",
ResponseOrigin: "https://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginNotAllowed",
Config: Config{
AllowOrigins: "https://aaa.com,https://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "https://ccc.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed",
Config: Config{
AllowOrigins: "https://aaa.com, https://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "https://aaa.com",
ResponseOrigin: "https://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
Expand Down Expand Up @@ -647,3 +689,75 @@ func Test_CORS_AllowCredentials(t *testing.T) {
})
}
}

// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4
func Benchmark_CORS_NewHandler(b *testing.B) {
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
app := fiber.New()
c := New(Config{
AllowOrigins: "https://localhost,https://example.com",
sixcolors marked this conversation as resolved.
Show resolved Hide resolved
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})

app.Use(c)
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "https://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
h(ctx)
}
}

// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "https://localhost,https://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})

app.Use(c)
app.Use(func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})

h := app.Handler()
ctx := &fasthttp.RequestCtx{}

req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "https://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")

ctx.Init(req, nil, nil)

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
h(ctx)
}
}
Loading