diff --git a/internal/tests/mapconv/input.go b/internal/tests/mapconv/input.go index 3542e46..990121f 100644 --- a/internal/tests/mapconv/input.go +++ b/internal/tests/mapconv/input.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the MIT license, // which can be found in the LICENSE file. -//go:generate gencodec -type X -field-override Xo -out output.go +//go:generate gencodec -type X -field-override Xo -formats json,yaml -out output.go package mapconv diff --git a/internal/tests/nameclash/input.go b/internal/tests/nameclash/input.go index 9c874d4..4e997d6 100644 --- a/internal/tests/nameclash/input.go +++ b/internal/tests/nameclash/input.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the MIT license, // which can be found in the LICENSE file. -//go:generate gencodec -type Y -field-override Yo -out output.go +//go:generate gencodec -type Y -field-override Yo -formats json,yaml -out output.go package nameclash diff --git a/internal/tests/sliceconv/input.go b/internal/tests/sliceconv/input.go index ad6ba8a..7112339 100644 --- a/internal/tests/sliceconv/input.go +++ b/internal/tests/sliceconv/input.go @@ -2,7 +2,7 @@ // Use of this source code is governed by the MIT license, // which can be found in the LICENSE file. -//go:generate gencodec -type X -field-override Xo -out output.go +//go:generate gencodec -type X -field-override Xo -formats json,yaml -out output.go package sliceconv diff --git a/main.go b/main.go index 1e48609..f8fda42 100644 --- a/main.go +++ b/main.go @@ -121,13 +121,18 @@ import ( func main() { var ( pkgdir = flag.String("dir", ".", "input package") - output = flag.String("out", "-", "output file") - typename = flag.String("type", "", "type to generate") + output = flag.String("out", "-", "output file (default is stdout)") + typename = flag.String("type", "", "type to generate methods for") overrides = flag.String("field-override", "", "type to take field type replacements from") + formats = flag.String("formats", "json", `marshaling formats (e.g. "json,yaml")`) ) flag.Parse() - cfg := Config{Dir: *pkgdir, Type: *typename, FieldOverride: *overrides} + formatList := strings.Split(*formats, ",") + for i := range formatList { + formatList[i] = strings.TrimSpace(formatList[i]) + } + cfg := Config{Dir: *pkgdir, Type: *typename, FieldOverride: *overrides, Formats: formatList} code, err := cfg.process() if err != nil { fatal(err) @@ -144,10 +149,13 @@ func fatal(args ...interface{}) { os.Exit(1) } +var AllFormats = []string{"json", "yaml"} + type Config struct { - Dir string // input package directory - Type string // type to generate methods for - FieldOverride string // name of struct type for field overrides + Dir string // input package directory + Type string // type to generate methods for + FieldOverride string // name of struct type for field overrides + Formats []string // defaults to just "json", supported: "json", "yaml" Importer types.Importer FileSet *token.FileSet } @@ -159,6 +167,9 @@ func (cfg *Config) process() (code []byte, err error) { if cfg.Importer == nil { cfg.Importer = importer.Default() } + if cfg.Formats == nil { + cfg.Formats = []string{"json"} + } pkg, err := loadPackage(cfg) if err != nil { return nil, err @@ -183,7 +194,10 @@ func (cfg *Config) process() (code []byte, err error) { // Generate and format the output. Formatting uses goimports because it // removes unused imports. - code = genPackage(mtyp) + code, err = generate(mtyp, cfg) + if err != nil { + return nil, err + } opt := &imports.Options{Comments: true, TabIndent: true, TabWidth: 8} code, err = imports.Process("", code, opt) if err != nil { @@ -211,21 +225,29 @@ func loadPackage(cfg *Config) (*types.Package, error) { return prog.Package(pkg.ImportPath).Pkg, nil } -func genPackage(mtyp *marshalerType) []byte { +func generate(mtyp *marshalerType, cfg *Config) ([]byte, error) { w := new(bytes.Buffer) fmt.Fprintln(w, "// generated by gencodec, do not edit.\n") fmt.Fprintln(w, "package", mtyp.orig.Obj().Pkg().Name()) fmt.Fprintln(w) mtyp.scope.writeImportDecl(w) fmt.Fprintln(w) - writeFunction(w, mtyp.fs, genMarshalJSON(mtyp)) - fmt.Fprintln(w) - writeFunction(w, mtyp.fs, genUnmarshalJSON(mtyp)) - fmt.Fprintln(w) - writeFunction(w, mtyp.fs, genMarshalYAML(mtyp)) - fmt.Fprintln(w) - writeFunction(w, mtyp.fs, genUnmarshalYAML(mtyp)) - return w.Bytes() + for _, format := range cfg.Formats { + switch format { + case "json": + writeFunction(w, mtyp.fs, genMarshalJSON(mtyp)) + fmt.Fprintln(w) + writeFunction(w, mtyp.fs, genUnmarshalJSON(mtyp)) + case "yaml": + writeFunction(w, mtyp.fs, genMarshalYAML(mtyp)) + fmt.Fprintln(w) + writeFunction(w, mtyp.fs, genUnmarshalYAML(mtyp)) + default: + return nil, fmt.Errorf("unknown format: %q", format) + } + fmt.Fprintln(w) + } + return w.Bytes(), nil } // marshalerType represents the intermediate struct type used during marshaling. diff --git a/main_test.go b/main_test.go index 597460a..f2afa62 100644 --- a/main_test.go +++ b/main_test.go @@ -18,15 +18,15 @@ import ( // go generate ./internal/... func TestMapConv(t *testing.T) { - runGoldenTest(t, Config{Dir: "mapconv", Type: "X", FieldOverride: "Xo"}) + runGoldenTest(t, Config{Dir: "mapconv", Type: "X", FieldOverride: "Xo", Formats: AllFormats}) } func TestSliceConv(t *testing.T) { - runGoldenTest(t, Config{Dir: "sliceconv", Type: "X", FieldOverride: "Xo"}) + runGoldenTest(t, Config{Dir: "sliceconv", Type: "X", FieldOverride: "Xo", Formats: AllFormats}) } func TestNameClash(t *testing.T) { - runGoldenTest(t, Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo"}) + runGoldenTest(t, Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo", Formats: AllFormats}) } func runGoldenTest(t *testing.T, cfg Config) {