diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..694d4e9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Fuyu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/input.go b/input.go index 0acd88d..b6edb38 100644 --- a/input.go +++ b/input.go @@ -6,15 +6,21 @@ import ( "strconv" ) -func inputInt(kind reflect.Kind, value string) interface{} { - return int(inputSame(reflect.Int, value).(int64)) +// InputFunc is a function that converts the parameter of a validation rule to the desired type +type InputFunc func(reflect.Kind, string) interface{} + +// InputInt always returns an int +func InputInt(kind reflect.Kind, value string) interface{} { + return int(InputSame(reflect.Int, value).(int64)) } -func inputRegexp(kind reflect.Kind, value string) interface{} { +// InputRegexp always returns a compiled regular expression +func InputRegexp(kind reflect.Kind, value string) interface{} { return regexp.MustCompile(value) } -func inputSame(kind reflect.Kind, value string) interface{} { +// InputSame returns the type matching the field +func InputSame(kind reflect.Kind, value string) interface{} { var val interface{} var err error diff --git a/rules.go b/rules.go index 69dddd0..3656938 100644 --- a/rules.go +++ b/rules.go @@ -6,17 +6,25 @@ import ( "strings" ) -type listFunc func(reflect.Value, interface{}) bool +// ValidationFunc is a function to validate a field +type ValidationFunc func(reflect.Value, interface{}) bool + +// Kinds is a map with validation funcs for each reflect.Kind +type Kinds map[reflect.Kind]ValidationFunc + type listFuncInfo struct { - inputFunc func(reflect.Kind, string) interface{} - kinds _kinds + inputFunc InputFunc + kinds Kinds } -type _kinds map[reflect.Kind]listFunc +// AddRule adds a rule to the list of validation functions +func AddRule(name string, inputFunc InputFunc, kinds Kinds) { + funcs[name] = listFuncInfo{inputFunc, kinds} +} // nolint: dupl var funcs = map[string]listFuncInfo{ - `required`: {nil, _kinds{ + `required`: {nil, Kinds{ reflect.Ptr: func(rv reflect.Value, _ interface{}) bool { return !rv.IsNil() }, @@ -44,29 +52,29 @@ var funcs = map[string]listFuncInfo{ }}, // Strings - `prefix`: {inputSame, _kinds{ + `prefix`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.HasPrefix(rv.String(), val.(string)) }, }}, - `suffix`: {inputSame, _kinds{ + `suffix`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.HasSuffix(rv.String(), val.(string)) }, }}, - `contains`: {inputSame, _kinds{ + `contains`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.Contains(rv.String(), val.(string)) }, }}, - `regexp`: {inputRegexp, _kinds{ + `regexp`: {InputRegexp, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return val.(*regexp.Regexp).MatchString(rv.String()) }, }}, // Comparisons - `eq`: {inputSame, _kinds{ + `eq`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return rv.String() == val.(string) }, @@ -82,7 +90,29 @@ var funcs = map[string]listFuncInfo{ }}, // Integers - `gt`: {inputSame, _kinds{ + `gte`: {InputSame, Kinds{ + reflect.Int: func(rv reflect.Value, val interface{}) bool { + return rv.Int() >= val.(int64) + }, + reflect.Uint: func(rv reflect.Value, val interface{}) bool { + return rv.Uint() >= val.(uint64) + }, + reflect.Float64: func(rv reflect.Value, val interface{}) bool { + return rv.Float() >= val.(float64) + }, + }}, + `lte`: {InputSame, Kinds{ + reflect.Int: func(rv reflect.Value, val interface{}) bool { + return rv.Int() <= val.(int64) + }, + reflect.Uint: func(rv reflect.Value, val interface{}) bool { + return rv.Uint() <= val.(uint64) + }, + reflect.Float64: func(rv reflect.Value, val interface{}) bool { + return rv.Float() <= val.(float64) + }, + }}, + `gt`: {InputSame, Kinds{ reflect.Int: func(rv reflect.Value, val interface{}) bool { return rv.Int() > val.(int64) }, @@ -93,7 +123,7 @@ var funcs = map[string]listFuncInfo{ return rv.Float() > val.(float64) }, }}, - `lt`: {inputSame, _kinds{ + `lt`: {InputSame, Kinds{ reflect.Int: func(rv reflect.Value, val interface{}) bool { return rv.Int() < val.(int64) }, @@ -106,7 +136,7 @@ var funcs = map[string]listFuncInfo{ }}, // Slices, maps & strings - `len`: {inputInt, _kinds{ + `len`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() == val.(int) }, @@ -117,7 +147,7 @@ var funcs = map[string]listFuncInfo{ return rv.Len() == val.(int) }, }}, - `min`: {inputInt, _kinds{ + `min`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() >= val.(int) }, @@ -128,7 +158,7 @@ var funcs = map[string]listFuncInfo{ return rv.Len() >= val.(int) }, }}, - `max`: {inputInt, _kinds{ + `max`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() <= val.(int) }, diff --git a/rules_test.go b/rules_test.go index 46e2c6b..409c8fc 100644 --- a/rules_test.go +++ b/rules_test.go @@ -1,9 +1,29 @@ package validate import ( + "reflect" "testing" ) +func TestAddRule(t *testing.T) { + type s struct { + A string `validate:"custom"` + } + + AddRule(`custom`, nil, Kinds{ + reflect.String: func(rv reflect.Value, _ interface{}) bool { + return rv.String() == `custom` + }, + }) + + pass := s{`custom`} + + fail := s{`somethingelse`} + + check(t, pass, 0) + check(t, fail, 1) +} + func TestRuleRequired(t *testing.T) { type s struct { A *string `validate:"required"` @@ -18,9 +38,9 @@ func TestRuleRequired(t *testing.T) { } str := `` - var pass = s{&str, make([]int, 1), make([]int, 1), ` `, -1, 1, 0.01, ``, map[int]int{0: 1}} + pass := s{&str, make([]int, 1), make([]int, 1), ` `, -1, 1, 0.01, ``, map[int]int{0: 1}} - var fail = s{nil, nil, make([]int, 0), ``, 0, 0, 0.000, nil, nil} + fail := s{nil, nil, make([]int, 0), ``, 0, 0, 0.000, nil, nil} check(t, pass, 0) check(t, fail, 9) @@ -32,9 +52,9 @@ func TestRulePrefixSuffix(t *testing.T) { B string `validate:"suffix=@"` } - var pass = s{`#a`, `a@`} + pass := s{`#a`, `a@`} - var fail = s{`a#`, `@a`} + fail := s{`a#`, `@a`} check(t, pass, 0) check(t, fail, 2) @@ -45,12 +65,12 @@ func TestRuleContains(t *testing.T) { A string `validate:"contains=%"` } - var pass1 = s{`a%`} - var pass2 = s{`%a`} - var pass3 = s{`%`} - var pass4 = s{`a%a`} + pass1 := s{`a%`} + pass2 := s{`%a`} + pass3 := s{`%`} + pass4 := s{`a%a`} - var fail = s{`aa`} + fail := s{`aa`} check(t, pass1, 0) check(t, pass2, 0) @@ -64,11 +84,11 @@ func TestRuleRegexp(t *testing.T) { A string `validate:"regexp=^[0-9]$"` } - var pass1 = s{`0`} - var pass2 = s{`7`} + pass1 := s{`0`} + pass2 := s{`7`} - var fail1 = s{`A`} - var fail2 = s{`11`} + fail1 := s{`A`} + fail2 := s{`11`} check(t, pass1, 0) check(t, pass2, 0) @@ -83,16 +103,46 @@ func TestRuleEqGtLt(t *testing.T) { C uint `validate:"lt=1"` } - var pass = s{3, 100001, 0} + pass := s{3, 100001, 0} - var fail1 = s{2, 1e5, 1} - var fail2 = s{4, 9999, 2} + fail1 := s{2, 1e5, 1} + fail2 := s{4, 9999, 2} check(t, pass, 0) check(t, fail1, 3) check(t, fail2, 3) } +func TestRuleGteLte(t *testing.T) { + type s struct { + U uint `validate:"gte=0,lte=10"` + I int `validate:"gte=-10,lte=0"` + F float64 `validate:"gte=0,lte=10"` + } + + pass1 := s{0, -10, 0} + pass2 := s{10, 0, 10} + + // Uint + fail1 := s{11, -10, 0} + + // Int + fail2 := s{0, -11, 0} + fail3 := s{0, 1, 0} + + // Float + fail4 := s{0, -10, -0.0001} + fail5 := s{0, -10, 10.0001} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 1) + check(t, fail2, 1) + check(t, fail3, 1) + check(t, fail4, 1) + check(t, fail5, 1) +} + func TestLenMinMax(t *testing.T) { type s struct { A string `validate:"len=3"` @@ -100,10 +150,10 @@ func TestLenMinMax(t *testing.T) { C map[int]string `validate:"max=1"` } - var pass = s{`abc`, []int{1, 2}, nil} + pass := s{`abc`, []int{1, 2}, nil} - var fail1 = s{`ab`, []int{1}, map[int]string{1: `a`, 2: `b`}} - var fail2 = s{`abcd`, nil, nil} + fail1 := s{`ab`, []int{1}, map[int]string{1: `a`, 2: `b`}} + fail2 := s{`abcd`, nil, nil} check(t, pass, 0) check(t, fail1, 3) @@ -111,6 +161,7 @@ func TestLenMinMax(t *testing.T) { } func check(t *testing.T, c interface{}, errCount int) { + t.Helper() errs := Validate(c) if len(errs) != errCount { t.Errorf(`Case %T(%v) should get %d errors, but got %v`, c, c, errCount, errs) diff --git a/validate.go b/validate.go index 2317edb..59216c8 100644 --- a/validate.go +++ b/validate.go @@ -7,6 +7,12 @@ import ( "sync" ) +type ValidateValuer interface { + ValidateValue() interface{} +} + +var vvType = reflect.TypeOf((*ValidateValuer)(nil)).Elem() + // Validate validates a variable func Validate(v interface{}) []ValidationError { rv := reflect.ValueOf(v) @@ -159,12 +165,24 @@ func getRules(rt reflect.Type) []rule { for i := 0; i < rt.NumField(); i++ { ft := rt.Field(i) + + var valuer bool + if ft.Type.Implements(vvType) { + ft.Type = reflect.TypeOf(reflect.New(ft.Type).Interface().(ValidateValuer).ValidateValue()) + valuer = true + } + kind := simplifyKind(ft.Type.Kind()) - tags := strings.Split(ft.Tag.Get(`validate`), `,`) - rules = append(rules, getTagFuncs(i, ft, kind, tags)...) + var ptr bool + if kind == reflect.Ptr { + kind = ft.Type.Elem().Kind() + ptr = true + } + kind = simplifyKind(kind) - // TODO: Add validator interface + tags := strings.Split(ft.Tag.Get(`validate`), `,`) + rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr, valuer)...) switch kind { case reflect.Slice, reflect.Struct, reflect.Interface: @@ -185,7 +203,7 @@ func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) } } -func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string) []rule { +func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr, valuer bool) []rule { var rules []rule for _, v := range tags { if v == `` { @@ -198,22 +216,42 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string value = parts[1] } - if tag == `optional` { - rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { - check, _ := getTagFunc(`required`, ``, kind) - return nil, check(rv, nil) - }}) - continue + kind := kind + ptr := ptr + if ptr && (tag == `optional` || strings.TrimPrefix(tag, `!`) == `required`) { + kind = reflect.Ptr + ptr = false } - check, val := getTagFunc(tag, value, kind) + var f validateCheck - f := func(rv reflect.Value) ([]ValidationError, bool) { - if check(rv, val) { - return nil, true + if tag == `optional` { + f = func(rv reflect.Value) ([]ValidationError, bool) { + check, _ := getTagFunc(`required`, ``, kind) + return nil, check(rv, nil) + } + } else { + var not bool + if strings.HasPrefix(tag, `!`) { + not = true } - return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind) + + f = func(rv reflect.Value) ([]ValidationError, bool) { + if check(rv, val) == !not { + return nil, true + } + + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } + } + + if ptr { + f = depointerFunc(f, ft, tag, value) + } + if valuer { + f = valuerFunc(f) } rules = append(rules, rule{i, f}) @@ -222,7 +260,26 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string return rules } -func getTagFunc(tag, value string, kind reflect.Kind) (listFunc, interface{}) { +type validateCheck = func(rv reflect.Value) ([]ValidationError, bool) + +func depointerFunc(f validateCheck, ft reflect.StructField, tag, value string) validateCheck { + return func(rv reflect.Value) ([]ValidationError, bool) { + if rv.IsNil() { + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } + + return f(rv.Elem()) + } +} + +func valuerFunc(f validateCheck) validateCheck { + return func(rv reflect.Value) ([]ValidationError, bool) { + rv = reflect.ValueOf(rv.Interface().(ValidateValuer).ValidateValue()) + return f(rv) + } +} + +func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) { tagInfo, ok := funcs[tag] if !ok { panic(`Unknown validation ` + tag) diff --git a/validate_test.go b/validate_test.go index 035b0b9..96a83b4 100644 --- a/validate_test.go +++ b/validate_test.go @@ -10,10 +10,10 @@ func TestOptionalMultiple(t *testing.T) { B int `validate:"gt=3,lt=20"` } - var pass1 = s{``, 4} - var pass2 = s{`a`, 19} + pass1 := s{``, 4} + pass2 := s{`a`, 19} - var fail = s{`b`, 3} + fail := s{`b`, 3} check(t, pass1, 0) check(t, pass2, 0) @@ -34,11 +34,11 @@ func TestNesting(t *testing.T) { B sb } - var pass = s{[]sa{{`abc`}}, sb{12}} + pass := s{[]sa{{`abc`}}, sb{12}} - var fail1 = s{nil, sb{12}} - var fail2 = s{[]sa{{``}}, sb{12}} - var fail3 = s{[]sa{{``}}, sb{9}} + fail1 := s{nil, sb{12}} + fail2 := s{[]sa{{``}}, sb{12}} + fail3 := s{[]sa{{``}}, sb{9}} check(t, pass, 0) check(t, fail1, 1) @@ -76,3 +76,78 @@ func TestValidationErrorField(t *testing.T) { t.Fatal(`Expected errors to be A.B.C[0].D and A.B.C[2].D; got`, errs[0].Field, `and`, errs[1].Field) } } + +func TestNot(t *testing.T) { + type s struct { + A string `validate:"!len=3"` + B int `validate:"!eq=3"` + } + + pass1 := s{`ab`, 2} + pass2 := s{`abcd`, 4} + + fail := s{`abc`, 3} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail, 2) + + errs := Validate(fail) + if errs[0].Check != `!len` || errs[1].Check != `!eq` { + t.Errorf(`Checknames missing !, got "%s" and "%s"`, errs[0].Check, errs[1].Check) + } +} + +func TestPtr(t *testing.T) { + type s struct { + A *int `validate:"eq=3"` + B *int `validate:"optional,eq=3"` + } + + two := 2 + three := 3 + + pass1 := s{&three, &three} + pass2 := s{&three, nil} + + fail1 := s{&two, &two} + fail2 := s{nil, nil} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 2) + check(t, fail2, 1) +} + +type val struct { + Int int + Valid bool +} + +func (v val) ValidateValue() interface{} { + if !v.Valid { + return (*int)(nil) + } + + return &v.Int +} + +func TestValidateValuer(t *testing.T) { + type s struct { + A val `validate:"required,lt=2"` + B val `validate:"optional,eq=3"` + } + + pass1 := s{val{Valid: true}, val{}} + pass2 := s{val{Valid: true}, val{Int: 3, Valid: true}} + + fail1 := s{} + fail2 := s{val{Int: 2, Valid: true}, val{}} + fail3 := s{val{Valid: true}, val{Int: 2, Valid: true}} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 1) + check(t, fail2, 1) + check(t, fail3, 1) +}