diff --git a/jsonenums.go b/jsonenums.go index 60346ee..7dcbfc4 100644 --- a/jsonenums.go +++ b/jsonenums.go @@ -82,6 +82,9 @@ var ( typeNames = flag.String("type", "", "comma-separated list of type names; must be set") outputPrefix = flag.String("prefix", "", "prefix to be added to the output file") outputSuffix = flag.String("suffix", "_jsonenums", "suffix to be added to the output file") + + // BSON Support + bsonMode = flag.Bool("bson", false, "enable BSON-mode and generate BSON output in addition to JSON") ) func main() { @@ -109,14 +112,26 @@ func main() { log.Fatalf("parsing package: %v", err) } + funcPrefixes := []string{"JSON"} + imports := []string{"encoding/json"} + if *bsonMode { + funcPrefixes = append(funcPrefixes, "BSON") + imports = append(imports, "gopkg.in/mgo.v2/bson") + } var analysis = struct { Command string PackageName string TypesAndValues map[string][]string + + // ["JSON", "BSON"] + FuncPrefixes []string + Imports []string }{ Command: strings.Join(os.Args[1:], " "), PackageName: pkg.Name, TypesAndValues: make(map[string][]string), + FuncPrefixes: funcPrefixes, + Imports: imports, } // Run generate for each type. diff --git a/template.go b/template.go index 7c14f0c..ad2bf5b 100644 --- a/template.go +++ b/template.go @@ -6,17 +6,24 @@ package main -import "text/template" +import ( + "strings" + "text/template" +) -var generatedTmpl = template.Must(template.New("generated").Parse(` +var generatedTmpl = template.Must(template.New("generated"). + Funcs(template.FuncMap{"toLower": strings.ToLower}).Parse(` // generated by jsonenums {{.Command}}; DO NOT EDIT package {{.PackageName}} import ( - "encoding/json" +{{range .Imports}} + "{{.}}" +{{end}} "fmt" ) +{{$funcPrefixes := .FuncPrefixes}} {{range $typename, $values := .TypesAndValues}} @@ -42,6 +49,9 @@ func init() { } } +{{range $_, $funcPrefix := $funcPrefixes}} + +{{if eq $funcPrefix "JSON"}} // MarshalJSON is generated so {{$typename}} satisfies json.Marshaler. func (r {{$typename}}) MarshalJSON() ([]byte, error) { if s, ok := interface{}(r).(fmt.Stringer); ok { @@ -51,7 +61,7 @@ func (r {{$typename}}) MarshalJSON() ([]byte, error) { if !ok { return nil, fmt.Errorf("invalid {{$typename}}: %d", r) } - return json.Marshal(s) + return {{$funcPrefix | toLower}}.Marshal(s) } // UnmarshalJSON is generated so {{$typename}} satisfies json.Unmarshaler. @@ -67,6 +77,37 @@ func (r *{{$typename}}) UnmarshalJSON(data []byte) error { *r = v return nil } +{{else if eq $funcPrefix "BSON"}} +// GetBSON is generated so {{$typename}} satisfies bson.Marshaler. +func (r {{$typename}}) GetBSON() (interface{}, error) { + var s string + if stringer, ok := interface{}(r).(fmt.Stringer); ok { + s = stringer.String() + } + s, ok := _{{$typename}}ValueToName[r] + if !ok { + return nil, fmt.Errorf("invalid {{$typename}}: %d", r) + } + return s, nil +} + +// SetBSON is generated so {{$typename}} satisfies bson.Unmarshaler. +func (r *{{$typename}}) SetBSON(raw bson.Raw) error { + var s []byte + if err := raw.Unmarshal(&s); err != nil { + return fmt.Errorf("{{$typename}} should be a string, got %s", raw) + } + v, ok := _{{$typename}}NameToValue[string(s)] + if !ok { + return fmt.Errorf("invalid {{$typename}} %q", s) + } + *r = v + return nil +} + +{{end}} + +{{end}} {{end}} `))