validate/validate.go
2020-11-04 18:09:07 +01:00

313 lines
6.1 KiB
Go

package validate
import (
"reflect"
"strconv"
"strings"
"sync"
)
type ValidateValuer interface {
ValidateValue() interface{}
}
var vvType = reflect.TypeOf((*ValidateValuer)(nil)).Elem()
// Validate validates a variable
func Validate(v interface{}) []ValidationError {
rv := reflect.ValueOf(v)
return validate(rv)
}
// Field is always an array/slice index or struct field
type Field struct {
Index *int
Field *reflect.StructField
}
// Fields is a list of Field
type Fields []Field
// ToString converts a list of fields to a string using the given struct tag
func (f Fields) ToString(tag string) string {
var field string
for k, v := range f {
if v.Index != nil {
field += `[` + strconv.Itoa(*v.Index) + `]`
continue
}
if k != 0 {
field += `.`
}
var name string
if tag != `` {
name = v.Field.Tag.Get(tag)
if idx := strings.IndexRune(name, ','); idx != -1 {
name = name[:idx]
}
}
if name == `` {
name = v.Field.Name
}
field += name
}
return field
}
func (f Fields) String() string {
return f.ToString(``)
}
// ValidationError contains information about a failed validation
type ValidationError struct {
Field Fields
Check string
Value string
}
func (e ValidationError) String() string {
var val string
if e.Value != `` {
val = `=` + e.Value
}
return e.Field.String() + `: ` + e.Check + val
}
func prependErrs(f Field, errs []ValidationError) []ValidationError {
for k := range errs {
fields := make([]Field, len(errs[k].Field)+1)
fields[0] = f
for k, v := range errs[k].Field {
fields[k+1] = v
}
errs[k].Field = fields
}
return errs
}
func validate(rv reflect.Value) []ValidationError {
for rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
}
rv = rv.Elem()
}
if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice {
var errs []ValidationError
for i := 0; i < rv.Len(); i++ {
newErrs := validate(rv.Index(i))
index := i
errs = append(errs, prependErrs(Field{&index, nil}, newErrs)...)
}
return errs
}
if rv.Kind() != reflect.Struct {
return nil
}
var errs []ValidationError
skip := -1
for _, rule := range getCachedRules(rv.Type()) {
if skip == rule.index {
continue
}
err, cont := rule.f(rv.Field(rule.index))
errs = append(errs, err...)
if !cont {
skip = rule.index
}
}
return errs
}
var cache = struct {
sync.Mutex
data map[reflect.Type][]rule
}{data: map[reflect.Type][]rule{}}
func getCachedRules(rt reflect.Type) []rule {
cache.Lock()
defer cache.Unlock()
rules, ok := cache.data[rt]
if !ok {
rules = getRules(rt)
cache.data[rt] = rules
}
return rules
}
type rule struct {
index int
f func(reflect.Value) ([]ValidationError, bool)
}
func getRules(rt reflect.Type) []rule {
var rules []rule
for i := 0; i < rt.NumField(); i++ {
ft := rt.Field(i)
var valuer bool
if ft.Type.Implements(vvType) {
ft.Type = reflect.TypeOf(reflect.New(ft.Type).Interface().(ValidateValuer).ValidateValue())
valuer = true
}
kind := simplifyKind(ft.Type.Kind())
var ptr bool
if kind == reflect.Ptr {
kind = ft.Type.Elem().Kind()
ptr = true
}
kind = simplifyKind(kind)
tags := strings.Split(ft.Tag.Get(`validate`), `,`)
rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr, valuer)...)
switch kind {
case reflect.Slice, reflect.Struct, reflect.Interface:
rules = append(rules, rule{i, nest(ft)})
}
}
return rules
}
func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) {
return func(rv reflect.Value) ([]ValidationError, bool) {
errs := validate(rv)
if errs != nil {
return prependErrs(Field{nil, &ft}, errs), false
}
return nil, true
}
}
func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr, valuer bool) []rule {
var rules []rule
for _, v := range tags {
if v == `` {
continue
}
parts := strings.SplitN(v, `=`, 2)
tag, value := parts[0], ``
if len(parts) > 1 {
value = parts[1]
}
kind := kind
ptr := ptr
if ptr && (tag == `optional` || strings.TrimPrefix(tag, `!`) == `required`) {
kind = reflect.Ptr
ptr = false
}
var f validateCheck
if tag == `optional` {
f = func(rv reflect.Value) ([]ValidationError, bool) {
check, _ := getTagFunc(`required`, ``, kind)
return nil, check(rv, nil)
}
} else {
var not bool
if strings.HasPrefix(tag, `!`) {
not = true
}
check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind)
f = func(rv reflect.Value) ([]ValidationError, bool) {
if check(rv, val) == !not {
return nil, true
}
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
}
}
if ptr {
f = depointerFunc(f, ft, tag, value)
}
if valuer {
f = valuerFunc(f)
}
rules = append(rules, rule{i, f})
}
return rules
}
type validateCheck = func(rv reflect.Value) ([]ValidationError, bool)
func depointerFunc(f validateCheck, ft reflect.StructField, tag, value string) validateCheck {
return func(rv reflect.Value) ([]ValidationError, bool) {
if rv.IsNil() {
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
}
return f(rv.Elem())
}
}
func valuerFunc(f validateCheck) validateCheck {
return func(rv reflect.Value) ([]ValidationError, bool) {
rv = reflect.ValueOf(rv.Interface().(ValidateValuer).ValidateValue())
return f(rv)
}
}
func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) {
tagInfo, ok := funcs[tag]
if !ok {
panic(`Unknown validation ` + tag)
}
check, ok := tagInfo.kinds[kind]
if !ok {
panic(`Validation ` + tag + ` does not support ` + kind.String())
}
var val interface{}
if value != `` && tagInfo.inputFunc != nil {
val = tagInfo.inputFunc(kind, value)
}
return check, val
}
func simplifyKind(kind reflect.Kind) reflect.Kind {
switch kind {
case reflect.Array:
return reflect.Slice
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return reflect.Int
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return reflect.Uint
case reflect.Float32:
return reflect.Float64
}
return kind
}