package validate import ( "reflect" "strconv" "strings" "sync" ) type ValidateValuer interface { ValidateValue() interface{} } var vvType = reflect.TypeOf((*ValidateValuer)(nil)).Elem() // Validate validates a variable func Validate(v interface{}) []ValidationError { rv := reflect.ValueOf(v) return validate(rv) } // Field is always an array/slice index or struct field type Field struct { Index *int Field *reflect.StructField } // Fields is a list of Field type Fields []Field // ToString converts a list of fields to a string using the given struct tag func (f Fields) ToString(tag string) string { var field string for k, v := range f { if v.Index != nil { field += `[` + strconv.Itoa(*v.Index) + `]` continue } if k != 0 { field += `.` } var name string if tag != `` { name = v.Field.Tag.Get(tag) if idx := strings.IndexRune(name, ','); idx != -1 { name = name[:idx] } } if name == `` { name = v.Field.Name } field += name } return field } func (f Fields) String() string { return f.ToString(``) } // ValidationError contains information about a failed validation type ValidationError struct { Field Fields Check string Value string } func (e ValidationError) String() string { var val string if e.Value != `` { val = `=` + e.Value } return e.Field.String() + `: ` + e.Check + val } func prependErrs(f Field, errs []ValidationError) []ValidationError { for k := range errs { fields := make([]Field, len(errs[k].Field)+1) fields[0] = f for k, v := range errs[k].Field { fields[k+1] = v } errs[k].Field = fields } return errs } func validate(rv reflect.Value) []ValidationError { for rv.Kind() == reflect.Ptr { if rv.IsNil() { return nil } rv = rv.Elem() } if rv.Kind() == reflect.Array || rv.Kind() == reflect.Slice { var errs []ValidationError for i := 0; i < rv.Len(); i++ { newErrs := validate(rv.Index(i)) index := i errs = append(errs, prependErrs(Field{&index, nil}, newErrs)...) } return errs } if rv.Kind() != reflect.Struct { return nil } var errs []ValidationError skip := -1 for _, rule := range getCachedRules(rv.Type()) { if skip == rule.index { continue } err, cont := rule.f(rv.Field(rule.index)) errs = append(errs, err...) if !cont { skip = rule.index } } return errs } var cache = struct { sync.Mutex data map[reflect.Type][]rule }{data: map[reflect.Type][]rule{}} func getCachedRules(rt reflect.Type) []rule { cache.Lock() defer cache.Unlock() rules, ok := cache.data[rt] if !ok { rules = getRules(rt) cache.data[rt] = rules } return rules } type rule struct { index int f func(reflect.Value) ([]ValidationError, bool) } func getRules(rt reflect.Type) []rule { var rules []rule for i := 0; i < rt.NumField(); i++ { ft := rt.Field(i) var valuer bool if ft.Type.Implements(vvType) { ft.Type = reflect.TypeOf(reflect.New(ft.Type).Interface().(ValidateValuer).ValidateValue()) valuer = true } kind := simplifyKind(ft.Type.Kind()) var ptr bool if kind == reflect.Ptr { kind = ft.Type.Elem().Kind() ptr = true } kind = simplifyKind(kind) tags := strings.Split(ft.Tag.Get(`validate`), `,`) rules = append(rules, getTagFuncs(i, ft, kind, tags, ptr, valuer)...) switch kind { case reflect.Slice, reflect.Struct, reflect.Interface: rules = append(rules, rule{i, nest(ft)}) } } return rules } func nest(ft reflect.StructField) func(reflect.Value) ([]ValidationError, bool) { return func(rv reflect.Value) ([]ValidationError, bool) { errs := validate(rv) if errs != nil { return prependErrs(Field{nil, &ft}, errs), false } return nil, true } } func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string, ptr, valuer bool) []rule { var rules []rule for _, v := range tags { if v == `` { continue } parts := strings.SplitN(v, `=`, 2) tag, value := parts[0], `` if len(parts) > 1 { value = parts[1] } kind := kind ptr := ptr if ptr && (tag == `optional` || strings.TrimPrefix(tag, `!`) == `required`) { kind = reflect.Ptr ptr = false } var f validateCheck if tag == `optional` { f = func(rv reflect.Value) ([]ValidationError, bool) { check, _ := getTagFunc(`required`, ``, kind) return nil, check(rv, nil) } } else { var not bool if strings.HasPrefix(tag, `!`) { not = true } check, val := getTagFunc(strings.TrimPrefix(tag, `!`), value, kind) f = func(rv reflect.Value) ([]ValidationError, bool) { if check(rv, val) == !not { return nil, true } return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false } } if ptr { f = depointerFunc(f, ft, tag, value) } if valuer { f = valuerFunc(f) } rules = append(rules, rule{i, f}) } return rules } type validateCheck = func(rv reflect.Value) ([]ValidationError, bool) func depointerFunc(f validateCheck, ft reflect.StructField, tag, value string) validateCheck { return func(rv reflect.Value) ([]ValidationError, bool) { if rv.IsNil() { return []ValidationError{{Field: []Field{{nil, &ft}}, Check: tag, Value: value}}, false } return f(rv.Elem()) } } func valuerFunc(f validateCheck) validateCheck { return func(rv reflect.Value) ([]ValidationError, bool) { rv = reflect.ValueOf(rv.Interface().(ValidateValuer).ValidateValue()) return f(rv) } } func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) { tagInfo, ok := funcs[tag] if !ok { panic(`Unknown validation ` + tag) } check, ok := tagInfo.kinds[kind] if !ok { panic(`Validation ` + tag + ` does not support ` + kind.String()) } var val interface{} if value != `` && tagInfo.inputFunc != nil { val = tagInfo.inputFunc(kind, value) } return check, val } func simplifyKind(kind reflect.Kind) reflect.Kind { switch kind { case reflect.Array: return reflect.Slice case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return reflect.Int case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return reflect.Uint case reflect.Float32: return reflect.Float64 } return kind }