Skip to content
This repository has been archived by the owner on Dec 23, 2023. It is now read-only.

Commit

Permalink
Revert middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
razonyang committed Mar 23, 2020
1 parent 843d3bf commit 481b38a
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 59 deletions.
16 changes: 6 additions & 10 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@

package clevergo

// MiddlewareFunc is a Handle.
type MiddlewareFunc Handle
// MiddlewareFunc is a function that receives a handle and returns a handle.
type MiddlewareFunc func(Handle) Handle

// Chain wraps handle with middlewares, middlewares will be invoked in sequence.
func Chain(handle Handle, middlewares ...MiddlewareFunc) Handle {
return func(ctx *Context) (err error) {
for _, f := range middlewares {
if err = f(ctx); err != nil {
return
}
}

return handle(ctx)
for i := len(middlewares) - 1; i >= 0; i-- {
handle = middlewares[i](handle)
}

return handle
}
27 changes: 12 additions & 15 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
package clevergo

import (
"errors"
"fmt"
"net/http/httptest"
"testing"
Expand All @@ -20,16 +19,20 @@ func echoHandler(s string) Handle {
}

func echoMiddleware(s string) MiddlewareFunc {
return func(ctx *Context) error {
ctx.WriteString(s + " ")
return nil
return func(next Handle) Handle {
return func(ctx *Context) error {
ctx.WriteString(s + " ")
return next(ctx)
}
}
}

func terminatedMiddleware() MiddlewareFunc {
return func(ctx *Context) error {
ctx.WriteString("terminated")
return errors.New("terminated")
return func(next Handle) Handle {
return func(ctx *Context) error {
ctx.WriteString("terminated")
return nil
}
}
}

Expand All @@ -56,14 +59,8 @@ func TestChain(t *testing.T) {
}

func ExampleChain() {
m1 := func(ctx *Context) error {
ctx.WriteString("m1 ")
return nil
}
m2 := func(ctx *Context) error {
ctx.WriteString("m2 ")
return nil
}
m1 := echoMiddleware("m1")
m2 := echoMiddleware("m2")
handle := Chain(echoHandler("hello"), m1, m2)
w := httptest.NewRecorder()
handle(&Context{Response: w})
Expand Down
66 changes: 37 additions & 29 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ type Router struct {
UseRawPath bool

middlewares []MiddlewareFunc
handle Handle
}

// Make sure the Router conforms with the http.Handler interface
Expand Down Expand Up @@ -153,6 +154,8 @@ func (r *Router) Group(path string, opts ...RouteGroupOption) IRouter {
// Use attaches global middlewares.
func (r *Router) Use(middlewares ...MiddlewareFunc) {
r.middlewares = append(r.middlewares, middlewares...)

r.handle = Chain(r.handleRequest, r.middlewares...)
}

// Get implements IRouter.Get.
Expand Down Expand Up @@ -353,34 +356,39 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
ctx.Request = req
ctx.Response = w
defer r.putContext(ctx)
for _, f := range r.middlewares {
if err := f(ctx); err != nil {
r.HandleError(ctx, err)
return
}

var err error
if r.handle != nil {
err = r.handle(ctx)
} else {
err = r.handleRequest(ctx)
}
if err != nil {
r.HandleError(ctx, err)
}
}

func (r *Router) handleRequest(ctx *Context) (err error) {

path := req.URL.Path
if r.UseRawPath && req.URL.RawPath != "" {
path = req.URL.RawPath
path := ctx.Request.URL.Path
if r.UseRawPath && ctx.Request.URL.RawPath != "" {
path = ctx.Request.URL.RawPath
}

if root := r.trees[req.Method]; root != nil {
if root := r.trees[ctx.Request.Method]; root != nil {
if route, ps, tsr := root.getValue(path, r.getParams, r.UseRawPath); route != nil {
ctx.Route = route
if ps != nil {
r.putParams(ps)
ctx.Params = *ps
}
err := route.handle(ctx)
if err != nil {
r.HandleError(ctx, err)
}

err = route.handle(ctx)
return
} else if req.Method != http.MethodConnect && path != "/" {
} else if ctx.Request.Method != http.MethodConnect && path != "/" {
// Moved Permanently, request with Get method
code := http.StatusMovedPermanently
if req.Method != http.MethodGet {
if ctx.Request.Method != http.MethodGet {
// Permanent Redirect, request with same method
code = http.StatusPermanentRedirect
}
Expand All @@ -391,7 +399,7 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
} else {
path = path + "/"
}
http.Redirect(w, req, path, code)
ctx.Redirect(path, code)
return
}

Expand All @@ -402,40 +410,40 @@ func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.RedirectTrailingSlash,
)
if found {
http.Redirect(w, req, fixedPath, code)
ctx.Redirect(fixedPath, code)
return
}
}
}
}

if req.Method == http.MethodOptions && r.HandleOPTIONS {
if ctx.Request.Method == http.MethodOptions && r.HandleOPTIONS {
// Handle OPTIONS requests
if allow := r.allowed(path, http.MethodOptions); allow != "" {
w.Header().Set("Allow", allow)
ctx.Response.Header().Set("Allow", allow)
if r.GlobalOPTIONS != nil {
r.GlobalOPTIONS.ServeHTTP(w, req)
r.GlobalOPTIONS.ServeHTTP(ctx.Response, ctx.Request)
}
return
}
} else if r.HandleMethodNotAllowed { // Handle 405
if allow := r.allowed(path, req.Method); allow != "" {
w.Header().Set("Allow", allow)
if allow := r.allowed(path, ctx.Request.Method); allow != "" {
ctx.Response.Header().Set("Allow", allow)
if r.MethodNotAllowed != nil {
r.MethodNotAllowed.ServeHTTP(w, req)
} else {
r.HandleError(ctx, ErrMethodNotAllowed)
r.MethodNotAllowed.ServeHTTP(ctx.Response, ctx.Request)
return
}
return
return ErrMethodNotAllowed
}
}

// Handle 404
if r.NotFound != nil {
r.NotFound.ServeHTTP(w, req)
} else {
r.HandleError(ctx, ErrNotFound)
r.NotFound.ServeHTTP(ctx.Response, ctx.Request)
return
}

return ErrNotFound
}

// HandleError handles error.
Expand Down
8 changes: 3 additions & 5 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1057,13 +1057,11 @@ func TestRouterUse(t *testing.T) {
t.Errorf("expeceted body %s, got %s", "m1 m2 foobar", w.Body.String())
}

router.Use(func(_ *Context) error {
return NewError(http.StatusForbidden, errors.New("forbidden"))
})
router.Use(terminatedMiddleware())
w = httptest.NewRecorder()
req = httptest.NewRequest(http.MethodGet, "/", nil)
router.ServeHTTP(w, req)
if w.Body.String() != "m1 m2 forbidden\n" {
t.Errorf("expected body %q, got %q", "m1 m2 forbidden\n", w.Body.String())
if w.Body.String() != "m1 m2 terminated" {
t.Errorf("expected body %q, got %q", "m1 m2 terminated", w.Body.String())
}
}

0 comments on commit 481b38a

Please sign in to comment.