From 43deea02afad4841f5592d14f9a2bec6b1319d40 Mon Sep 17 00:00:00 2001 From: NiseVoid Date: Fri, 24 Apr 2020 13:01:25 +0200 Subject: [PATCH] Support pointer values --- rules_test.go | 1 + validate.go | 31 ++++++++++++++++++++++++++++--- validate_test.go | 21 +++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/rules_test.go b/rules_test.go index 32d632d..7f2fd57 100644 --- a/rules_test.go +++ b/rules_test.go @@ -131,6 +131,7 @@ func TestLenMinMax(t *testing.T) { } func check(t *testing.T, c interface{}, errCount int) { + t.Helper() errs := Validate(c) if len(errs) != errCount { t.Errorf(`Case %T(%v) should get %d errors, but got %v`, c, c, errCount, errs) diff --git a/validate.go b/validate.go index 035d30c..50c840e 100644 --- a/validate.go +++ b/validate.go @@ -159,10 +159,17 @@ func getRules(rt reflect.Type) []rule { for i := 0; i < rt.NumField(); i++ { ft := rt.Field(i) - kind := simplifyKind(ft.Type.Kind()) + + kind := ft.Type.Kind() + var ptr bool + if kind == reflect.Ptr { + kind = ft.Type.Elem().Kind() + ptr = true + } + kind = simplifyKind(kind) tags := strings.Split(ft.Tag.Get(`validate`), `,`) - rules = append(rules, getTagFuncs(i, ft, kind, tags)...) + rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr)...) // TODO: Add validator interface @@ -185,7 +192,7 @@ func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) } } -func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string) []rule { +func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr bool) []rule { var rules []rule for _, v := range tags { if v == `` { @@ -198,6 +205,13 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string value = parts[1] } + kind := kind + ptr := ptr + if ptr && (tag == `optional` || strings.TrimPrefix(tag, `!`) == `required`) { + kind = reflect.Ptr + ptr = false + } + if tag == `optional` { rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { check, _ := getTagFunc(`required`, ``, kind) @@ -221,6 +235,17 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false } + if ptr { + oldF := f + f = func(rv reflect.Value) ([]ValidationError, bool) { + if rv.IsNil() { + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } + + return oldF(rv.Elem()) + } + } + rules = append(rules, rule{i, f}) } diff --git a/validate_test.go b/validate_test.go index 1e9b3b9..2a03c42 100644 --- a/validate_test.go +++ b/validate_test.go @@ -97,3 +97,24 @@ func TestNot(t *testing.T) { t.Errorf(`Checknames missing !, got "%s" and "%s"`, errs[0].Check, errs[1].Check) } } + +func TestPtr(t *testing.T) { + type s struct { + A *int `validate:"eq=3"` + B *int `validate:"optional,eq=3"` + } + + two := 2 + three := 3 + + var pass1 = s{&three, &three} + var pass2 = s{&three, nil} + + var fail1 = s{&two, &two} + var fail2 = s{nil, nil} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 2) + check(t, fail2, 1) +}