diff --git a/validate.go b/validate.go index 50c840e..59216c8 100644 --- a/validate.go +++ b/validate.go @@ -7,6 +7,12 @@ import ( "sync" ) +type ValidateValuer interface { + ValidateValue() interface{} +} + +var vvType = reflect.TypeOf((*ValidateValuer)(nil)).Elem() + // Validate validates a variable func Validate(v interface{}) []ValidationError { rv := reflect.ValueOf(v) @@ -160,7 +166,14 @@ func getRules(rt reflect.Type) []rule { for i := 0; i < rt.NumField(); i++ { ft := rt.Field(i) - kind := ft.Type.Kind() + var valuer bool + if ft.Type.Implements(vvType) { + ft.Type = reflect.TypeOf(reflect.New(ft.Type).Interface().(ValidateValuer).ValidateValue()) + valuer = true + } + + kind := simplifyKind(ft.Type.Kind()) + var ptr bool if kind == reflect.Ptr { kind = ft.Type.Elem().Kind() @@ -169,9 +182,7 @@ func getRules(rt reflect.Type) []rule { kind = simplifyKind(kind) tags := strings.Split(ft.Tag.Get(`validate`), `,`) - rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr)...) - - // TODO: Add validator interface + rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr, valuer)...) switch kind { case reflect.Slice, reflect.Struct, reflect.Interface: @@ -192,7 +203,7 @@ func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) } } -func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr bool) []rule { +func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr, valuer bool) []rule { var rules []rule for _, v := range tags { if v == `` { @@ -212,38 +223,35 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string ptr = false } + var f validateCheck + if tag == `optional` { - rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { + f = func(rv reflect.Value) ([]ValidationError, bool) { check, _ := getTagFunc(`required`, ``, kind) return nil, check(rv, nil) - }}) - continue - } - - var not bool - if strings.HasPrefix(tag, `!`) { - not = true - } - - check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind) - - f := func(rv reflect.Value) ([]ValidationError, bool) { - if check(rv, val) == !not { - return nil, true + } + } else { + var not bool + if strings.HasPrefix(tag, `!`) { + not = true } - return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind) + + f = func(rv reflect.Value) ([]ValidationError, bool) { + if check(rv, val) == !not { + return nil, true + } + + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } } if ptr { - oldF := f - f = func(rv reflect.Value) ([]ValidationError, bool) { - if rv.IsNil() { - return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false - } - - return oldF(rv.Elem()) - } + f = depointerFunc(f, ft, tag, value) + } + if valuer { + f = valuerFunc(f) } rules = append(rules, rule{i, f}) @@ -252,6 +260,25 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string return rules } +type validateCheck = func(rv reflect.Value) ([]ValidationError, bool) + +func depointerFunc(f validateCheck, ft reflect.StructField, tag, value string) validateCheck { + return func(rv reflect.Value) ([]ValidationError, bool) { + if rv.IsNil() { + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } + + return f(rv.Elem()) + } +} + +func valuerFunc(f validateCheck) validateCheck { + return func(rv reflect.Value) ([]ValidationError, bool) { + rv = reflect.ValueOf(rv.Interface().(ValidateValuer).ValidateValue()) + return f(rv) + } +} + func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) { tagInfo, ok := funcs[tag] if !ok { diff --git a/validate_test.go b/validate_test.go index 2a03c42..96a83b4 100644 --- a/validate_test.go +++ b/validate_test.go @@ -10,10 +10,10 @@ func TestOptionalMultiple(t *testing.T) { B int `validate:"gt=3,lt=20"` } - var pass1 = s{``, 4} - var pass2 = s{`a`, 19} + pass1 := s{``, 4} + pass2 := s{`a`, 19} - var fail = s{`b`, 3} + fail := s{`b`, 3} check(t, pass1, 0) check(t, pass2, 0) @@ -34,11 +34,11 @@ func TestNesting(t *testing.T) { B sb } - var pass = s{[]sa{{`abc`}}, sb{12}} + pass := s{[]sa{{`abc`}}, sb{12}} - var fail1 = s{nil, sb{12}} - var fail2 = s{[]sa{{``}}, sb{12}} - var fail3 = s{[]sa{{``}}, sb{9}} + fail1 := s{nil, sb{12}} + fail2 := s{[]sa{{``}}, sb{12}} + fail3 := s{[]sa{{``}}, sb{9}} check(t, pass, 0) check(t, fail1, 1) @@ -83,10 +83,10 @@ func TestNot(t *testing.T) { B int `validate:"!eq=3"` } - var pass1 = s{`ab`, 2} - var pass2 = s{`abcd`, 4} + pass1 := s{`ab`, 2} + pass2 := s{`abcd`, 4} - var fail = s{`abc`, 3} + fail := s{`abc`, 3} check(t, pass1, 0) check(t, pass2, 0) @@ -107,14 +107,47 @@ func TestPtr(t *testing.T) { two := 2 three := 3 - var pass1 = s{&three, &three} - var pass2 = s{&three, nil} + pass1 := s{&three, &three} + pass2 := s{&three, nil} - var fail1 = s{&two, &two} - var fail2 = s{nil, nil} + fail1 := s{&two, &two} + fail2 := s{nil, nil} check(t, pass1, 0) check(t, pass2, 0) check(t, fail1, 2) check(t, fail2, 1) } + +type val struct { + Int int + Valid bool +} + +func (v val) ValidateValue() interface{} { + if !v.Valid { + return (*int)(nil) + } + + return &v.Int +} + +func TestValidateValuer(t *testing.T) { + type s struct { + A val `validate:"required,lt=2"` + B val `validate:"optional,eq=3"` + } + + pass1 := s{val{Valid: true}, val{}} + pass2 := s{val{Valid: true}, val{Int: 3, Valid: true}} + + fail1 := s{} + fail2 := s{val{Int: 2, Valid: true}, val{}} + fail3 := s{val{Valid: true}, val{Int: 2, Valid: true}} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 1) + check(t, fail2, 1) + check(t, fail3, 1) +}