package router import ( "context" "crypto/tls" "net/http" "time" "github.com/julienschmidt/httprouter" ) type route struct { Method string Path string Handle Handle Middleware []Middleware } // Handle handles a request type Handle = func(*Context) error // ErrorHandle handles a request type ErrorHandle func(*Context, interface{}) // BadRequestFormatter formats a bad request type BadRequestFormatter func(*Context, string) error // Validator validates data. // If validation fails it should write a response and return false type Validator func(*Context, interface{}) bool // Middleware is a function that runs before your route, it gets the next handler as a parameter type Middleware func(Handle) Handle // Router is the router itself type Router struct { routes []route Renderer Renderer middleware []Middleware NotFoundHandler Handle MethodNotAllowedHandler Handle ErrorHandler ErrorHandle BadRequestFormatter BadRequestFormatter Validator Validator TrimTrailingSlashes bool server *http.Server } // New returns a new Router func New() *Router { return &Router{ NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler, BadRequestFormatter: defaultBadRequestFormatter, } } // Use adds a global middleware func (r *Router) Use(m ...Middleware) { r.middleware = append(r.middleware, m...) } // Group creates a new router group with a shared prefix and set of middlewares func (r *Router) Group(prefix string, middleware ...Middleware) *Group { return &Group{prefix: prefix, router: r, middleware: middleware} } // GET adds a GET route func (r *Router) GET(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`GET`, path, handle, middleware}) } // POST adds a POST route func (r *Router) POST(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`POST`, path, handle, middleware}) } // DELETE adds a DELETE route func (r *Router) DELETE(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`DELETE`, path, handle, middleware}) } // PUT adds a PUT route func (r *Router) PUT(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`PUT`, path, handle, middleware}) } // PATCH adds a PATCH route func (r *Router) PATCH(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`PATCH`, path, handle, middleware}) } // HEAD adds a HEAD route func (r *Router) HEAD(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`HEAD`, path, handle, middleware}) } // OPTIONS adds a OPTIONS route func (r *Router) OPTIONS(path string, handle Handle, middleware ...Middleware) { r.routes = append(r.routes, route{`OPTIONS`, path, handle, middleware}) } // Start starts the web server and binds to the given address func (r *Router) Start(addr string) error { httpr := r.getHttpr() r.server = &http.Server{Addr: addr, Handler: httpr} return r.server.ListenAndServe() } // StartTLS starts a TLS web server using the given key, cert and config and binds to the given address func (r *Router) StartTLS(addr, certFile, keyFile string, conf *tls.Config) error { httpr := r.getHttpr() r.server = &http.Server{Addr: addr, Handler: httpr, TLSConfig: conf} return r.server.ListenAndServeTLS(certFile, keyFile) } // Stop stops the web server func (r *Router) Stop() error { if r.server == nil { return nil } ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) defer cancel() err := r.server.Shutdown(ctx) if err == context.DeadlineExceeded { err = r.server.Close() } r.server = nil return err } func (r *Router) getHttpr() http.Handler { httpr := httprouter.New() for _, v := range r.routes { middleware := make([]Middleware, len(r.middleware)+len(v.Middleware)) copy(middleware, r.middleware) copy(middleware[len(r.middleware):], v.Middleware) httpr.Handle(v.Method, v.Path, handleReq(r, v.Handle, middleware)) } httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { handleReq(r, r.NotFoundHandler, r.middleware)(res, req, nil) }) httpr.MethodNotAllowed = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { handleReq(r, r.MethodNotAllowedHandler, r.middleware)(res, req, nil) }) httpr.PanicHandler = func(res http.ResponseWriter, req *http.Request, err interface{}) { c := NewContext(r, res, req, nil) r.ErrorHandler(c, err) } if r.TrimTrailingSlashes { httpr.RedirectTrailingSlash = false return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { l := len(req.URL.Path) if l > 1 && req.URL.Path[l-1] == '/' { req.URL.Path = req.URL.Path[:l-1] } httpr.ServeHTTP(w, req) }) } return httpr } func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { c := NewContext(r, res, req, param) f := handle for i := len(m) - 1; i >= 0; i-- { f = m[i](f) } err := f(c) if err != nil { r.ErrorHandler(c, err) } } }