From 69754a99460898dd1e7faf8419a2d7d0cee21490 Mon Sep 17 00:00:00 2001 From: NiseVoid Date: Tue, 9 Oct 2018 18:07:43 +0200 Subject: [PATCH] Add middleware --- default.go | 18 ++++++++ group.go | 41 +++++++++-------- router.go | 133 +++++++++++++++++++++++++++++++++++------------------ 3 files changed, 127 insertions(+), 65 deletions(-) create mode 100644 default.go diff --git a/default.go b/default.go new file mode 100644 index 0000000..99cc394 --- /dev/null +++ b/default.go @@ -0,0 +1,18 @@ +package router + +import ( + "fmt" +) + +func defaultNotFoundHandler(c *Context) error { + return c.String(404, `not found`) +} + +func defaultMethodNotAllowedHandler(c *Context) error { + return c.String(504, `method not allowed`) +} + +func defaultErrorHandler(c *Context, err interface{}) { + fmt.Println(err) + c.String(500, `internal server error`) +} diff --git a/group.go b/group.go index 2e7bb80..e9a1451 100644 --- a/group.go +++ b/group.go @@ -6,48 +6,49 @@ func join(prefix, path string) string { return urlpath.Join(prefix, urlpath.Clean(path)) } +// Group is a router group with a shared prefix and set of middlewares type Group struct { - router *Router - prefix string + router *Router + prefix string + middleware []Middleware } -func (g *Group) Group(prefix string) *Group { - return &Group{prefix: join(g.prefix, prefix), router: g.router} +// Group creates a new router group with a shared prefix and set of middlewares +func (g *Group) Group(prefix string, middleware ...Middleware) *Group { + return &Group{prefix: join(g.prefix, prefix), router: g.router, middleware: append(g.middleware, middleware...)} } // GET adds a GET route -func (g *Group) GET(path string, handle GetHandle) { - g.router.GET(join(g.prefix, path), handle) +func (g *Group) GET(path string, handle Handle, middleware ...Middleware) { + g.router.GET(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // POST adds a POST route -func (g *Group) POST(path string, handle interface{}) { - g.router.POST(join(g.prefix, path), handle) +func (g *Group) POST(path string, handle interface{}, middleware ...Middleware) { + g.router.POST(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // DELETE adds a DELETE route -func (g *Group) DELETE(path string, handle GetHandle) { - g.router.DELETE(join(g.prefix, path), handle) +func (g *Group) DELETE(path string, handle Handle, middleware ...Middleware) { + g.router.DELETE(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // PUT adds a PUT route -func (g *Group) PUT(path string, handle interface{}) { - checkInterfaceHandle(handle) - g.router.PUT(join(g.prefix, path), handle) +func (g *Group) PUT(path string, handle interface{}, middleware ...Middleware) { + g.router.PUT(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // PATCH adds a PATCH route -func (g *Group) PATCH(path string, handle interface{}) { - checkInterfaceHandle(handle) - g.router.PATCH(join(g.prefix, path), handle) +func (g *Group) PATCH(path string, handle interface{}, middleware ...Middleware) { + g.router.PATCH(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // HEAD adds a HEAD route -func (g *Group) HEAD(path string, handle GetHandle) { - g.router.HEAD(join(g.prefix, path), handle) +func (g *Group) HEAD(path string, handle Handle, middleware ...Middleware) { + g.router.HEAD(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } // OPTIONS adds a OPTIONS route -func (g *Group) OPTIONS(path string, handle GetHandle) { - g.router.OPTIONS(join(g.prefix, path), handle) +func (g *Group) OPTIONS(path string, handle Handle, middleware ...Middleware) { + g.router.OPTIONS(join(g.prefix, path), handle, append(g.middleware, middleware...)...) } diff --git a/router.go b/router.go index 4c1c616..720c417 100644 --- a/router.go +++ b/router.go @@ -2,7 +2,6 @@ package router import ( "encoding/json" - "fmt" "net/http" "reflect" @@ -10,65 +9,82 @@ import ( ) type route struct { - Method string - Path string - Handle interface{} + Method string + Path string + Handle interface{} + Middleware []Middleware } -// GetHandle handles a request that doesn't receive a body -type GetHandle func(*Context) error +// Handle handles a request +type Handle func(*Context) error + +// ErrorHandle handles a request +type ErrorHandle func(*Context, interface{}) + +// Middleware TODO: +type Middleware func(Handle) Handle // Router is the router itself type Router struct { - routes []route - Renderer Renderer + routes []route + Renderer Renderer + middleware []Middleware + NotFoundHandler Handle + MethodNotAllowedHandler Handle + ErrorHandler ErrorHandle } // New returns a new Router func New() *Router { - return &Router{} + return &Router{NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler} } -func (r *Router) Group(prefix string) *Group { - return &Group{prefix: prefix, router: r} +// Use adds a global middleware +func (r *Router) Use(m ...Middleware) { + r.middleware = append(r.middleware, m...) +} + +// Group creates a new router group with a shared prefix and set of middlewares +func (r *Router) Group(prefix string, middleware ...Middleware) *Group { + return &Group{prefix: prefix, router: r, middleware: middleware} } // GET adds a GET route -func (r *Router) GET(path string, handle GetHandle) { - r.routes = append(r.routes, route{`GET`, path, handle}) +func (r *Router) GET(path string, handle Handle, middleware ...Middleware) { + r.routes = append(r.routes, route{`GET`, path, handle, middleware}) } // POST adds a POST route -func (r *Router) POST(path string, handle interface{}) { +func (r *Router) POST(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) - r.routes = append(r.routes, route{`POST`, path, handle}) + r.routes = append(r.routes, route{`POST`, path, handle, middleware}) } // DELETE adds a DELETE route -func (r *Router) DELETE(path string, handle GetHandle) { - r.routes = append(r.routes, route{`DELETE`, path, handle}) +func (r *Router) DELETE(path string, handle Handle, middleware ...Middleware) { + r.routes = append(r.routes, route{`DELETE`, path, handle, middleware}) } // PUT adds a PUT route -func (r *Router) PUT(path string, handle interface{}) { +func (r *Router) PUT(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) - r.routes = append(r.routes, route{`PUT`, path, handle}) + r.routes = append(r.routes, route{`PUT`, path, handle, middleware}) } // PATCH adds a PATCH route -func (r *Router) PATCH(path string, handle interface{}) { +func (r *Router) PATCH(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) - r.routes = append(r.routes, route{`PATCH`, path, handle}) + r.routes = append(r.routes, route{`PATCH`, path, handle, middleware}) } // HEAD adds a HEAD route -func (r *Router) HEAD(path string, handle GetHandle) { - r.routes = append(r.routes, route{`HEAD`, path, handle}) +func (r *Router) HEAD(path string, handle Handle, middleware ...Middleware) { + r.routes = append(r.routes, route{`HEAD`, path, handle, middleware}) } // OPTIONS adds a OPTIONS route -func (r *Router) OPTIONS(path string, handle GetHandle) { - r.routes = append(r.routes, route{`OPTIONS`, path, handle}) +func (r *Router) OPTIONS(path string, handle Handle, middleware ...Middleware) { + r.routes = append(r.routes, route{`OPTIONS`, path, handle, middleware}) } // Start starts the web server and binds to the given address @@ -82,19 +98,39 @@ func (r *Router) getHttpr() *httprouter.Router { httpr := httprouter.New() for _, v := range r.routes { - if handle, ok := v.Handle.(GetHandle); ok { - httpr.Handle(v.Method, v.Path, handleGET(r, handle)) - continue + handle, ok := v.Handle.(Handle) + if !ok { + handle = handlePOST(r, v.Handle) } - httpr.Handle(v.Method, v.Path, handlePOST(r, v.Handle)) + httpr.Handle(v.Method, v.Path, handleReq(r, handle, append(r.middleware, v.Middleware...))) + } + + httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + handleReq(r, r.NotFoundHandler, r.middleware)(res, req, nil) + }) + + httpr.MethodNotAllowed = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + handleReq(r, r.MethodNotAllowedHandler, r.middleware)(res, req, nil) + }) + + httpr.PanicHandler = func(res http.ResponseWriter, req *http.Request, err interface{}) { + c := newContext(r, res, req, nil) + r.ErrorHandler(c, err) } return httpr } +func handleErr(errHandler ErrorHandle, err interface{}) Handle { + return func(c *Context) error { + errHandler(c, err) + return nil + } +} + func checkInterfaceHandle(f interface{}) { - if _, ok := f.(GetHandle); ok { + if _, ok := f.(Handle); ok { return } @@ -119,34 +155,41 @@ func checkInterfaceHandle(f interface{}) { return } -func handlePOST(r *Router, f interface{}) httprouter.Handle { +func handlePOST(r *Router, f interface{}) Handle { funcRv, inputRt := reflect.ValueOf(f), reflect.TypeOf(f).In(1) - return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { - c := newContext(r, res, req, param) - + return func(c *Context) error { data := reflect.New(inputRt) - { - err := json.NewDecoder(req.Body).Decode(data.Interface()) - req.Body.Close() - if err != nil { - c.NoContent(400) // TODO: send info about error (BindError) - return - } + + err := json.NewDecoder(c.Request.Body).Decode(data.Interface()) + c.Request.Body.Close() + if err != nil { + c.NoContent(400) // TODO: send info about error (BindError) + return nil } out := funcRv.Call([]reflect.Value{reflect.ValueOf(c), data.Elem()}) - err := out[0].Interface() - _ = err + + if out[0].IsNil() { + return nil + } + return out[0].Interface().(error) } } -func handleGET(r *Router, f GetHandle) httprouter.Handle { +func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { c := newContext(r, res, req, param) + f := handle + for i := len(m) - 1; i >= 0; i-- { // TODO: 1,2,3 of 3,2,1 + f = m[i](f) + } + err := f(c) - fmt.Println(err) + if err != nil { + r.ErrorHandler(c, err) + } } }