disable field type pointer-wrapping for Marshal*

This makes the ,omitempty tag work correctly.
master
Felix Lange 8 years ago
parent 50a5fcc22f
commit 8c7f4ae1e5
  1. 35
      genmethod.go
  2. 40
      internal/tests/nameclash/output.go
  3. 17
      internal/tests/omitempty/input.go
  4. 21
      internal/tests/omitempty/input_test.go
  5. 60
      internal/tests/omitempty/output.go
  6. 4
      main.go
  7. 4
      main_test.go

@ -25,15 +25,17 @@ var (
type marshalMethod struct { type marshalMethod struct {
mtyp *marshalerType mtyp *marshalerType
scope *funcScope scope *funcScope
isUnmarshal bool
// cached identifiers for map, slice conversions // cached identifiers for map, slice conversions
iterKey, iterVal Var iterKey, iterVal Var
} }
func newMarshalMethod(mtyp *marshalerType) *marshalMethod { func newMarshalMethod(mtyp *marshalerType, isUnmarshal bool) *marshalMethod {
s := newFuncScope(mtyp.scope) s := newFuncScope(mtyp.scope)
return &marshalMethod{ return &marshalMethod{
mtyp: mtyp, mtyp: mtyp,
scope: newFuncScope(mtyp.scope), scope: newFuncScope(mtyp.scope),
isUnmarshal: isUnmarshal,
iterKey: Name(s.newIdent("k")), iterKey: Name(s.newIdent("k")),
iterVal: Name(s.newIdent("v")), iterVal: Name(s.newIdent("v")),
} }
@ -47,8 +49,8 @@ func writeFunction(w io.Writer, fs *token.FileSet, fn Function) {
// genUnmarshalJSON generates the UnmarshalJSON method. // genUnmarshalJSON generates the UnmarshalJSON method.
func genUnmarshalJSON(mtyp *marshalerType) Function { func genUnmarshalJSON(mtyp *marshalerType) Function {
var ( var (
m = newMarshalMethod(mtyp) m = newMarshalMethod(mtyp, true)
recv = m.receiver(true) recv = m.receiver()
input = Name(m.scope.newIdent("input")) input = Name(m.scope.newIdent("input"))
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON"))
dec = Name(m.scope.newIdent("dec")) dec = Name(m.scope.newIdent("dec"))
@ -79,8 +81,8 @@ func genUnmarshalJSON(mtyp *marshalerType) Function {
// genMarshalJSON generates the MarshalJSON method. // genMarshalJSON generates the MarshalJSON method.
func genMarshalJSON(mtyp *marshalerType) Function { func genMarshalJSON(mtyp *marshalerType) Function {
var ( var (
m = newMarshalMethod(mtyp) m = newMarshalMethod(mtyp, false)
recv = m.receiver(false) recv = m.receiver()
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON"))
enc = Name(m.scope.newIdent("enc")) enc = Name(m.scope.newIdent("enc"))
json = Name(m.scope.parent.packageName("encoding/json")) json = Name(m.scope.parent.packageName("encoding/json"))
@ -107,8 +109,8 @@ func genMarshalJSON(mtyp *marshalerType) Function {
// genUnmarshalYAML generates the UnmarshalYAML method. // genUnmarshalYAML generates the UnmarshalYAML method.
func genUnmarshalYAML(mtyp *marshalerType) Function { func genUnmarshalYAML(mtyp *marshalerType) Function {
var ( var (
m = newMarshalMethod(mtyp) m = newMarshalMethod(mtyp, true)
recv = m.receiver(true) recv = m.receiver()
unmarshal = Name(m.scope.newIdent("unmarshal")) unmarshal = Name(m.scope.newIdent("unmarshal"))
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML"))
dec = Name(m.scope.newIdent("dec")) dec = Name(m.scope.newIdent("dec"))
@ -135,8 +137,8 @@ func genUnmarshalYAML(mtyp *marshalerType) Function {
// genMarshalYAML generates the MarshalYAML method. // genMarshalYAML generates the MarshalYAML method.
func genMarshalYAML(mtyp *marshalerType) Function { func genMarshalYAML(mtyp *marshalerType) Function {
var ( var (
m = newMarshalMethod(mtyp) m = newMarshalMethod(mtyp, false)
recv = m.receiver(false) recv = m.receiver()
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "YAML"))
enc = Name(m.scope.newIdent("enc")) enc = Name(m.scope.newIdent("enc"))
) )
@ -154,10 +156,10 @@ func genMarshalYAML(mtyp *marshalerType) Function {
return fn return fn
} }
func (m *marshalMethod) receiver(pointer bool) Receiver { func (m *marshalMethod) receiver() Receiver {
letter := strings.ToLower(m.mtyp.name[:1]) letter := strings.ToLower(m.mtyp.name[:1])
r := Receiver{Name: m.scope.newIdent(letter), Type: Name(m.mtyp.name)} r := Receiver{Name: m.scope.newIdent(letter), Type: Name(m.mtyp.name)}
if pointer { if m.isUnmarshal {
r.Type = Star{Value: r.Type} r.Type = Star{Value: r.Type}
} }
return r return r
@ -166,9 +168,13 @@ func (m *marshalMethod) receiver(pointer bool) Receiver {
func (m *marshalMethod) intermediateType(name string) Struct { func (m *marshalMethod) intermediateType(name string) Struct {
s := Struct{Name: name} s := Struct{Name: name}
for _, f := range m.mtyp.Fields { for _, f := range m.mtyp.Fields {
typ := f.typ
if m.isUnmarshal {
typ = ensureNilCheckable(typ)
}
s.Fields = append(s.Fields, Field{ s.Fields = append(s.Fields, Field{
Name: f.name, Name: f.name,
TypeName: types.TypeString(f.typ, m.mtyp.scope.qualify), TypeName: types.TypeString(typ, m.mtyp.scope.qualify),
Tag: f.tag, Tag: f.tag,
}) })
} }
@ -179,10 +185,11 @@ func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []S
for _, f := range m.mtyp.Fields { for _, f := range m.mtyp.Fields {
accessFrom := Dotted{Receiver: from, Name: f.name} accessFrom := Dotted{Receiver: from, Name: f.name}
accessTo := Dotted{Receiver: to, Name: f.name} accessTo := Dotted{Receiver: to, Name: f.name}
typ := ensureNilCheckable(f.typ)
if f.isOptional(format) { if f.isOptional(format) {
s = append(s, If{ s = append(s, If{
Condition: NotEqual{Lhs: accessFrom, Rhs: NIL}, Condition: NotEqual{Lhs: accessFrom, Rhs: NIL},
Body: m.convert(accessFrom, accessTo, f.typ, f.origTyp), Body: m.convert(accessFrom, accessTo, typ, f.origTyp),
}) })
} else { } else {
err := fmt.Sprintf("missing required field '%s' for %s", f.encodedName(format), m.mtyp.name) err := fmt.Sprintf("missing required field '%s' for %s", f.encodedName(format), m.mtyp.name)
@ -200,7 +207,7 @@ func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []S
}, },
}, },
}) })
s = append(s, m.convert(accessFrom, accessTo, f.typ, f.origTyp)...) s = append(s, m.convert(accessFrom, accessTo, typ, f.origTyp)...)
} }
} }
return s return s

@ -11,18 +11,18 @@ import (
func (y Y) MarshalJSON() ([]byte, error) { func (y Y) MarshalJSON() ([]byte, error) {
type YJSON0 struct { type YJSON0 struct {
Foo *json0.Foo `optional:"true"` Foo json0.Foo `optional:"true"`
Foo2 *json0.Foo `optional:"true"` Foo2 json0.Foo `optional:"true"`
Bar *errors0.Foo `optional:"true"` Bar errors0.Foo `optional:"true"`
Gazonk *YJSON `optional:"true"` Gazonk YJSON `optional:"true"`
Over *enc `optional:"true"` Over enc `optional:"true"`
} }
var enc0 YJSON0 var enc0 YJSON0
enc0.Foo = &y.Foo enc0.Foo = y.Foo
enc0.Foo2 = &y.Foo2 enc0.Foo2 = y.Foo2
enc0.Bar = &y.Bar enc0.Bar = y.Bar
enc0.Gazonk = &y.Gazonk enc0.Gazonk = y.Gazonk
enc0.Over = (*enc)(&y.Over) enc0.Over = enc(y.Over)
return json.Marshal(&enc0) return json.Marshal(&enc0)
} }
@ -60,18 +60,18 @@ func (y *Y) UnmarshalJSON(input []byte) error {
func (y Y) MarshalYAML() (interface{}, error) { func (y Y) MarshalYAML() (interface{}, error) {
type YYAML struct { type YYAML struct {
Foo *json0.Foo `optional:"true"` Foo json0.Foo `optional:"true"`
Foo2 *json0.Foo `optional:"true"` Foo2 json0.Foo `optional:"true"`
Bar *errors0.Foo `optional:"true"` Bar errors0.Foo `optional:"true"`
Gazonk *YJSON `optional:"true"` Gazonk YJSON `optional:"true"`
Over *enc `optional:"true"` Over enc `optional:"true"`
} }
var enc0 YYAML var enc0 YYAML
enc0.Foo = &y.Foo enc0.Foo = y.Foo
enc0.Foo2 = &y.Foo2 enc0.Foo2 = y.Foo2
enc0.Bar = &y.Bar enc0.Bar = y.Bar
enc0.Gazonk = &y.Gazonk enc0.Gazonk = y.Gazonk
enc0.Over = (*enc)(&y.Over) enc0.Over = enc(y.Over)
return &enc0, nil return &enc0, nil
} }

@ -0,0 +1,17 @@
// Copyright 2017 Felix Lange <fjl@twurst.com>.
// Use of this source code is governed by the MIT license,
// which can be found in the LICENSE file.
//go:generate gencodec -type X -field-override Xo -formats json,yaml -out output.go
package omitempty
type replacedInt int
type X struct {
Int int `json:",omitempty"`
}
type Xo struct {
Int replacedInt
}

@ -0,0 +1,21 @@
// Copyright 2017 Felix Lange <fjl@twurst.com>.
// Use of this source code is governed by the MIT license,
// which can be found in the LICENSE file.
package omitempty
import (
"encoding/json"
"testing"
)
func TestOmitemptyJSON(t *testing.T) {
want := `{}`
out, err := json.Marshal(new(X))
if err != nil {
t.Fatal(err)
}
if string(out) != want {
t.Fatalf("got %#q, want %#q", string(out), want)
}
}

@ -0,0 +1,60 @@
// Code generated by github.com/fjl/gencodec. DO NOT EDIT.
package omitempty
import (
"encoding/json"
"errors"
)
func (x X) MarshalJSON() ([]byte, error) {
type XJSON struct {
Int replacedInt `json:",omitempty"`
}
var enc XJSON
enc.Int = replacedInt(x.Int)
return json.Marshal(&enc)
}
func (x *X) UnmarshalJSON(input []byte) error {
type XJSON struct {
Int *replacedInt `json:",omitempty"`
}
var dec XJSON
if err := json.Unmarshal(input, &dec); err != nil {
return err
}
var x0 X
if dec.Int == nil {
return errors.New("missing required field 'int' for X")
}
x0.Int = int(*dec.Int)
*x = x0
return nil
}
func (x X) MarshalYAML() (interface{}, error) {
type XYAML struct {
Int replacedInt `json:",omitempty"`
}
var enc XYAML
enc.Int = replacedInt(x.Int)
return &enc, nil
}
func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error {
type XYAML struct {
Int *replacedInt `json:",omitempty"`
}
var dec XYAML
if err := unmarshal(&dec); err != nil {
return err
}
var x0 X
if dec.Int == nil {
return errors.New("missing required field 'int' for X")
}
x0.Int = int(*dec.Int)
*x = x0
return nil
}

@ -284,7 +284,7 @@ func newMarshalerType(fs *token.FileSet, imp types.Importer, typ *types.Named) *
} }
mf := &marshalerField{ mf := &marshalerField{
name: f.Name(), name: f.Name(),
typ: ensureNilCheckable(f.Type()), typ: f.Type(),
origTyp: f.Type(), origTyp: f.Type(),
tag: styp.Tag(i), tag: styp.Tag(i),
} }
@ -312,7 +312,7 @@ func (mtyp *marshalerType) loadOverrides(otypename string, otyp *types.Struct) e
if err := checkConvertible(of.Type(), f.origTyp); err != nil { if err := checkConvertible(of.Type(), f.origTyp); err != nil {
return fmt.Errorf("%v: invalid field override: %v", mtyp.fs.Position(of.Pos()), err) return fmt.Errorf("%v: invalid field override: %v", mtyp.fs.Position(of.Pos()), err)
} }
f.typ = ensureNilCheckable(of.Type()) f.typ = of.Type()
} }
mtyp.scope.addReferences(otyp) mtyp.scope.addReferences(otyp)
return nil return nil

@ -29,6 +29,10 @@ func TestNameClash(t *testing.T) {
runGoldenTest(t, Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo", Formats: AllFormats}) runGoldenTest(t, Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo", Formats: AllFormats})
} }
func TestOmitempty(t *testing.T) {
runGoldenTest(t, Config{Dir: "omitempty", Type: "X", FieldOverride: "Xo", Formats: AllFormats})
}
func runGoldenTest(t *testing.T, cfg Config) { func runGoldenTest(t *testing.T, cfg Config) {
cfg.Dir = filepath.Join("internal", "tests", cfg.Dir) cfg.Dir = filepath.Join("internal", "tests", cfg.Dir)
want, err := ioutil.ReadFile(filepath.Join(cfg.Dir, "output.go")) want, err := ioutil.ReadFile(filepath.Join(cfg.Dir, "output.go"))

Loading…
Cancel
Save