From 309d04c19f7da1241a5953dc51da43264d0362a2 Mon Sep 17 00:00:00 2001 From: Elwin Tamminga Date: Mon, 8 Jul 2024 20:36:46 +0200 Subject: [PATCH] Fix group middlewares getting overwritten and fix partial response --- context.go | 10 +++++++++- default.go | 2 +- go.mod | 5 +++++ go.sum | 2 ++ group.go | 21 ++++++++++++--------- router.go | 20 ++++++++++++-------- 6 files changed, 41 insertions(+), 19 deletions(-) create mode 100644 go.mod create mode 100644 go.sum diff --git a/context.go b/context.go index e9c2fd6..0699ec5 100644 --- a/context.go +++ b/context.go @@ -70,9 +70,17 @@ func (c *Context) NoContent(code int) error { // JSON returns the given status code and writes JSON to the body func (c *Context) JSON(code int, data interface{}) error { + // write to buffer first in case of error + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(data) + if err != nil { + return err + } + c.Response.Header().Set(`Content-Type`, `application/json`) c.Response.WriteHeader(code) - return json.NewEncoder(c.Response).Encode(data) // TODO: Encode to buffer first to prevent partial responses on error + _, err = io.Copy(c.Response, &buf) + return err } // Render renders a templating using the Renderer set in router diff --git a/default.go b/default.go index ecda050..075cdf1 100644 --- a/default.go +++ b/default.go @@ -13,7 +13,7 @@ func defaultMethodNotAllowedHandler(c *Context) error { return c.StatusText(http.StatusMethodNotAllowed) } -func defaultErrorHandler(c *Context, err interface{}) { +func defaultErrorHandler(c *Context, _ interface{}) { _ = c.StatusText(http.StatusInternalServerError) } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7ae044b --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.fuyu.moe/Fuyu/router + +go 1.22.0 + +require github.com/julienschmidt/httprouter v1.3.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..096c54e --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= diff --git a/group.go b/group.go index e9a1451..150876d 100644 --- a/group.go +++ b/group.go @@ -1,6 +1,9 @@ package router -import urlpath "path" +import ( + urlpath "path" + "slices" +) func join(prefix, path string) string { return urlpath.Join(prefix, urlpath.Clean(path)) @@ -15,40 +18,40 @@ type Group struct { // Group creates a new router group with a shared prefix and set of middlewares func (g *Group) Group(prefix string, middleware ...Middleware) *Group { - return &Group{prefix: join(g.prefix, prefix), router: g.router, middleware: append(g.middleware, middleware...)} + return &Group{prefix: join(g.prefix, prefix), router: g.router, middleware: slices.Concat(g.middleware, middleware)} } // GET adds a GET route func (g *Group) GET(path string, handle Handle, middleware ...Middleware) { - g.router.GET(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.GET(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // POST adds a POST route func (g *Group) POST(path string, handle interface{}, middleware ...Middleware) { - g.router.POST(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.POST(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // DELETE adds a DELETE route func (g *Group) DELETE(path string, handle Handle, middleware ...Middleware) { - g.router.DELETE(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.DELETE(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // PUT adds a PUT route func (g *Group) PUT(path string, handle interface{}, middleware ...Middleware) { - g.router.PUT(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.PUT(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // PATCH adds a PATCH route func (g *Group) PATCH(path string, handle interface{}, middleware ...Middleware) { - g.router.PATCH(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.PATCH(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // HEAD adds a HEAD route func (g *Group) HEAD(path string, handle Handle, middleware ...Middleware) { - g.router.HEAD(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.HEAD(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // OPTIONS adds a OPTIONS route func (g *Group) OPTIONS(path string, handle Handle, middleware ...Middleware) { - g.router.OPTIONS(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.OPTIONS(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } diff --git a/router.go b/router.go index d869fc6..f3252d9 100644 --- a/router.go +++ b/router.go @@ -3,8 +3,10 @@ package router import ( "context" "crypto/tls" + "errors" "net/http" "reflect" + "slices" "time" "github.com/julienschmidt/httprouter" @@ -44,7 +46,12 @@ type Router struct { // New returns a new Router func New() *Router { - return &Router{Reader: defaultReader, NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler} + return &Router{ + Reader: defaultReader, + NotFoundHandler: defaultNotFoundHandler, + MethodNotAllowedHandler: defaultMethodNotAllowedHandler, + ErrorHandler: defaultErrorHandler, + } } // Use adds a global middleware @@ -120,7 +127,7 @@ func (r *Router) Stop() error { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := r.server.Shutdown(ctx) - if err == context.DeadlineExceeded { + if errors.Is(err, context.DeadlineExceeded) { err = r.server.Close() } r.server = nil @@ -137,11 +144,9 @@ func (r *Router) getHttpr() http.Handler { handle = handlePOST(r, v.Handle) } - middleware := make([]Middleware, len(r.middleware)+len(v.Middleware)) - copy(middleware, r.middleware) - copy(middleware[len(r.middleware):], v.Middleware) - - httpr.Handle(v.Method, v.Path, handleReq(r, handle, middleware)) + httpr.Handle(v.Method, v.Path, handleReq(r, handle, + slices.Concat(r.middleware, v.Middleware), + )) } httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { @@ -232,7 +237,6 @@ func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { } err := f(c) - if err != nil { r.ErrorHandler(c, err) }