package router import ( "context" "crypto/tls" "net/http" "reflect" "time" "github.com/julienschmidt/httprouter" ) type route struct { Method string Path string Handle interface{} Middleware []Middleware } // Handle handles a request type Handle = func(*Context) error // ErrorHandle handles a request type ErrorHandle func(*Context, interface{}) // Middleware is a function that runs before your route, it gets the next handler as a parameter type Middleware func(Handle) Handle // Reader reads input to dst, returns true if successful type Reader func(c *Context, dst interface{}) (bool, error) // Router is the router itself type Router struct { routes []route Reader Reader Renderer Renderer middleware []Middleware NotFoundHandler Handle MethodNotAllowedHandler Handle ErrorHandler ErrorHandle TrimTrailingSlashes bool server *http.Server } // New returns a new Router func New() *Router { return &Router{Reader: defaultReader, NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler} } // Use adds a global middleware func (r *Router) Use(m ...Middleware) { r.middleware = append(r.middleware, m...) } // Group creates a new router group with a shared prefix and set of middlewares func (r *Router) Group(prefix string, middleware ...Middleware) *Group { return &Group{prefix: prefix, router: r, middleware: middleware} } // GET adds a GET route func (r *Router) GET(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`GET`, path, handle, middleware}) } // POST adds a POST route func (r *Router) POST(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) r.routes = append(r.routes, route{`POST`, path, handle, middleware}) } // DELETE adds a DELETE route func (r *Router) DELETE(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`DELETE`, path, handle, middleware}) } // PUT adds a PUT route func (r *Router) PUT(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) r.routes = append(r.routes, route{`PUT`, path, handle, middleware}) } // PATCH adds a PATCH route func (r *Router) PATCH(path string, handle interface{}, middleware ...Middleware) { checkInterfaceHandle(handle) r.routes = append(r.routes, route{`PATCH`, path, handle, middleware}) } // HEAD adds a HEAD route func (r *Router) HEAD(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`HEAD`, path, handle, middleware}) } // OPTIONS adds a OPTIONS route func (r *Router) OPTIONS(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`OPTIONS`, path, handle, middleware}) } // Start starts the web server and binds to the given address func (r *Router) Start(addr string) error { httpr := r.getHttpr() r.server = &http.Server{Addr: addr, Handler: httpr} return r.server.ListenAndServe() } // StartTLS starts a TLS web server using the given key, cert and config and binds to the given address func (r *Router) StartTLS(addr, certFile, keyFile string, conf *tls.Config) error { httpr := r.getHttpr() r.server = &http.Server{Addr: addr, Handler: httpr, TLSConfig: conf} return r.server.ListenAndServeTLS(certFile, keyFile) } // Stop stops the web server func (r *Router) Stop() error { if r.server == nil { return nil } ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := r.server.Shutdown(ctx) if err == context.DeadlineExceeded { err = r.server.Close() } r.server = nil return err } func (r *Router) getHttpr() http.Handler { httpr := httprouter.New() for _, v := range r.routes { handle, ok := v.Handle.(Handle) if !ok { handle = handlePOST(r, v.Handle) } middleware := make([]Middleware, len(r.middleware)+len(v.Middleware)) copy(middleware, r.middleware) copy(middleware[len(r.middleware):], v.Middleware) httpr.Handle(v.Method, v.Path, handleReq(r, handle, middleware)) } httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { handleReq(r, r.NotFoundHandler, r.middleware)(res, req, nil) }) httpr.MethodNotAllowed = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { handleReq(r, r.MethodNotAllowedHandler, r.middleware)(res, req, nil) }) httpr.PanicHandler = func(res http.ResponseWriter, req *http.Request, err interface{}) { c := NewContext(r, res, req, nil) r.ErrorHandler(c, err) } if r.TrimTrailingSlashes { httpr.RedirectTrailingSlash = false return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { l := len(req.URL.Path) if l > 1 && req.URL.Path[l-1] == '/' { req.URL.Path = req.URL.Path[:l-1] } httpr.ServeHTTP(w, req) }) } return httpr } func checkInterfaceHandle(f interface{}) { if _, ok := f.(Handle); ok { return } rt := reflect.TypeOf(f) if rt.Kind() != reflect.Func { panic(`non-func handle`) } if rt.NumIn() != 2 { panic(`handle should take 2 arguments`) } if rt.NumOut() != 1 || rt.Out(0).Name() != `error` { panic(`handle should return only error`) } if rt.In(0) != reflect.TypeOf(&Context{}) { panic(`handle should accept Context as first argument`) } } func handlePOST(r *Router, f interface{}) Handle { funcRv, inputRt := reflect.ValueOf(f), reflect.TypeOf(f).In(1) return func(c *Context) error { data := reflect.New(inputRt) if r.Reader != nil { ok, err := r.Reader(c, data.Interface()) _ = c.Request.Body.Close() if err != nil { return err } if !ok { return nil } } out := funcRv.Call([]reflect.Value{reflect.ValueOf(c), data.Elem()}) if out[0].IsNil() { return nil } return out[0].Interface().(error) } } func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { c := NewContext(r, res, req, param) f := handle for i := len(m) - 1; i >= 0; i-- { f = m[i](f) } err := f(c) if err != nil { r.ErrorHandler(c, err) } } }