diff --git a/context.go b/context.go index e6cc4b3..0699ec5 100644 --- a/context.go +++ b/context.go @@ -4,7 +4,9 @@ import ( "bytes" "encoding/json" "io" + "net" "net/http" + "strings" "github.com/julienschmidt/httprouter" ) @@ -18,7 +20,8 @@ type Context struct { store map[string]interface{} } -func newContext(router *Router, res http.ResponseWriter, req *http.Request, param httprouter.Params) *Context { +// NewContext creates a new context, this function is only exported for use in tests +func NewContext(router *Router, res http.ResponseWriter, req *http.Request, param httprouter.Params) *Context { return &Context{router, req, res, param.ByName, make(map[string]interface{})} } @@ -67,9 +70,17 @@ func (c *Context) NoContent(code int) error { // JSON returns the given status code and writes JSON to the body func (c *Context) JSON(code int, data interface{}) error { + // write to buffer first in case of error + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(data) + if err != nil { + return err + } + c.Response.Header().Set(`Content-Type`, `application/json`) c.Response.WriteHeader(code) - return json.NewEncoder(c.Response).Encode(data) // TODO: Encode to buffer first to prevent partial responses on error + _, err = io.Copy(c.Response, &buf) + return err } // Render renders a templating using the Renderer set in router @@ -99,3 +110,20 @@ func (c *Context) Set(key string, value interface{}) { func (c *Context) Get(key string) interface{} { return c.store[key] } + +// RealIP uses proxy headers for the real ip, if none exist the IP of the current connection is returned +func (c *Context) RealIP() string { + reqIP := c.Request.RemoteAddr + + if ip := c.Request.Header.Get(`X-Forwarded-For`); ip != `` { + reqIP = strings.Split(ip, `, `)[0] + } else if ip := c.Request.Header.Get(`X-Real-IP`); ip != `` { + reqIP = ip + } + + ra, _, _ := net.SplitHostPort(reqIP) + if ra != `` { + reqIP = ra + } + return reqIP +} diff --git a/default.go b/default.go index ecda050..075cdf1 100644 --- a/default.go +++ b/default.go @@ -13,7 +13,7 @@ func defaultMethodNotAllowedHandler(c *Context) error { return c.StatusText(http.StatusMethodNotAllowed) } -func defaultErrorHandler(c *Context, err interface{}) { +func defaultErrorHandler(c *Context, _ interface{}) { _ = c.StatusText(http.StatusInternalServerError) } diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7ae044b --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module git.fuyu.moe/Fuyu/router + +go 1.22.0 + +require github.com/julienschmidt/httprouter v1.3.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..096c54e --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U= +github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM= diff --git a/group.go b/group.go index e9a1451..150876d 100644 --- a/group.go +++ b/group.go @@ -1,6 +1,9 @@ package router -import urlpath "path" +import ( + urlpath "path" + "slices" +) func join(prefix, path string) string { return urlpath.Join(prefix, urlpath.Clean(path)) @@ -15,40 +18,40 @@ type Group struct { // Group creates a new router group with a shared prefix and set of middlewares func (g *Group) Group(prefix string, middleware ...Middleware) *Group { - return &Group{prefix: join(g.prefix, prefix), router: g.router, middleware: append(g.middleware, middleware...)} + return &Group{prefix: join(g.prefix, prefix), router: g.router, middleware: slices.Concat(g.middleware, middleware)} } // GET adds a GET route func (g *Group) GET(path string, handle Handle, middleware ...Middleware) { - g.router.GET(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.GET(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // POST adds a POST route func (g *Group) POST(path string, handle interface{}, middleware ...Middleware) { - g.router.POST(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.POST(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // DELETE adds a DELETE route func (g *Group) DELETE(path string, handle Handle, middleware ...Middleware) { - g.router.DELETE(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.DELETE(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // PUT adds a PUT route func (g *Group) PUT(path string, handle interface{}, middleware ...Middleware) { - g.router.PUT(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.PUT(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // PATCH adds a PATCH route func (g *Group) PATCH(path string, handle interface{}, middleware ...Middleware) { - g.router.PATCH(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.PATCH(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // HEAD adds a HEAD route func (g *Group) HEAD(path string, handle Handle, middleware ...Middleware) { - g.router.HEAD(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.HEAD(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } // OPTIONS adds a OPTIONS route func (g *Group) OPTIONS(path string, handle Handle, middleware ...Middleware) { - g.router.OPTIONS(join(g.prefix, path), handle, append(g.middleware, middleware...)...) + g.router.OPTIONS(join(g.prefix, path), handle, slices.Concat(g.middleware, middleware)...) } diff --git a/router.go b/router.go index a878555..f3252d9 100644 --- a/router.go +++ b/router.go @@ -1,8 +1,13 @@ package router import ( + "context" + "crypto/tls" + "errors" "net/http" "reflect" + "slices" + "time" "github.com/julienschmidt/httprouter" ) @@ -35,11 +40,18 @@ type Router struct { NotFoundHandler Handle MethodNotAllowedHandler Handle ErrorHandler ErrorHandle + TrimTrailingSlashes bool + server *http.Server } // New returns a new Router func New() *Router { - return &Router{Reader: defaultReader, NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler} + return &Router{ + Reader: defaultReader, + NotFoundHandler: defaultNotFoundHandler, + MethodNotAllowedHandler: defaultMethodNotAllowedHandler, + ErrorHandler: defaultErrorHandler, + } } // Use adds a global middleware @@ -94,10 +106,36 @@ func (r *Router) OPTIONS(path string, handle Handle, middleware ...Middleware) { func (r *Router) Start(addr string) error { httpr := r.getHttpr() - return http.ListenAndServe(addr, httpr) + r.server = &http.Server{Addr: addr, Handler: httpr} + return r.server.ListenAndServe() } -func (r *Router) getHttpr() *httprouter.Router { +// StartTLS starts a TLS web server using the given key, cert and config and binds to the given address +func (r *Router) StartTLS(addr, certFile, keyFile string, conf *tls.Config) error { + httpr := r.getHttpr() + + r.server = &http.Server{Addr: addr, Handler: httpr, TLSConfig: conf} + return r.server.ListenAndServeTLS(certFile, keyFile) +} + +// Stop stops the web server +func (r *Router) Stop() error { + if r.server == nil { + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + err := r.server.Shutdown(ctx) + if errors.Is(err, context.DeadlineExceeded) { + err = r.server.Close() + } + r.server = nil + + return err +} + +func (r *Router) getHttpr() http.Handler { httpr := httprouter.New() for _, v := range r.routes { @@ -106,7 +144,9 @@ func (r *Router) getHttpr() *httprouter.Router { handle = handlePOST(r, v.Handle) } - httpr.Handle(v.Method, v.Path, handleReq(r, handle, append(r.middleware, v.Middleware...))) + httpr.Handle(v.Method, v.Path, handleReq(r, handle, + slices.Concat(r.middleware, v.Middleware), + )) } httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { @@ -118,10 +158,22 @@ func (r *Router) getHttpr() *httprouter.Router { }) httpr.PanicHandler = func(res http.ResponseWriter, req *http.Request, err interface{}) { - c := newContext(r, res, req, nil) + c := NewContext(r, res, req, nil) r.ErrorHandler(c, err) } + if r.TrimTrailingSlashes { + httpr.RedirectTrailingSlash = false + + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + l := len(req.URL.Path) + if l > 1 && req.URL.Path[l-1] == '/' { + req.URL.Path = req.URL.Path[:l-1] + } + httpr.ServeHTTP(w, req) + }) + } + return httpr } @@ -147,8 +199,6 @@ func checkInterfaceHandle(f interface{}) { if rt.In(0) != reflect.TypeOf(&Context{}) { panic(`handle should accept Context as first argument`) } - - return } func handlePOST(r *Router, f interface{}) Handle { @@ -179,15 +229,14 @@ func handlePOST(r *Router, f interface{}) Handle { func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { - c := newContext(r, res, req, param) + c := NewContext(r, res, req, param) f := handle - for i := len(m) - 1; i >= 0; i-- { // TODO: 1,2,3 of 3,2,1 + for i := len(m) - 1; i >= 0; i-- { f = m[i](f) } err := f(c) - if err != nil { r.ErrorHandler(c, err) }