From 90983d99deb1a46a57ef60a7fd16e4ccfa7409e3 Mon Sep 17 00:00:00 2001 From: Felix Lange Date: Mon, 8 Jan 2018 13:43:47 +0100 Subject: [PATCH] apply slice, map pointer wrapping to named types only --- internal/tests/mapconv/output.go | 30 ++++++++++----------- internal/tests/sliceconv/output.go | 42 +++++++++++++++--------------- typeutil.go | 10 +++++++ 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/internal/tests/mapconv/output.go b/internal/tests/mapconv/output.go index 4924846..0d9089c 100644 --- a/internal/tests/mapconv/output.go +++ b/internal/tests/mapconv/output.go @@ -43,9 +43,9 @@ func (x X) MarshalJSON() ([]byte, error) { func (x *X) UnmarshalJSON(input []byte) error { type X struct { - Map *map[replacedString]replacedInt + Map map[replacedString]replacedInt Named *namedMap2 - NoConv *map[string]int + NoConv map[string]int NoConvNamed *namedMap } var dec X @@ -53,8 +53,8 @@ func (x *X) UnmarshalJSON(input []byte) error { return err } if dec.Map != nil { - x.Map = make(map[string]int, len(*dec.Map)) - for k, v := range *dec.Map { + x.Map = make(map[string]int, len(dec.Map)) + for k, v := range dec.Map { x.Map[string(k)] = int(v) } } @@ -65,7 +65,7 @@ func (x *X) UnmarshalJSON(input []byte) error { } } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed @@ -108,9 +108,9 @@ func (x X) MarshalYAML() (interface{}, error) { func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { type X struct { - Map *map[replacedString]replacedInt + Map map[replacedString]replacedInt Named *namedMap2 - NoConv *map[string]int + NoConv map[string]int NoConvNamed *namedMap } var dec X @@ -118,8 +118,8 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { return err } if dec.Map != nil { - x.Map = make(map[string]int, len(*dec.Map)) - for k, v := range *dec.Map { + x.Map = make(map[string]int, len(dec.Map)) + for k, v := range dec.Map { x.Map[string(k)] = int(v) } } @@ -130,7 +130,7 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { } } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed @@ -173,9 +173,9 @@ func (x X) MarshalTOML() (interface{}, error) { func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { type X struct { - Map *map[replacedString]replacedInt + Map map[replacedString]replacedInt Named *namedMap2 - NoConv *map[string]int + NoConv map[string]int NoConvNamed *namedMap } var dec X @@ -183,8 +183,8 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { return err } if dec.Map != nil { - x.Map = make(map[string]int, len(*dec.Map)) - for k, v := range *dec.Map { + x.Map = make(map[string]int, len(dec.Map)) + for k, v := range dec.Map { x.Map[string(k)] = int(v) } } @@ -195,7 +195,7 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { } } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed diff --git a/internal/tests/sliceconv/output.go b/internal/tests/sliceconv/output.go index 8fd4207..0f6e719 100644 --- a/internal/tests/sliceconv/output.go +++ b/internal/tests/sliceconv/output.go @@ -45,10 +45,10 @@ func (x X) MarshalJSON() ([]byte, error) { func (x *X) UnmarshalJSON(input []byte) error { type X struct { - Slice *[]replacedInt + Slice []replacedInt Named *namedSlice2 - ByteString *[]byte - NoConv *[]int + ByteString []byte + NoConv []int NoConvNamed *namedSlice } var dec X @@ -56,8 +56,8 @@ func (x *X) UnmarshalJSON(input []byte) error { return err } if dec.Slice != nil { - x.Slice = make([]int, len(*dec.Slice)) - for k, v := range *dec.Slice { + x.Slice = make([]int, len(dec.Slice)) + for k, v := range dec.Slice { x.Slice[k] = int(v) } } @@ -68,10 +68,10 @@ func (x *X) UnmarshalJSON(input []byte) error { } } if dec.ByteString != nil { - x.ByteString = string(*dec.ByteString) + x.ByteString = string(dec.ByteString) } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed @@ -116,10 +116,10 @@ func (x X) MarshalYAML() (interface{}, error) { func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { type X struct { - Slice *[]replacedInt + Slice []replacedInt Named *namedSlice2 - ByteString *[]byte - NoConv *[]int + ByteString []byte + NoConv []int NoConvNamed *namedSlice } var dec X @@ -127,8 +127,8 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { return err } if dec.Slice != nil { - x.Slice = make([]int, len(*dec.Slice)) - for k, v := range *dec.Slice { + x.Slice = make([]int, len(dec.Slice)) + for k, v := range dec.Slice { x.Slice[k] = int(v) } } @@ -139,10 +139,10 @@ func (x *X) UnmarshalYAML(unmarshal func(interface{}) error) error { } } if dec.ByteString != nil { - x.ByteString = string(*dec.ByteString) + x.ByteString = string(dec.ByteString) } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed @@ -187,10 +187,10 @@ func (x X) MarshalTOML() (interface{}, error) { func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { type X struct { - Slice *[]replacedInt + Slice []replacedInt Named *namedSlice2 - ByteString *[]byte - NoConv *[]int + ByteString []byte + NoConv []int NoConvNamed *namedSlice } var dec X @@ -198,8 +198,8 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { return err } if dec.Slice != nil { - x.Slice = make([]int, len(*dec.Slice)) - for k, v := range *dec.Slice { + x.Slice = make([]int, len(dec.Slice)) + for k, v := range dec.Slice { x.Slice[k] = int(v) } } @@ -210,10 +210,10 @@ func (x *X) UnmarshalTOML(unmarshal func(interface{}) error) error { } } if dec.ByteString != nil { - x.ByteString = string(*dec.ByteString) + x.ByteString = string(dec.ByteString) } if dec.NoConv != nil { - x.NoConv = *dec.NoConv + x.NoConv = dec.NoConv } if dec.NoConvNamed != nil { x.NoConvNamed = *dec.NoConvNamed diff --git a/typeutil.go b/typeutil.go index 774272e..3a68dc8 100644 --- a/typeutil.go +++ b/typeutil.go @@ -98,10 +98,20 @@ func underlyingMap(typ types.Type) *types.Map { func ensureNilCheckable(typ types.Type) types.Type { orig := typ + named := false for { switch typ.(type) { case *types.Named: typ = typ.Underlying() + named = true + case *types.Slice, *types.Map: + if named { + // Named slices, maps, etc. are special because they can have a custom + // decoder function that prevents the JSON null value. Wrap them with a + // pointer to allow null always so required/optional works as expected. + return types.NewPointer(orig) + } + return orig case *types.Pointer, *types.Interface: return orig default: