diff --git a/input.go b/input.go index 0acd88d..b6edb38 100644 --- a/input.go +++ b/input.go @@ -6,15 +6,21 @@ import ( "strconv" ) -func inputInt(kind reflect.Kind, value string) interface{} { - return int(inputSame(reflect.Int, value).(int64)) +// InputFunc is a function that converts the parameter of a validation rule to the desired type +type InputFunc func(reflect.Kind, string) interface{} + +// InputInt always returns an int +func InputInt(kind reflect.Kind, value string) interface{} { + return int(InputSame(reflect.Int, value).(int64)) } -func inputRegexp(kind reflect.Kind, value string) interface{} { +// InputRegexp always returns a compiled regular expression +func InputRegexp(kind reflect.Kind, value string) interface{} { return regexp.MustCompile(value) } -func inputSame(kind reflect.Kind, value string) interface{} { +// InputSame returns the type matching the field +func InputSame(kind reflect.Kind, value string) interface{} { var val interface{} var err error diff --git a/rules.go b/rules.go index 69dddd0..685d0d5 100644 --- a/rules.go +++ b/rules.go @@ -6,17 +6,25 @@ import ( "strings" ) -type listFunc func(reflect.Value, interface{}) bool +// ValidationFunc is a function to validate a field +type ValidationFunc func(reflect.Value, interface{}) bool + +// Kinds is a map with validation funcs for each reflect.Kind +type Kinds map[reflect.Kind]ValidationFunc + type listFuncInfo struct { - inputFunc func(reflect.Kind, string) interface{} - kinds _kinds + inputFunc InputFunc + kinds Kinds } -type _kinds map[reflect.Kind]listFunc +// AddRule adds a rule to the list of validation functions +func AddRule(name string, inputFunc InputFunc, kinds Kinds) { + funcs[name] = listFuncInfo{inputFunc, kinds} +} // nolint: dupl var funcs = map[string]listFuncInfo{ - `required`: {nil, _kinds{ + `required`: {nil, Kinds{ reflect.Ptr: func(rv reflect.Value, _ interface{}) bool { return !rv.IsNil() }, @@ -44,29 +52,29 @@ var funcs = map[string]listFuncInfo{ }}, // Strings - `prefix`: {inputSame, _kinds{ + `prefix`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.HasPrefix(rv.String(), val.(string)) }, }}, - `suffix`: {inputSame, _kinds{ + `suffix`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.HasSuffix(rv.String(), val.(string)) }, }}, - `contains`: {inputSame, _kinds{ + `contains`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return strings.Contains(rv.String(), val.(string)) }, }}, - `regexp`: {inputRegexp, _kinds{ + `regexp`: {InputRegexp, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return val.(*regexp.Regexp).MatchString(rv.String()) }, }}, // Comparisons - `eq`: {inputSame, _kinds{ + `eq`: {InputSame, Kinds{ reflect.String: func(rv reflect.Value, val interface{}) bool { return rv.String() == val.(string) }, @@ -82,7 +90,7 @@ var funcs = map[string]listFuncInfo{ }}, // Integers - `gt`: {inputSame, _kinds{ + `gt`: {InputSame, Kinds{ reflect.Int: func(rv reflect.Value, val interface{}) bool { return rv.Int() > val.(int64) }, @@ -93,7 +101,7 @@ var funcs = map[string]listFuncInfo{ return rv.Float() > val.(float64) }, }}, - `lt`: {inputSame, _kinds{ + `lt`: {InputSame, Kinds{ reflect.Int: func(rv reflect.Value, val interface{}) bool { return rv.Int() < val.(int64) }, @@ -106,7 +114,7 @@ var funcs = map[string]listFuncInfo{ }}, // Slices, maps & strings - `len`: {inputInt, _kinds{ + `len`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() == val.(int) }, @@ -117,7 +125,7 @@ var funcs = map[string]listFuncInfo{ return rv.Len() == val.(int) }, }}, - `min`: {inputInt, _kinds{ + `min`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() >= val.(int) }, @@ -128,7 +136,7 @@ var funcs = map[string]listFuncInfo{ return rv.Len() >= val.(int) }, }}, - `max`: {inputInt, _kinds{ + `max`: {InputInt, Kinds{ reflect.Slice: func(rv reflect.Value, val interface{}) bool { return rv.Len() <= val.(int) }, diff --git a/rules_test.go b/rules_test.go index 46e2c6b..32d632d 100644 --- a/rules_test.go +++ b/rules_test.go @@ -1,9 +1,29 @@ package validate import ( + "reflect" "testing" ) +func TestAddRule(t *testing.T) { + type s struct { + A string `validate:"custom"` + } + + AddRule(`custom`, nil, Kinds{ + reflect.String: func(rv reflect.Value, _ interface{}) bool { + return rv.String() == `custom` + }, + }) + + var pass = s{`custom`} + + var fail = s{`somethingelse`} + + check(t, pass, 0) + check(t, fail, 1) +} + func TestRuleRequired(t *testing.T) { type s struct { A *string `validate:"required"` diff --git a/validate.go b/validate.go index 2317edb..f1cc1aa 100644 --- a/validate.go +++ b/validate.go @@ -222,7 +222,7 @@ func getTagFuncs(i int, ft reflect.StructField, kind reflect.Kind, tags []string return rules } -func getTagFunc(tag, value string, kind reflect.Kind) (listFunc, interface{}) { +func getTagFunc(tag, value string, kind reflect.Kind) (ValidationFunc, interface{}) { tagInfo, ok := funcs[tag] if !ok { panic(`Unknown validation ` + tag)