Support pointer values

This commit is contained in:
Nise Void 2020-04-24 13:01:25 +02:00
parent cb4720cd31
commit 43deea02af
Signed by untrusted user: NiseVoid
GPG Key ID: FBA14AC83EA602F3
3 changed files with 50 additions and 3 deletions

View File

@ -131,6 +131,7 @@ func TestLenMinMax(t *testing.T) {
} }
func check(t *testing.T, c interface{}, errCount int) { func check(t *testing.T, c interface{}, errCount int) {
t.Helper()
errs := Validate(c) errs := Validate(c)
if len(errs) != errCount { if len(errs) != errCount {
t.Errorf(`Case %T(%v) should get %d errors, but got %v`, c, c, errCount, errs) t.Errorf(`Case %T(%v) should get %d errors, but got %v`, c, c, errCount, errs)

View File

@ -159,10 +159,17 @@ func getRules(rt reflect.Type) []rule {
for i := 0; i < rt.NumField(); i++ { for i := 0; i < rt.NumField(); i++ {
ft := rt.Field(i) ft := rt.Field(i)
kind := simplifyKind(ft.Type.Kind())
kind := ft.Type.Kind()
var ptr bool
if kind == reflect.Ptr {
kind = ft.Type.Elem().Kind()
ptr = true
}
kind = simplifyKind(kind)
tags := strings.Split(ft.Tag.Get(`validate`), `,`) tags := strings.Split(ft.Tag.Get(`validate`), `,`)
rules = append(rules, getTagFuncs(i, ft, kind, tags)...) rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr)...)
// TODO: Add validator interface // TODO: Add validator interface
@ -185,7 +192,7 @@ func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool)
} }
} }
func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string) []rule { func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr bool) []rule {
var rules []rule var rules []rule
for _, v := range tags { for _, v := range tags {
if v == `` { if v == `` {
@ -198,6 +205,13 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string
value = parts[1] value = parts[1]
} }
kind := kind
ptr := ptr
if ptr && (tag == `optional` || strings.TrimPrefix(tag, `!`) == `required`) {
kind = reflect.Ptr
ptr = false
}
if tag == `optional` { if tag == `optional` {
rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) {
check, _ := getTagFunc(`required`, ``, kind) check, _ := getTagFunc(`required`, ``, kind)
@ -221,6 +235,17 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
} }
if ptr {
oldF := f
f = func(rv reflect.Value) ([]ValidationError, bool) {
if rv.IsNil() {
return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false
}
return oldF(rv.Elem())
}
}
rules = append(rules, rule{i, f}) rules = append(rules, rule{i, f})
} }

View File

@ -97,3 +97,24 @@ func TestNot(t *testing.T) {
t.Errorf(`Checknames missing !, got "%s" and "%s"`, errs[0].Check, errs[1].Check) t.Errorf(`Checknames missing !, got "%s" and "%s"`, errs[0].Check, errs[1].Check)
} }
} }
func TestPtr(t *testing.T) {
type s struct {
A *int `validate:"eq=3"`
B *int `validate:"optional,eq=3"`
}
two := 2
three := 3
var pass1 = s{&three, &three}
var pass2 = s{&three, nil}
var fail1 = s{&two, &two}
var fail2 = s{nil, nil}
check(t, pass1, 0)
check(t, pass2, 0)
check(t, fail1, 2)
check(t, fail2, 1)
}