diff --git a/genmethod.go b/genmethod.go index 69a66a8..0339f3d 100644 --- a/genmethod.go +++ b/genmethod.go @@ -23,19 +23,21 @@ var ( ) type marshalMethod struct { - mtyp *marshalerType - scope *funcScope + mtyp *marshalerType + scope *funcScope + isUnmarshal bool // cached identifiers for map, slice conversions iterKey, iterVal Var } -func newMarshalMethod(mtyp *marshalerType) *marshalMethod { +func newMarshalMethod(mtyp *marshalerType, isUnmarshal bool) *marshalMethod { s := newFuncScope(mtyp.scope) return &marshalMethod{ - mtyp: mtyp, - scope: newFuncScope(mtyp.scope), - iterKey: Name(s.newIdent("k")), - iterVal: Name(s.newIdent("v")), + mtyp: mtyp, + scope: newFuncScope(mtyp.scope), + isUnmarshal: isUnmarshal, + iterKey: Name(s.newIdent("k")), + iterVal: Name(s.newIdent("v")), } } @@ -47,8 +49,8 @@ func writeFunction(w io.Writer, fs *token.FileSet, fn Function) { // genUnmarshalJSON generates the UnmarshalJSON method. func genUnmarshalJSON(mtyp *marshalerType) Function { var ( - m = newMarshalMethod(mtyp) - recv = m.receiver(true) + m = newMarshalMethod(mtyp, true) + recv = m.receiver() input = Name(m.scope.newIdent("input")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON")) dec = Name(m.scope.newIdent("dec")) @@ -79,8 +81,8 @@ func genUnmarshalJSON(mtyp *marshalerType) Function { // genMarshalJSON generates the MarshalJSON method. func genMarshalJSON(mtyp *marshalerType) Function { var ( - m = newMarshalMethod(mtyp) - recv = m.receiver(false) + m = newMarshalMethod(mtyp, false) + recv = m.receiver() intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON")) enc = Name(m.scope.newIdent("enc")) json = Name(m.scope.parent.packageName("encoding/json")) @@ -107,8 +109,8 @@ func genMarshalJSON(mtyp *marshalerType) Function { // genUnmarshalYAML generates the UnmarshalYAML method. func genUnmarshalYAML(mtyp *marshalerType) Function { var ( - m = newMarshalMethod(mtyp) - recv = m.receiver(true) + m = newMarshalMethod(mtyp, true) + recv = m.receiver() unmarshal = Name(m.scope.newIdent("unmarshal")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML")) dec = Name(m.scope.newIdent("dec")) @@ -135,8 +137,8 @@ func genUnmarshalYAML(mtyp *marshalerType) Function { // genMarshalYAML generates the MarshalYAML method. func genMarshalYAML(mtyp *marshalerType) Function { var ( - m = newMarshalMethod(mtyp) - recv = m.receiver(false) + m = newMarshalMethod(mtyp, false) + recv = m.receiver() intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML")) enc = Name(m.scope.newIdent("enc")) ) @@ -154,10 +156,10 @@ func genMarshalYAML(mtyp *marshalerType) Function { return fn } -func (m *marshalMethod) receiver(pointer bool) Receiver { +func (m *marshalMethod) receiver() Receiver { letter := strings.ToLower(m.mtyp.name[:1]) r := Receiver{Name: m.scope.newIdent(letter), Type: Name(m.mtyp.name)} - if pointer { + if m.isUnmarshal { r.Type = Star{Value: r.Type} } return r @@ -166,9 +168,13 @@ func (m *marshalMethod) receiver(pointer bool) Receiver { func (m *marshalMethod) intermediateType(name string) Struct { s := Struct{Name: name} for _, f := range m.mtyp.Fields { + typ := f.typ + if m.isUnmarshal { + typ = ensureNilCheckable(typ) + } s.Fields = append(s.Fields, Field{ Name: f.name, - TypeName: types.TypeString(f.typ, m.mtyp.scope.qualify), + TypeName: types.TypeString(typ, m.mtyp.scope.qualify), Tag: f.tag, }) } @@ -179,10 +185,11 @@ func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []S for _, f := range m.mtyp.Fields { accessFrom := Dotted{Receiver: from, Name: f.name} accessTo := Dotted{Receiver: to, Name: f.name} + typ := ensureNilCheckable(f.typ) if f.isOptional(format) { s = append(s, If{ Condition: NotEqual{Lhs: accessFrom, Rhs: NIL}, - Body: m.convert(accessFrom, accessTo, f.typ, f.origTyp), + Body: m.convert(accessFrom, accessTo, typ, f.origTyp), }) } else { err := fmt.Sprintf("missing required field '%s' for %s", f.encodedName(format), m.mtyp.name) @@ -200,7 +207,7 @@ func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []S }, }, }) - s = append(s, m.convert(accessFrom, accessTo, f.typ, f.origTyp)...) + s = append(s, m.convert(accessFrom, accessTo, typ, f.origTyp)...) } } return s diff --git a/internal/tests/nameclash/output.go b/internal/tests/nameclash/output.go index 480fc03..831cb49 100644 --- a/internal/tests/nameclash/output.go +++ b/internal/tests/nameclash/output.go @@ -11,18 +11,18 @@ import ( func (y Y) MarshalJSON() ([]byte, error) { type YJSON0 struct { - Foo *json0.Foo `optional:"true"` - Foo2 *json0.Foo `optional:"true"` - Bar *errors0.Foo `optional:"true"` - Gazonk *YJSON `optional:"true"` - Over *enc `optional:"true"` + Foo json0.Foo `optional:"true"` + Foo2 json0.Foo `optional:"true"` + Bar errors0.Foo `optional:"true"` + Gazonk YJSON `optional:"true"` + Over enc `optional:"true"` } var enc0 YJSON0 - enc0.Foo = &y.Foo - enc0.Foo2 = &y.Foo2 - enc0.Bar = &y.Bar - enc0.Gazonk = &y.Gazonk - enc0.Over = (*enc)(&y.Over) + enc0.Foo = y.Foo + enc0.Foo2 = y.Foo2 + enc0.Bar = y.Bar + enc0.Gazonk = y.Gazonk + enc0.Over = enc(y.Over) return json.Marshal(&enc0) } @@ -60,18 +60,18 @@ func (y *Y) UnmarshalJSON(input []byte) error { func (y Y) MarshalYAML() (interface{}, error) { type YYAML struct { - Foo *json0.Foo `optional:"true"` - Foo2 *json0.Foo `optional:"true"` - Bar *errors0.Foo `optional:"true"` - Gazonk *YJSON `optional:"true"` - Over *enc `optional:"true"` + Foo json0.Foo `optional:"true"` + Foo2 json0.Foo `optional:"true"` + Bar errors0.Foo `optional:"true"` + Gazonk YJSON `optional:"true"` + Over enc `optional:"true"` } var enc0 YYAML - enc0.Foo = &y.Foo - enc0.Foo2 = &y.Foo2 - enc0.Bar = &y.Bar - enc0.Gazonk = &y.Gazonk - enc0.Over = (*enc)(&y.Over) + enc0.Foo = y.Foo + enc0.Foo2 = y.Foo2 + enc0.Bar = y.Bar + enc0.Gazonk = y.Gazonk + enc0.Over = enc(y.Over) return &enc0, nil } diff --git a/internal/tests/omitempty/input.go b/internal/tests/omitempty/input.go new file mode 100644 index 0000000..0fc8fef --- /dev/null +++ b/internal/tests/omitempty/input.go @@ -0,0 +1,17 @@ +// Copyright 2017 Felix Lange . +// Use of this source code is governed by the MIT license, +// which can be found in the LICENSE file. + +//go:generate gencodec -type X -field-override Xo -formats json,yaml -out output.go + +package omitempty + +type replacedInt int + +type X struct { + Int int `json:",omitempty"` +} + +type Xo struct { + Int replacedInt +} diff --git a/internal/tests/omitempty/input_test.go b/internal/tests/omitempty/input_test.go new file mode 100644 index 0000000..2bce65d --- /dev/null +++ b/internal/tests/omitempty/input_test.go @@ -0,0 +1,21 @@ +// Copyright 2017 Felix Lange . +// Use of this source code is governed by the MIT license, +// which can be found in the LICENSE file. + +package omitempty + +import ( + "encoding/json" + "testing" +) + +func TestOmitemptyJSON(t *testing.T) { + want := `{}` + out, err := json.Marshal(new(X)) + if err != nil { + t.Fatal(err) + } + if string(out) != want { + t.Fatalf("got %#q, want %#q", string(out), want) + } +} diff --git a/internal/tests/omitempty/output.go b/internal/tests/omitempty/output.go new file mode 100644 index 0000000..cf6148f --- /dev/null +++ b/internal/tests/omitempty/output.go @@ -0,0 +1,60 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package omitempty + +import ( + "encoding/json" + "errors" +) + +func (x X) MarshalJSON() ([]byte, error) { + type XJSON struct { + Int replacedInt `json:",omitempty"` + } + var enc XJSON + enc.Int = replacedInt(x.Int) + return json.Marshal(&enc) +} + +func (x *X) UnmarshalJSON(input []byte) error { + type XJSON struct { + Int *replacedInt `json:",omitempty"` + } + var dec XJSON + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + var x0 X + if dec.Int == nil { + return errors.New("missing required field 'int' for X") + } + x0.Int = int(*dec.Int) + *x = x0 + return nil +} + +func (x X) MarshalYAML() (interface{}, error) { + type XYAML struct { + Int replacedInt `json:",omitempty"` + } + var enc XYAML + enc.Int = replacedInt(x.Int) + return &enc, nil +} + +func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { + type XYAML struct { + Int *replacedInt `json:",omitempty"` + } + var dec XYAML + if err := unmarshal(&dec); err != nil { + return err + } + var x0 X + if dec.Int == nil { + return errors.New("missing required field 'int' for X") + } + x0.Int = int(*dec.Int) + *x = x0 + return nil +} diff --git a/main.go b/main.go index ad2bbdb..2bdecff 100644 --- a/main.go +++ b/main.go @@ -284,7 +284,7 @@ func newMarshalerType(fs *token.FileSet, imp types.Importer, typ *types.Named) * } mf := &marshalerField{ name: f.Name(), - typ: ensureNilCheckable(f.Type()), + typ: f.Type(), origTyp: f.Type(), tag: styp.Tag(i), } @@ -312,7 +312,7 @@ func (mtyp *marshalerType) loadOverrides(otypename string, otyp *types.Struct) e if err := checkConvertible(of.Type(), f.origTyp); err != nil { return fmt.Errorf("%v: invalid field override: %v", mtyp.fs.Position(of.Pos()), err) } - f.typ = ensureNilCheckable(of.Type()) + f.typ = of.Type() } mtyp.scope.addReferences(otyp) return nil diff --git a/main_test.go b/main_test.go index f2afa62..9b2fb62 100644 --- a/main_test.go +++ b/main_test.go @@ -29,6 +29,10 @@ func TestNameClash(t *testing.T) { runGoldenTest(t, Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo", Formats: AllFormats}) } +func TestOmitempty(t *testing.T) { + runGoldenTest(t, Config{Dir: "omitempty", Type: "X", FieldOverride: "Xo", Formats: AllFormats}) +} + func runGoldenTest(t *testing.T, cfg Config) { cfg.Dir = filepath.Join("internal", "tests", cfg.Dir) want, err := ioutil.ReadFile(filepath.Join(cfg.Dir, "output.go"))