Add ValidateValuer

This commit is contained in:
Nise Void 2020-11-04 17:07:45 +01:00
parent 43deea02af
commit 84ee3011ac
Signed by untrusted user: NiseVoid
GPG Key ID: FBA14AC83EA602F3
2 changed files with 103 additions and 43 deletions

View File

@ -7,6 +7,12 @@ import (
"sync" "sync"
) )
type ValidateValuer interface {
ValidateValue() interface{}
}
var vvType = reflect.TypeOf((*ValidateValuer)(nil)).Elem()
// Validate validates a variable // Validate validates a variable
func Validate(v interface{}) []ValidationError { func Validate(v interface{}) []ValidationError {
rv := reflect.ValueOf(v) rv := reflect.ValueOf(v)
@ -160,7 +166,14 @@ func getRules(rt reflect.Type) []rule {
for i := 0; i < rt.NumField(); i++ { for i := 0; i < rt.NumField(); i++ {
ft := rt.Field(i) ft := rt.Field(i)
kind := ft.Type.Kind() var valuer bool
if ft.Type.Implements(vvType) {
ft.Type = reflect.TypeOf(reflect.New(ft.Type).Interface().(ValidateValuer).ValidateValue())
valuer = true
}
kind := simplifyKind(ft.Type.Kind())
var ptr bool var ptr bool
if kind == reflect.Ptr { if kind == reflect.Ptr {
kind = ft.Type.Elem().Kind() kind = ft.Type.Elem().Kind()
@ -169,9 +182,7 @@ func getRules(rt reflect.Type) []rule {
kind = simplifyKind(kind) kind = simplifyKind(kind)
tags := strings.Split(ft.Tag.Get(`validate`), `,`) tags := strings.Split(ft.Tag.Get(`validate`), `,`)
rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr)...) rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr, valuer)...)
// TODO: Add validator interface
switch kind { switch kind {
case reflect.Slice, reflect.Struct, reflect.Interface: case reflect.Slice, reflect.Struct, reflect.Interface:
@ -192,7 +203,7 @@ func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool)
} }
} }
func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr bool) []rule { func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr, valuer bool) []rule {
var rules []rule var rules []rule
for _, v := range tags { for _, v := range tags {
if v == `` { if v == `` {
@ -212,14 +223,14 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string
ptr = false ptr = false
} }
var f validateCheck
if tag == `optional` { if tag == `optional` {
rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { f = func(rv reflect.Value) ([]ValidationError, bool) {
check, _ := getTagFunc(`required`, ``, kind) check, _ := getTagFunc(`required`, ``, kind)
return nil, check(rv, nil) return nil, check(rv, nil)
}})
continue
} }
} else {
var not bool var not bool
if strings.HasPrefix(tag, `!`) { if strings.HasPrefix(tag, `!`) {
not = true not = true
@ -227,23 +238,20 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string
check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind) check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind)
f := func(rv reflect.Value) ([]ValidationError, bool) { f = func(rv reflect.Value) ([]ValidationError, bool) {
if check(rv, val) == !not { if check(rv, val) == !not {
return nil, true return nil, true
} }
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
} }
}
if ptr { if ptr {
oldF := f f = depointerFunc(f, ft, tag, value)
f = func(rv reflect.Value) ([]ValidationError, bool) {
if rv.IsNil() {
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
}
return oldF(rv.Elem())
} }
if valuer {
f = valuerFunc(f)
} }
rules = append(rules, rule{i, f}) rules = append(rules, rule{i, f})
@ -252,6 +260,25 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string
return rules return rules
} }
type validateCheck = func(rv reflect.Value) ([]ValidationError, bool)
func depointerFunc(f validateCheck, ft reflect.StructField, tag, value string) validateCheck {
return func(rv reflect.Value) ([]ValidationError, bool) {
if rv.IsNil() {
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
}
return f(rv.Elem())
}
}
func valuerFunc(f validateCheck) validateCheck {
return func(rv reflect.Value) ([]ValidationError, bool) {
rv = reflect.ValueOf(rv.Interface().(ValidateValuer).ValidateValue())
return f(rv)
}
}
func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) { func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) {
tagInfo, ok := funcs[tag] tagInfo, ok := funcs[tag]
if !ok { if !ok {

View File

@ -10,10 +10,10 @@ func TestOptionalMultiple(t *testing.T) {
B int `validate:"gt=3,lt=20"` B int `validate:"gt=3,lt=20"`
} }
var pass1 = s{``, 4} pass1 := s{``, 4}
var pass2 = s{`a`, 19} pass2 := s{`a`, 19}
var fail = s{`b`, 3} fail := s{`b`, 3}
check(t, pass1, 0) check(t, pass1, 0)
check(t, pass2, 0) check(t, pass2, 0)
@ -34,11 +34,11 @@ func TestNesting(t *testing.T) {
B sb B sb
} }
var pass = s{[]sa{{`abc`}}, sb{12}} pass := s{[]sa{{`abc`}}, sb{12}}
var fail1 = s{nil, sb{12}} fail1 := s{nil, sb{12}}
var fail2 = s{[]sa{{``}}, sb{12}} fail2 := s{[]sa{{``}}, sb{12}}
var fail3 = s{[]sa{{``}}, sb{9}} fail3 := s{[]sa{{``}}, sb{9}}
check(t, pass, 0) check(t, pass, 0)
check(t, fail1, 1) check(t, fail1, 1)
@ -83,10 +83,10 @@ func TestNot(t *testing.T) {
B int `validate:"!eq=3"` B int `validate:"!eq=3"`
} }
var pass1 = s{`ab`, 2} pass1 := s{`ab`, 2}
var pass2 = s{`abcd`, 4} pass2 := s{`abcd`, 4}
var fail = s{`abc`, 3} fail := s{`abc`, 3}
check(t, pass1, 0) check(t, pass1, 0)
check(t, pass2, 0) check(t, pass2, 0)
@ -107,14 +107,47 @@ func TestPtr(t *testing.T) {
two := 2 two := 2
three := 3 three := 3
var pass1 = s{&three, &three} pass1 := s{&three, &three}
var pass2 = s{&three, nil} pass2 := s{&three, nil}
var fail1 = s{&two, &two} fail1 := s{&two, &two}
var fail2 = s{nil, nil} fail2 := s{nil, nil}
check(t, pass1, 0) check(t, pass1, 0)
check(t, pass2, 0) check(t, pass2, 0)
check(t, fail1, 2) check(t, fail1, 2)
check(t, fail2, 1) check(t, fail2, 1)
} }
type val struct {
Int int
Valid bool
}
func (v val) ValidateValue() interface{} {
if !v.Valid {
return (*int)(nil)
}
return &v.Int
}
func TestValidateValuer(t *testing.T) {
type s struct {
A val `validate:"required,lt=2"`
B val `validate:"optional,eq=3"`
}
pass1 := s{val{Valid: true}, val{}}
pass2 := s{val{Valid: true}, val{Int: 3, Valid: true}}
fail1 := s{}
fail2 := s{val{Int: 2, Valid: true}, val{}}
fail3 := s{val{Valid: true}, val{Int: 2, Valid: true}}
check(t, pass1, 0)
check(t, pass2, 0)
check(t, fail1, 1)
check(t, fail2, 1)
check(t, fail3, 1)
}