From 5778878a2a428fb89a707e35098749352243e9f6 Mon Sep 17 00:00:00 2001 From: NiseVoid Date: Tue, 16 Oct 2018 20:18:39 +0200 Subject: [PATCH] Initial commit --- bench_test.go | 35 +++++++ input.go | 41 ++++++++ rules.go | 142 +++++++++++++++++++++++++++ rules_test.go | 118 ++++++++++++++++++++++ validate.go | 251 +++++++++++++++++++++++++++++++++++++++++++++++ validate_test.go | 78 +++++++++++++++ 6 files changed, 665 insertions(+) create mode 100644 bench_test.go create mode 100644 input.go create mode 100644 rules.go create mode 100644 rules_test.go create mode 100644 validate.go create mode 100644 validate_test.go diff --git a/bench_test.go b/bench_test.go new file mode 100644 index 0000000..9e13b8d --- /dev/null +++ b/bench_test.go @@ -0,0 +1,35 @@ +package validate + +import ( + "testing" +) + +type testCase struct { + A string `validate:"required"` + B int `validate:"gt=9"` + C float64 `validate:"eq=10"` + D uint64 `validate:"required,lt=9"` +} + +var case1 = testCase{`abc`, 10, 10.0, 5} +var case2 = testCase{`abc`, 8, 10.0, 0} + +func BenchmarkValidateCorrect(b *testing.B) { + var errs []ValidationError + for i := 0; i < b.N; i++ { + errs = Validate(case1) + if len(errs) != 0 { + b.Fatal(`This case should pass but got`, errs) + } + } +} + +func BenchmarkValidateIncorrect(b *testing.B) { + var errs []ValidationError + for i := 0; i < b.N; i++ { + errs = Validate(case2) + if len(errs) != 2 { + b.Fatal(`This case should get 2 errors but got`, len(errs)) + } + } +} diff --git a/input.go b/input.go new file mode 100644 index 0000000..0acd88d --- /dev/null +++ b/input.go @@ -0,0 +1,41 @@ +package validate + +import ( + "reflect" + "regexp" + "strconv" +) + +func inputInt(kind reflect.Kind, value string) interface{} { + return int(inputSame(reflect.Int, value).(int64)) +} + +func inputRegexp(kind reflect.Kind, value string) interface{} { + return regexp.MustCompile(value) +} + +func inputSame(kind reflect.Kind, value string) interface{} { + var val interface{} + var err error + + switch kind { + case reflect.String: + val = value + case reflect.Int: + val, err = strconv.ParseInt(value, 10, 64) + case reflect.Uint: + val, err = strconv.ParseUint(value, 10, 64) + case reflect.Float64: + val, err = strconv.ParseFloat(value, 64) + case reflect.Bool: + val, err = strconv.ParseBool(value) + default: + panic(`Cannot pass value to checks on type ` + kind.String()) + } + + if err != nil { + panic(`Invalid value "` + value + `"`) + } + + return val +} diff --git a/rules.go b/rules.go new file mode 100644 index 0000000..69dddd0 --- /dev/null +++ b/rules.go @@ -0,0 +1,142 @@ +package validate + +import ( + "reflect" + "regexp" + "strings" +) + +type listFunc func(reflect.Value, interface{}) bool +type listFuncInfo struct { + inputFunc func(reflect.Kind, string) interface{} + kinds _kinds +} + +type _kinds map[reflect.Kind]listFunc + +// nolint: dupl +var funcs = map[string]listFuncInfo{ + `required`: {nil, _kinds{ + reflect.Ptr: func(rv reflect.Value, _ interface{}) bool { + return !rv.IsNil() + }, + reflect.Interface: func(rv reflect.Value, _ interface{}) bool { + return !rv.IsNil() + }, + reflect.Slice: func(rv reflect.Value, _ interface{}) bool { + return !rv.IsNil() && rv.Len() > 0 + }, + reflect.Map: func(rv reflect.Value, _ interface{}) bool { + return !rv.IsNil() && rv.Len() > 0 + }, + reflect.String: func(rv reflect.Value, _ interface{}) bool { + return rv.String() != `` + }, + reflect.Int: func(rv reflect.Value, _ interface{}) bool { + return rv.Int() != 0 + }, + reflect.Uint: func(rv reflect.Value, _ interface{}) bool { + return rv.Uint() != 0 + }, + reflect.Float64: func(rv reflect.Value, _ interface{}) bool { + return rv.Float() != 0 + }, + }}, + + // Strings + `prefix`: {inputSame, _kinds{ + reflect.String: func(rv reflect.Value, val interface{}) bool { + return strings.HasPrefix(rv.String(), val.(string)) + }, + }}, + `suffix`: {inputSame, _kinds{ + reflect.String: func(rv reflect.Value, val interface{}) bool { + return strings.HasSuffix(rv.String(), val.(string)) + }, + }}, + `contains`: {inputSame, _kinds{ + reflect.String: func(rv reflect.Value, val interface{}) bool { + return strings.Contains(rv.String(), val.(string)) + }, + }}, + `regexp`: {inputRegexp, _kinds{ + reflect.String: func(rv reflect.Value, val interface{}) bool { + return val.(*regexp.Regexp).MatchString(rv.String()) + }, + }}, + + // Comparisons + `eq`: {inputSame, _kinds{ + reflect.String: func(rv reflect.Value, val interface{}) bool { + return rv.String() == val.(string) + }, + reflect.Int: func(rv reflect.Value, val interface{}) bool { + return rv.Int() == val.(int64) + }, + reflect.Uint: func(rv reflect.Value, val interface{}) bool { + return rv.Uint() == val.(uint64) + }, + reflect.Float64: func(rv reflect.Value, val interface{}) bool { + return rv.Float() == val.(float64) + }, + }}, + + // Integers + `gt`: {inputSame, _kinds{ + reflect.Int: func(rv reflect.Value, val interface{}) bool { + return rv.Int() > val.(int64) + }, + reflect.Uint: func(rv reflect.Value, val interface{}) bool { + return rv.Uint() > val.(uint64) + }, + reflect.Float64: func(rv reflect.Value, val interface{}) bool { + return rv.Float() > val.(float64) + }, + }}, + `lt`: {inputSame, _kinds{ + reflect.Int: func(rv reflect.Value, val interface{}) bool { + return rv.Int() < val.(int64) + }, + reflect.Uint: func(rv reflect.Value, val interface{}) bool { + return rv.Uint() < val.(uint64) + }, + reflect.Float64: func(rv reflect.Value, val interface{}) bool { + return rv.Float() < val.(float64) + }, + }}, + + // Slices, maps & strings + `len`: {inputInt, _kinds{ + reflect.Slice: func(rv reflect.Value, val interface{}) bool { + return rv.Len() == val.(int) + }, + reflect.Map: func(rv reflect.Value, val interface{}) bool { + return rv.Len() == val.(int) + }, + reflect.String: func(rv reflect.Value, val interface{}) bool { + return rv.Len() == val.(int) + }, + }}, + `min`: {inputInt, _kinds{ + reflect.Slice: func(rv reflect.Value, val interface{}) bool { + return rv.Len() >= val.(int) + }, + reflect.Map: func(rv reflect.Value, val interface{}) bool { + return rv.Len() >= val.(int) + }, + reflect.String: func(rv reflect.Value, val interface{}) bool { + return rv.Len() >= val.(int) + }, + }}, + `max`: {inputInt, _kinds{ + reflect.Slice: func(rv reflect.Value, val interface{}) bool { + return rv.Len() <= val.(int) + }, + reflect.Map: func(rv reflect.Value, val interface{}) bool { + return rv.Len() <= val.(int) + }, + reflect.String: func(rv reflect.Value, val interface{}) bool { + return rv.Len() <= val.(int) + }, + }}, +} diff --git a/rules_test.go b/rules_test.go new file mode 100644 index 0000000..46e2c6b --- /dev/null +++ b/rules_test.go @@ -0,0 +1,118 @@ +package validate + +import ( + "testing" +) + +func TestRuleRequired(t *testing.T) { + type s struct { + A *string `validate:"required"` + B []int `validate:"required"` + C []int `validate:"required"` + D string `validate:"required"` + E int `validate:"required"` + F uint `validate:"required"` + G float64 `validate:"required"` + H interface{} `validate:"required"` + I map[int]int `validate:"required"` + } + + str := `` + var pass = s{&str, make([]int, 1), make([]int, 1), ` `, -1, 1, 0.01, ``, map[int]int{0: 1}} + + var fail = s{nil, nil, make([]int, 0), ``, 0, 0, 0.000, nil, nil} + + check(t, pass, 0) + check(t, fail, 9) +} + +func TestRulePrefixSuffix(t *testing.T) { + type s struct { + A string `validate:"prefix=#"` + B string `validate:"suffix=@"` + } + + var pass = s{`#a`, `a@`} + + var fail = s{`a#`, `@a`} + + check(t, pass, 0) + check(t, fail, 2) +} + +func TestRuleContains(t *testing.T) { + type s struct { + A string `validate:"contains=%"` + } + + var pass1 = s{`a%`} + var pass2 = s{`%a`} + var pass3 = s{`%`} + var pass4 = s{`a%a`} + + var fail = s{`aa`} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, pass3, 0) + check(t, pass4, 0) + check(t, fail, 1) +} + +func TestRuleRegexp(t *testing.T) { + type s struct { + A string `validate:"regexp=^[0-9]$"` + } + + var pass1 = s{`0`} + var pass2 = s{`7`} + + var fail1 = s{`A`} + var fail2 = s{`11`} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail1, 1) + check(t, fail2, 1) +} + +func TestRuleEqGtLt(t *testing.T) { + type s struct { + A int `validate:"eq=3"` + B float64 `validate:"gt=1e5"` + C uint `validate:"lt=1"` + } + + var pass = s{3, 100001, 0} + + var fail1 = s{2, 1e5, 1} + var fail2 = s{4, 9999, 2} + + check(t, pass, 0) + check(t, fail1, 3) + check(t, fail2, 3) +} + +func TestLenMinMax(t *testing.T) { + type s struct { + A string `validate:"len=3"` + B []int `validate:"min=2"` + C map[int]string `validate:"max=1"` + } + + var pass = s{`abc`, []int{1, 2}, nil} + + var fail1 = s{`ab`, []int{1}, map[int]string{1: `a`, 2: `b`}} + var fail2 = s{`abcd`, nil, nil} + + check(t, pass, 0) + check(t, fail1, 3) + check(t, fail2, 2) +} + +func check(t *testing.T, c interface{}, errCount int) { + errs := Validate(c) + if len(errs) != errCount { + t.Errorf(`Case %T(%v) should get %d errors, but got %v`, c, c, errCount, errs) + } +} diff --git a/validate.go b/validate.go new file mode 100644 index 0000000..63ed1a4 --- /dev/null +++ b/validate.go @@ -0,0 +1,251 @@ +package validate + +import ( + "reflect" + "strconv" + "strings" + "sync" +) + +// Validate validates a variable +func Validate(v interface{}) []ValidationError { + rv := reflect.ValueOf(v) + + return validate(rv) +} + +// Field is always an array/slice index or struct field +type Field struct { + Index *int + Field *reflect.StructField +} + +// Fields is a list of Field +type Fields []Field + +// ToString converts a list of fields to a string using the given struct tag +func (f Fields) ToString(tag string) string { + var field string + for k, v := range f { + if v.Index != nil { + field += `[` + strconv.Itoa(*v.Index) + `]` + continue + } + + if k != 0 { + field += `.` + } + + var name string + if tag != `` { + name = v.Field.Tag.Get(tag) + } + if name == `` { + name = v.Field.Name + } + + field += name + } + + return field +} + +func (f Fields) String() string { + return f.ToString(``) +} + +// ValidationError contains information about a failed validation +type ValidationError struct { + Field Fields + Check string + Value string +} + +func (e ValidationError) String() string { + var val string + if e.Value != `` { + val = `=` + e.Value + } + return e.Field.String() + `: ` + e.Check + val +} + +func prependErrs(f Field, errs []ValidationError) []ValidationError { + for k := range errs { + fields := make([]Field, len(errs[k].Field)+1) + + fields[0] = f + for k, v := range errs[k].Field { + fields[k+1] = v + } + + errs[k].Field = fields + } + + return errs +} + +func validate(rv reflect.Value) []ValidationError { + for rv.Kind() == reflect.Ptr { + if rv.IsNil() { + return nil + } + + rv = rv.Elem() + } + + if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice { + var errs []ValidationError + for i := 0; i < rv.Len(); i++ { + newErrs := validate(rv.Index(i)) + + index := i + errs = append(errs, prependErrs(Field{&index, nil}, newErrs)...) + } + return errs + } + + if rv.Kind() != reflect.Struct { + return nil + } + + var errs []ValidationError + skip := -1 + for _, rule := range getCachedRules(rv.Type()) { + if skip == rule.index { + continue + } + + err, cont := rule.f(rv.Field(rule.index)) + errs = append(errs, err...) + + if !cont { + skip = rule.index + } + } + + return errs +} + +var cache = struct { + sync.Mutex + data map[reflect.Type][]rule +}{data: map[reflect.Type][]rule{}} + +func getCachedRules(rt reflect.Type) []rule { + cache.Lock() + defer cache.Unlock() + + rules, ok := cache.data[rt] + if !ok { + rules = getRules(rt) + + cache.data[rt] = rules + } + + return rules +} + +type rule struct { + index int + f func(reflect.Value) ([]ValidationError, bool) +} + +func getRules(rt reflect.Type) []rule { + var rules []rule + + for i := 0; i < rt.NumField(); i++ { + ft := rt.Field(i) + kind := simplifyKind(ft.Type.Kind()) + + tags := strings.Split(ft.Tag.Get(`validate`), `,`) + rules = append(rules, getTagFuncs(i, ft, kind, tags)...) + + // TODO: Add validator interface + + switch kind { + case reflect.Slice, reflect.Struct, reflect.Interface: + rules = append(rules, rule{i, nest(ft)}) + } + } + + return rules +} + +func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) { + return func(rv reflect.Value) ([]ValidationError, bool) { + errs := validate(rv) + if errs != nil { + return prependErrs(Field{nil, &ft}, errs), false + } + return nil, true + } +} + +func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string) []rule { + var rules []rule + for _, v := range tags { + if v == `` { + continue + } + parts := strings.SplitN(v, `=`, 2) + + tag, value := parts[0], `` + if len(parts) > 1 { + value = parts[1] + } + + if tag == `optional` { + rules = append(rules, rule{i, func(rv reflect.Value) ([]ValidationError, bool) { + check, _ := getTagFunc(`required`, ``, kind) + return nil, check(rv, nil) + }}) + continue + } + + check, val := getTagFunc(tag, value, kind) + + f := func(rv reflect.Value) ([]ValidationError, bool) { + if check(rv, val) { + return nil, true + } + + return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false + } + + rules = append(rules, rule{i, f}) + } + + return rules +} + +func getTagFunc(tag, value string, kind reflect.Kind) (listFunc, interface{}) { + tagInfo, ok := funcs[tag] + if !ok { + panic(`Unknown validation ` + tag) + } + check, ok := tagInfo.kinds[kind] + if !ok { + panic(`Validation ` + tag + ` does not support ` + kind.String()) + } + + var val interface{} + if value != `` && tagInfo.inputFunc != nil { + val = tagInfo.inputFunc(kind, value) + } + + return check, val +} + +func simplifyKind(kind reflect.Kind) reflect.Kind { + switch kind { + case reflect.Array: + return reflect.Slice + case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return reflect.Int + case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return reflect.Uint + case reflect.Float32: + return reflect.Float64 + } + return kind +} diff --git a/validate_test.go b/validate_test.go new file mode 100644 index 0000000..035b0b9 --- /dev/null +++ b/validate_test.go @@ -0,0 +1,78 @@ +package validate + +import ( + "testing" +) + +func TestOptionalMultiple(t *testing.T) { + type s struct { + A string `validate:"optional,eq=a"` + B int `validate:"gt=3,lt=20"` + } + + var pass1 = s{``, 4} + var pass2 = s{`a`, 19} + + var fail = s{`b`, 3} + + check(t, pass1, 0) + check(t, pass2, 0) + check(t, fail, 2) +} + +func TestNesting(t *testing.T) { + type sa struct { + AA string `validate:"required"` + } + + type sb struct { + BA int `validate:"gt=10"` + } + + type s struct { + A []sa `validate:"required"` + B sb + } + + var pass = s{[]sa{{`abc`}}, sb{12}} + + var fail1 = s{nil, sb{12}} + var fail2 = s{[]sa{{``}}, sb{12}} + var fail3 = s{[]sa{{``}}, sb{9}} + + check(t, pass, 0) + check(t, fail1, 1) + check(t, fail2, 1) + check(t, fail3, 2) +} + +func TestValidationErrorField(t *testing.T) { + type sc struct { + D int `validate:"eq=1"` + } + type sb struct { + C []sc + } + type sa struct { + B sb + } + type s struct { + A sa + } + + errs := Validate(s{ + A: sa{ + B: sb{ + C: []sc{ + {D: 0}, + {D: 1}, + {D: 2}, + }, + }, + }, + }) + + if errs[0].Field.String() != `A.B.C[0].D` || errs[1].Field.String() != `A.B.C[2].D` { + t.Fatal(`Expected errors to be A.B.C[0].D and A.B.C[2].D; got`, errs[0].Field, `and`, errs[1].Field) + } +}