Remove reflect; Improve defaults

This commit is contained in:
Nise Void 2022-11-25 16:01:32 +01:00
parent 283c75b32e
commit b6045aa53d
6 changed files with 62 additions and 80 deletions

View File

@ -25,6 +25,13 @@ func NewContext(router *Router, res http.ResponseWriter, req *http.Request, para
return &Context{router, req, res, param.ByName, make(map[string]interface{})} return &Context{router, req, res, param.ByName, make(map[string]interface{})}
} }
func (c *Context) Validate(v interface{}) bool {
if c.router.Validator == nil {
return true
}
return c.router.Validator(c, v)
}
// QueryParam returns the specified parameter from the query string. // QueryParam returns the specified parameter from the query string.
// Returns an empty string if it doesn't exist. Returns the first parameter if multiple instances exist // Returns an empty string if it doesn't exist. Returns the first parameter if multiple instances exist
func (c *Context) QueryParam(param string) string { func (c *Context) QueryParam(param string) string {
@ -57,6 +64,10 @@ func (c *Context) String(code int, s string) error {
return err return err
} }
func (c *Context) BadRequest(reason string) error {
return c.router.BadRequestFormatter(c, reason)
}
// StatusText returns the given status code with the matching status text // StatusText returns the given status code with the matching status text
func (c *Context) StatusText(code int) error { func (c *Context) StatusText(code int) error {
return c.String(code, http.StatusText(code)) return c.String(code, http.StatusText(code))

View File

@ -2,6 +2,7 @@ package router
import ( import (
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
) )
@ -14,14 +15,27 @@ func defaultMethodNotAllowedHandler(c *Context) error {
} }
func defaultErrorHandler(c *Context, err interface{}) { func defaultErrorHandler(c *Context, err interface{}) {
fmt.Printf("%s '%s': %s\n", c.Request.Method, c.Request.URL, err)
_ = c.StatusText(http.StatusInternalServerError) _ = c.StatusText(http.StatusInternalServerError)
} }
func defaultReader(c *Context, dst interface{}) (bool, error) { func defaultBadRequestFormatter(c *Context, reason string) error {
err := json.NewDecoder(c.Request.Body).Decode(dst) return c.String(http.StatusBadRequest, reason)
if err != nil {
return false, c.StatusText(http.StatusBadRequest)
} }
return true, nil func BodyJSON[T any](handle func(*Context, T) error) Handle {
return func(c *Context) error {
var dst T
err := json.NewDecoder(c.Request.Body).Decode(&dst)
if err != nil {
return c.BadRequest(`Invalid JSON`)
}
if !c.Validate(dst) {
return nil
}
return handle(c, dst)
}
} }

5
go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.fuyu.moe/Fuyu/router
go 1.19
require github.com/julienschmidt/httprouter v1.3.0

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/julienschmidt/httprouter v1.3.0 h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=

View File

@ -24,7 +24,7 @@ func (g *Group) GET(path string, handle Handle, middleware ...Middleware) {
} }
// POST adds a POST route // POST adds a POST route
func (g *Group) POST(path string, handle interface{}, middleware ...Middleware) { func (g *Group) POST(path string, handle Handle, middleware ...Middleware) {
g.router.POST(join(g.prefix, path), handle, append(g.middleware, middleware...)...) g.router.POST(join(g.prefix, path), handle, append(g.middleware, middleware...)...)
} }
@ -34,12 +34,12 @@ func (g *Group) DELETE(path string, handle Handle, middleware ...Middleware) {
} }
// PUT adds a PUT route // PUT adds a PUT route
func (g *Group) PUT(path string, handle interface{}, middleware ...Middleware) { func (g *Group) PUT(path string, handle Handle, middleware ...Middleware) {
g.router.PUT(join(g.prefix, path), handle, append(g.middleware, middleware...)...) g.router.PUT(join(g.prefix, path), handle, append(g.middleware, middleware...)...)
} }
// PATCH adds a PATCH route // PATCH adds a PATCH route
func (g *Group) PATCH(path string, handle interface{}, middleware ...Middleware) { func (g *Group) PATCH(path string, handle Handle, middleware ...Middleware) {
g.router.PATCH(join(g.prefix, path), handle, append(g.middleware, middleware...)...) g.router.PATCH(join(g.prefix, path), handle, append(g.middleware, middleware...)...)
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"net/http" "net/http"
"reflect"
"time" "time"
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
@ -13,7 +12,7 @@ import (
type route struct { type route struct {
Method string Method string
Path string Path string
Handle interface{} Handle Handle
Middleware []Middleware Middleware []Middleware
} }
@ -23,28 +22,38 @@ type Handle = func(*Context) error
// ErrorHandle handles a request // ErrorHandle handles a request
type ErrorHandle func(*Context, interface{}) type ErrorHandle func(*Context, interface{})
// BadRequestFormatter formats a bad request
type BadRequestFormatter func(*Context, string) error
// Validator validates data.
// If validation fails it should write a response and return false
type Validator func(*Context, interface{}) bool
// Middleware is a function that runs before your route, it gets the next handler as a parameter // Middleware is a function that runs before your route, it gets the next handler as a parameter
type Middleware func(Handle) Handle type Middleware func(Handle) Handle
// Reader reads input to dst, returns true if successful
type Reader func(c *Context, dst interface{}) (bool, error)
// Router is the router itself // Router is the router itself
type Router struct { type Router struct {
routes []route routes []route
Reader Reader
Renderer Renderer Renderer Renderer
middleware []Middleware middleware []Middleware
NotFoundHandler Handle NotFoundHandler Handle
MethodNotAllowedHandler Handle MethodNotAllowedHandler Handle
ErrorHandler ErrorHandle ErrorHandler ErrorHandle
BadRequestFormatter BadRequestFormatter
Validator Validator
TrimTrailingSlashes bool TrimTrailingSlashes bool
server *http.Server server *http.Server
} }
// New returns a new Router // New returns a new Router
func New() *Router { func New() *Router {
return &Router{Reader: defaultReader, NotFoundHandler: defaultNotFoundHandler, MethodNotAllowedHandler: defaultMethodNotAllowedHandler, ErrorHandler: defaultErrorHandler} return &Router{
NotFoundHandler: defaultNotFoundHandler,
MethodNotAllowedHandler: defaultMethodNotAllowedHandler,
ErrorHandler: defaultErrorHandler,
BadRequestFormatter: defaultBadRequestFormatter,
}
} }
// Use adds a global middleware // Use adds a global middleware
@ -63,8 +72,7 @@ func (r *Router) GET(path string, handle Handle, middleware ...Middleware) {
} }
// POST adds a POST route // POST adds a POST route
func (r *Router) POST(path string, handle interface{}, middleware ...Middleware) { func (r *Router) POST(path string, handle Handle, middleware ...Middleware) {
checkInterfaceHandle(handle)
r.routes = append(r.routes, route{`POST`, path, handle, middleware}) r.routes = append(r.routes, route{`POST`, path, handle, middleware})
} }
@ -74,14 +82,12 @@ func (r *Router) DELETE(path string, handle Handle, middleware ...Middleware) {
} }
// PUT adds a PUT route // PUT adds a PUT route
func (r *Router) PUT(path string, handle interface{}, middleware ...Middleware) { func (r *Router) PUT(path string, handle Handle, middleware ...Middleware) {
checkInterfaceHandle(handle)
r.routes = append(r.routes, route{`PUT`, path, handle, middleware}) r.routes = append(r.routes, route{`PUT`, path, handle, middleware})
} }
// PATCH adds a PATCH route // PATCH adds a PATCH route
func (r *Router) PATCH(path string, handle interface{}, middleware ...Middleware) { func (r *Router) PATCH(path string, handle Handle, middleware ...Middleware) {
checkInterfaceHandle(handle)
r.routes = append(r.routes, route{`PATCH`, path, handle, middleware}) r.routes = append(r.routes, route{`PATCH`, path, handle, middleware})
} }
@ -132,16 +138,11 @@ func (r *Router) getHttpr() http.Handler {
httpr := httprouter.New() httpr := httprouter.New()
for _, v := range r.routes { for _, v := range r.routes {
handle, ok := v.Handle.(Handle)
if !ok {
handle = handlePOST(r, v.Handle)
}
middleware := make([]Middleware, len(r.middleware)+len(v.Middleware)) middleware := make([]Middleware, len(r.middleware)+len(v.Middleware))
copy(middleware, r.middleware) copy(middleware, r.middleware)
copy(middleware[len(r.middleware):], v.Middleware) copy(middleware[len(r.middleware):], v.Middleware)
httpr.Handle(v.Method, v.Path, handleReq(r, handle, middleware)) httpr.Handle(v.Method, v.Path, handleReq(r, v.Handle, middleware))
} }
httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { httpr.NotFound = http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
@ -172,56 +173,6 @@ func (r *Router) getHttpr() http.Handler {
return httpr return httpr
} }
func checkInterfaceHandle(f interface{}) {
if _, ok := f.(Handle); ok {
return
}
rt := reflect.TypeOf(f)
if rt.Kind() != reflect.Func {
panic(`non-func handle`)
}
if rt.NumIn() != 2 {
panic(`handle should take 2 arguments`)
}
if rt.NumOut() != 1 || rt.Out(0).Name() != `error` {
panic(`handle should return only error`)
}
if rt.In(0) != reflect.TypeOf(&Context{}) {
panic(`handle should accept Context as first argument`)
}
}
func handlePOST(r *Router, f interface{}) Handle {
funcRv, inputRt := reflect.ValueOf(f), reflect.TypeOf(f).In(1)
return func(c *Context) error {
data := reflect.New(inputRt)
if r.Reader != nil {
ok, err := r.Reader(c, data.Interface())
_ = c.Request.Body.Close()
if err != nil {
return err
}
if !ok {
return nil
}
}
out := funcRv.Call([]reflect.Value{reflect.ValueOf(c), data.Elem()})
if out[0].IsNil() {
return nil
}
return out[0].Interface().(error)
}
}
func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle { func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle {
return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) { return func(res http.ResponseWriter, req *http.Request, param httprouter.Params) {
c := NewContext(r, res, req, param) c := NewContext(r, res, req, param)
@ -232,7 +183,6 @@ func handleReq(r *Router, handle Handle, m []Middleware) httprouter.Handle {
} }
err := f(c) err := f(c)
if err != nil { if err != nil {
r.ErrorHandler(c, err) r.ErrorHandler(c, err)
} }