From faf3ce92a4c84a9e01510b610312fa31202fea26 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Tue, 4 Apr 2017 00:53:46 +0200 Subject: [PATCH] don't reset all fields in Unmarshal* This preserves decoding into a struct with pre-filled, possibly unexported fields. --- genmethod.go | 10 ++-------- internal/tests/mapconv/output.go | 28 +++++++++++--------------- internal/tests/nameclash/output.go | 24 ++++++++++------------ internal/tests/omitempty/output.go | 8 ++------ internal/tests/sliceconv/output.go | 32 +++++++++++++----------------- 5 files changed, 40 insertions(+), 62 deletions(-) diff --git a/genmethod.go b/genmethod.go index 2bea522..41795a5 100644 --- a/genmethod.go +++ b/genmethod.go @@ -54,7 +54,6 @@ func genUnmarshalJSON(mtyp *marshalerType) Function { input = Name(m.scope.newIdent("input")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + "JSON")) dec = Name(m.scope.newIdent("dec")) - conv = Name(m.scope.newIdent("x")) json = Name(m.scope.parent.packageName("encoding/json")) ) fn := Function{ @@ -69,11 +68,9 @@ func genUnmarshalJSON(mtyp *marshalerType) Function { Func: Dotted{Receiver: json, Name: "Unmarshal"}, Params: []Expression{input, AddressOf{Value: dec}}, }), - Declare{Name: conv.Name, TypeName: m.mtyp.name}, }, } - fn.Body = append(fn.Body, m.unmarshalConversions(dec, conv, "json")...) - fn.Body = append(fn.Body, Assign{Lhs: Star{Value: Name(recv.Name)}, Rhs: conv}) + fn.Body = append(fn.Body, m.unmarshalConversions(dec, Name(recv.Name), "json")...) fn.Body = append(fn.Body, Return{Values: []Expression{NIL}}) return fn } @@ -123,7 +120,6 @@ func genUnmarshalLikeYAML(mtyp *marshalerType, name string) Function { unmarshal = Name(m.scope.newIdent("unmarshal")) intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name() + name)) dec = Name(m.scope.newIdent("dec")) - conv = Name(m.scope.newIdent("x")) tag = strings.ToLower(name) ) fn := Function{ @@ -135,11 +131,9 @@ func genUnmarshalLikeYAML(mtyp *marshalerType, name string) Function { declStmt{intertyp}, Declare{Name: dec.Name, TypeName: intertyp.Name}, errCheck(CallFunction{Func: unmarshal, Params: []Expression{AddressOf{Value: dec}}}), - Declare{Name: conv.Name, TypeName: m.mtyp.name}, }, } - fn.Body = append(fn.Body, m.unmarshalConversions(dec, conv, tag)...) - fn.Body = append(fn.Body, Assign{Lhs: Star{Value: Name(recv.Name)}, Rhs: conv}) + fn.Body = append(fn.Body, m.unmarshalConversions(dec, Name(recv.Name), tag)...) fn.Body = append(fn.Body, Return{Values: []Expression{NIL}}) return fn } diff --git a/internal/tests/mapconv/output.go b/internal/tests/mapconv/output.go index 316753b..5eaffb4 100644 --- a/internal/tests/mapconv/output.go +++ b/internal/tests/mapconv/output.go @@ -106,30 +106,28 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Map == nil { return errors.New("missing required field 'map' for X") } - x0.Map = make(map[string]int, len(dec.Map)) + x.Map = make(map[string]int, len(dec.Map)) for k, v := range dec.Map { - x0.Map[string(k)] = int(v) + x.Map[string(k)] = int(v) } if dec.Named == nil { return errors.New("missing required field 'named' for X") } - x0.Named = make(namedMap, len(dec.Named)) + x.Named = make(namedMap, len(dec.Named)) for k, v := range dec.Named { - x0.Named[string(k)] = int(v) + x.Named[string(k)] = int(v) } if dec.NoConv == nil { return errors.New("missing required field 'noConv' for X") } - x0.NoConv = dec.NoConv + x.NoConv = dec.NoConv if dec.NoConvNamed == nil { return errors.New("missing required field 'noConvNamed' for X") } - x0.NoConvNamed = dec.NoConvNamed - *x = x0 + x.NoConvNamed = dec.NoConvNamed return nil } @@ -169,29 +167,27 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Map == nil { return errors.New("missing required field 'map' for X") } - x0.Map = make(map[string]int, len(dec.Map)) + x.Map = make(map[string]int, len(dec.Map)) for k, v := range dec.Map { - x0.Map[string(k)] = int(v) + x.Map[string(k)] = int(v) } if dec.Named == nil { return errors.New("missing required field 'named' for X") } - x0.Named = make(namedMap, len(dec.Named)) + x.Named = make(namedMap, len(dec.Named)) for k, v := range dec.Named { - x0.Named[string(k)] = int(v) + x.Named[string(k)] = int(v) } if dec.NoConv == nil { return errors.New("missing required field 'noConv' for X") } - x0.NoConv = dec.NoConv + x.NoConv = dec.NoConv if dec.NoConvNamed == nil { return errors.New("missing required field 'noConvNamed' for X") } - x0.NoConvNamed = dec.NoConvNamed - *x = x0 + x.NoConvNamed = dec.NoConvNamed return nil } diff --git a/internal/tests/nameclash/output.go b/internal/tests/nameclash/output.go index 614c222..d1ebc9e 100644 --- a/internal/tests/nameclash/output.go +++ b/internal/tests/nameclash/output.go @@ -87,23 +87,21 @@ func (y *Y) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x Y if dec.Foo != nil { - x.Foo = *dec.Foo + y.Foo = *dec.Foo } if dec.Foo2 != nil { - x.Foo2 = *dec.Foo2 + y.Foo2 = *dec.Foo2 } if dec.Bar != nil { - x.Bar = *dec.Bar + y.Bar = *dec.Bar } if dec.Gazonk != nil { - x.Gazonk = *dec.Gazonk + y.Gazonk = *dec.Gazonk } if dec.Over != nil { - x.Over = int(*dec.Over) + y.Over = int(*dec.Over) } - *y = x return nil } @@ -136,22 +134,20 @@ func (y *Y) UnmarshalTOML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x Y if dec.Foo != nil { - x.Foo = *dec.Foo + y.Foo = *dec.Foo } if dec.Foo2 != nil { - x.Foo2 = *dec.Foo2 + y.Foo2 = *dec.Foo2 } if dec.Bar != nil { - x.Bar = *dec.Bar + y.Bar = *dec.Bar } if dec.Gazonk != nil { - x.Gazonk = *dec.Gazonk + y.Gazonk = *dec.Gazonk } if dec.Over != nil { - x.Over = int(*dec.Over) + y.Over = int(*dec.Over) } - *y = x return nil } diff --git a/internal/tests/omitempty/output.go b/internal/tests/omitempty/output.go index 7c7f8c5..d2a3770 100644 --- a/internal/tests/omitempty/output.go +++ b/internal/tests/omitempty/output.go @@ -50,12 +50,10 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Int == nil { return errors.New("missing required field 'int' for X") } - x0.Int = int(*dec.Int) - *x = x0 + x.Int = int(*dec.Int) return nil } @@ -76,11 +74,9 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Int == nil { return errors.New("missing required field 'int' for X") } - x0.Int = int(*dec.Int) - *x = x0 + x.Int = int(*dec.Int) return nil } diff --git a/internal/tests/sliceconv/output.go b/internal/tests/sliceconv/output.go index 47417d8..3e14114 100644 --- a/internal/tests/sliceconv/output.go +++ b/internal/tests/sliceconv/output.go @@ -116,34 +116,32 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Slice == nil { return errors.New("missing required field 'slice' for X") } - x0.Slice = make([]int, len(dec.Slice)) + x.Slice = make([]int, len(dec.Slice)) for k, v := range dec.Slice { - x0.Slice[k] = int(v) + x.Slice[k] = int(v) } if dec.Named == nil { return errors.New("missing required field 'named' for X") } - x0.Named = make(namedSlice, len(dec.Named)) + x.Named = make(namedSlice, len(dec.Named)) for k, v := range dec.Named { - x0.Named[k] = int(v) + x.Named[k] = int(v) } if dec.ByteString == nil { return errors.New("missing required field 'byteString' for X") } - x0.ByteString = string(dec.ByteString) + x.ByteString = string(dec.ByteString) if dec.NoConv == nil { return errors.New("missing required field 'noConv' for X") } - x0.NoConv = dec.NoConv + x.NoConv = dec.NoConv if dec.NoConvNamed == nil { return errors.New("missing required field 'noConvNamed' for X") } - x0.NoConvNamed = dec.NoConvNamed - *x = x0 + x.NoConvNamed = dec.NoConvNamed return nil } @@ -186,33 +184,31 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { if err := unmarshal(&dec); err != nil { return err } - var x0 X if dec.Slice == nil { return errors.New("missing required field 'slice' for X") } - x0.Slice = make([]int, len(dec.Slice)) + x.Slice = make([]int, len(dec.Slice)) for k, v := range dec.Slice { - x0.Slice[k] = int(v) + x.Slice[k] = int(v) } if dec.Named == nil { return errors.New("missing required field 'named' for X") } - x0.Named = make(namedSlice, len(dec.Named)) + x.Named = make(namedSlice, len(dec.Named)) for k, v := range dec.Named { - x0.Named[k] = int(v) + x.Named[k] = int(v) } if dec.ByteString == nil { return errors.New("missing required field 'byteString' for X") } - x0.ByteString = string(dec.ByteString) + x.ByteString = string(dec.ByteString) if dec.NoConv == nil { return errors.New("missing required field 'noConv' for X") } - x0.NoConv = dec.NoConv + x.NoConv = dec.NoConv if dec.NoConvNamed == nil { return errors.New("missing required field 'noConvNamed' for X") } - x0.NoConvNamed = dec.NoConvNamed - *x = x0 + x.NoConvNamed = dec.NoConvNamed return nil }