diff --git a/genmethod.go b/genmethod.go index 2c99e77..f28d44c 100644 --- a/genmethod.go +++ b/genmethod.go @@ -182,6 +182,9 @@ func (m *marshalMethod) receiver() Receiver { func (m *marshalMethod) intermediateType(name string) Struct { s := Struct{Name: name} for _, f := range m.mtyp.Fields { + if m.isUnmarshal && f.function != nil { + continue // fields generated from functions cannot be assigned on unmarshal + } typ := f.typ if m.isUnmarshal { typ = ensureNilCheckable(typ) @@ -197,6 +200,10 @@ func (m *marshalMethod) intermediateType(name string) Struct { func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []Statement) { for _, f := range m.mtyp.Fields { + if f.function != nil { + continue // fields generated from functions cannot be assigned + } + accessFrom := Dotted{Receiver: from, Name: f.name} accessTo := Dotted{Receiver: to, Name: f.name} typ := ensureNilCheckable(f.typ) @@ -231,8 +238,11 @@ func (m *marshalMethod) marshalConversions(from, to Var, format string) (s []Sta for _, f := range m.mtyp.Fields { accessFrom := Dotted{Receiver: from, Name: f.name} accessTo := Dotted{Receiver: to, Name: f.name} - conversion := m.convert(accessFrom, accessTo, f.origTyp, f.typ) - s = append(s, conversion...) + if f.function != nil { + s = append(s, m.convert(CallFunction{Func: accessFrom}, accessTo, f.origTyp, f.typ)...) + } else { + s = append(s, m.convert(accessFrom, accessTo, f.origTyp, f.typ)...) + } } return s } diff --git a/internal/tests/funcoverride/input.go b/internal/tests/funcoverride/input.go new file mode 100644 index 0000000..6edef47 --- /dev/null +++ b/internal/tests/funcoverride/input.go @@ -0,0 +1,33 @@ +// Copyright 2017 Felix Lange . +// Use of this source code is governed by the MIT license, +// which can be found in the LICENSE file. + +//go:generate gencodec -type Z -field-override Zo -formats json,yaml,toml -out output.go + +package funcoverride + +import ( + "fmt" +) + +type Z struct { + S string `json:"s"` + I int32 `json:"iVal"` +} + +func (z *Z) Hash() string { + return fmt.Sprintf("%s-%d", z.S, z.I) +} + +func (z *Z) MultiplyIByTwo() int32 { + return 2 * z.I +} + +func (z *Z) NotUsed() string { + return "not used" +} + +type Zo struct { + Hash string + MultiplyIByTwo int64 `json:"multipliedByTwo"` +} diff --git a/internal/tests/funcoverride/input_test.go b/internal/tests/funcoverride/input_test.go new file mode 100644 index 0000000..325b3c6 --- /dev/null +++ b/internal/tests/funcoverride/input_test.go @@ -0,0 +1,44 @@ +// Copyright 2017 Felix Lange . +// Use of this source code is governed by the MIT license, +// which can be found in the LICENSE file. + +package funcoverride + +import ( + "encoding/json" + "fmt" + "testing" +) + +func TestOverrideFuncJSON(t *testing.T) { + z := Z{"str", 1234} + hash := z.Hash() + multiply := z.MultiplyIByTwo() + want := fmt.Sprintf(`{"s":"%s","iVal":%d,"Hash":"%s","multipliedByTwo":%d}`, z.S, z.I, hash, multiply) + out, err := json.Marshal(z) + if err != nil { + t.Fatal(err) + } + if string(out) != want { + t.Fatalf("got %#q, want %#q", string(out), want) + } + + var zUnmarshaled Z + if err := json.Unmarshal([]byte(want), &zUnmarshaled); err != nil { + t.Fatalf("could not unmarshal Z: %v", err) + } + if zUnmarshaled.I != z.I { + t.Errorf("Z.I has an unexpected value, want %d, got %d", z.I, zUnmarshaled.I) + } + if zUnmarshaled.S != z.S { + t.Errorf("Z.Str has an unexpected value, want %s, got %s", z.S, zUnmarshaled.S) + } + uHash := zUnmarshaled.Hash() + if uHash != hash { + t.Errorf("Z.Hash() returned unexpected value, want %s, got %s", hash, uHash) + } + uMultiply := zUnmarshaled.MultiplyIByTwo() + if uMultiply != multiply { + t.Errorf("Z.MultiplIByTwo() returned unexpected value, want %d, got %d", multiply, uMultiply) + } +} diff --git a/internal/tests/funcoverride/output.go b/internal/tests/funcoverride/output.go new file mode 100644 index 0000000..cb9ad80 --- /dev/null +++ b/internal/tests/funcoverride/output.go @@ -0,0 +1,106 @@ +// Code generated by github.com/fjl/gencodec. DO NOT EDIT. + +package funcoverride + +import ( + "encoding/json" +) + +func (z Z) MarshalJSON() ([]byte, error) { + type Z struct { + S string `json:"s"` + I int32 `json:"iVal"` + Hash string + MultiplyIByTwo int64 `json:"multipliedByTwo"` + } + var enc Z + enc.S = z.S + enc.I = z.I + enc.Hash = z.Hash() + enc.MultiplyIByTwo = int64(z.MultiplyIByTwo()) + return json.Marshal(&enc) +} + +func (z *Z) UnmarshalJSON(input []byte) error { + type Z struct { + S *string `json:"s"` + I *int32 `json:"iVal"` + } + var dec Z + if err := json.Unmarshal(input, &dec); err != nil { + return err + } + if dec.S != nil { + z.S = *dec.S + } + if dec.I != nil { + z.I = *dec.I + } + return nil +} + +func (z Z) MarshalYAML() (interface{}, error) { + type Z struct { + S string `json:"s"` + I int32 `json:"iVal"` + Hash string + MultiplyIByTwo int64 `json:"multipliedByTwo"` + } + var enc Z + enc.S = z.S + enc.I = z.I + enc.Hash = z.Hash() + enc.MultiplyIByTwo = int64(z.MultiplyIByTwo()) + return &enc, nil +} + +func (z *Z) UnmarshalYAML(unmarshal func(interface{}) error) error { + type Z struct { + S *string `json:"s"` + I *int32 `json:"iVal"` + } + var dec Z + if err := unmarshal(&dec); err != nil { + return err + } + if dec.S != nil { + z.S = *dec.S + } + if dec.I != nil { + z.I = *dec.I + } + return nil +} + +func (z Z) MarshalTOML() (interface{}, error) { + type Z struct { + S string `json:"s"` + I int32 `json:"iVal"` + Hash string + MultiplyIByTwo int64 `json:"multipliedByTwo"` + } + var enc Z + enc.S = z.S + enc.I = z.I + enc.Hash = z.Hash() + enc.MultiplyIByTwo = int64(z.MultiplyIByTwo()) + return &enc, nil +} + +func (z *Z) UnmarshalTOML(unmarshal func(interface{}) error) error { + type Z struct { + S *string `json:"s"` + I *int32 `json:"iVal"` + } + var dec Z + if err := unmarshal(&dec); err != nil { + return err + } + if dec.S != nil { + z.S = *dec.S + } + if dec.I != nil { + z.I = *dec.I + } + return nil +} diff --git a/internal/tests/mapconv/input.go b/internal/tests/mapconv/input.go index 89d4d49..c9db943 100644 --- a/internal/tests/mapconv/input.go +++ b/internal/tests/mapconv/input.go @@ -21,8 +21,14 @@ type X struct { NoConvNamed map[string]int } +func (x *X) Func() map[string]int { + return map[string]int{"a": 1, "b": 2} +} + type Xo struct { Map map[replacedString]replacedInt Named namedMap2 NoConvNamed namedMap + + Func map[replacedString]replacedInt } diff --git a/internal/tests/mapconv/output.go b/internal/tests/mapconv/output.go index 3188583..4593155 100644 --- a/internal/tests/mapconv/output.go +++ b/internal/tests/mapconv/output.go @@ -12,6 +12,7 @@ func (x X) MarshalJSON() ([]byte, error) { Named namedMap2 NoConv map[string]int NoConvNamed namedMap + Func map[replacedString]replacedInt } var enc X if x.Map != nil { @@ -28,6 +29,12 @@ func (x X) MarshalJSON() ([]byte, error) { } enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make(map[replacedString]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[replacedString(k)] = replacedInt(v) + } + } return json.Marshal(&enc) } @@ -69,6 +76,7 @@ func (x X) MarshalYAML() (interface{}, error) { Named namedMap2 NoConv map[string]int NoConvNamed namedMap + Func map[replacedString]replacedInt } var enc X if x.Map != nil { @@ -85,6 +93,12 @@ func (x X) MarshalYAML() (interface{}, error) { } enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make(map[replacedString]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[replacedString(k)] = replacedInt(v) + } + } return &enc, nil } @@ -126,6 +140,7 @@ func (x X) MarshalTOML() (interface{}, error) { Named namedMap2 NoConv map[string]int NoConvNamed namedMap + Func map[replacedString]replacedInt } var enc X if x.Map != nil { @@ -142,6 +157,12 @@ func (x X) MarshalTOML() (interface{}, error) { } enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make(map[replacedString]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[replacedString(k)] = replacedInt(v) + } + } return &enc, nil } diff --git a/internal/tests/sliceconv/input.go b/internal/tests/sliceconv/input.go index 7b59f4d..1d020d6 100644 --- a/internal/tests/sliceconv/input.go +++ b/internal/tests/sliceconv/input.go @@ -20,9 +20,15 @@ type X struct { NoConvNamed []int } +func (x *X) Func() []int { + return []int{1,2,3,4} +} + type Xo struct { Slice []replacedInt Named namedSlice2 ByteString []byte NoConvNamed namedSlice + + Func []replacedInt } diff --git a/internal/tests/sliceconv/output.go b/internal/tests/sliceconv/output.go index 335881e..fd7d80a 100644 --- a/internal/tests/sliceconv/output.go +++ b/internal/tests/sliceconv/output.go @@ -13,6 +13,7 @@ func (x X) MarshalJSON() ([]byte, error) { ByteString []byte NoConv []int NoConvNamed namedSlice + Func []replacedInt } var enc X if x.Slice != nil { @@ -30,6 +31,12 @@ func (x X) MarshalJSON() ([]byte, error) { enc.ByteString = []byte(x.ByteString) enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make([]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[k] = replacedInt(v) + } + } return json.Marshal(&enc) } @@ -76,6 +83,7 @@ func (x X) MarshalYAML() (interface{}, error) { ByteString []byte NoConv []int NoConvNamed namedSlice + Func []replacedInt } var enc X if x.Slice != nil { @@ -93,6 +101,12 @@ func (x X) MarshalYAML() (interface{}, error) { enc.ByteString = []byte(x.ByteString) enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make([]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[k] = replacedInt(v) + } + } return &enc, nil } @@ -139,6 +153,7 @@ func (x X) MarshalTOML() (interface{}, error) { ByteString []byte NoConv []int NoConvNamed namedSlice + Func []replacedInt } var enc X if x.Slice != nil { @@ -156,6 +171,12 @@ func (x X) MarshalTOML() (interface{}, error) { enc.ByteString = []byte(x.ByteString) enc.NoConv = x.NoConv enc.NoConvNamed = x.NoConvNamed + if x.Func() != nil { + enc.Func = make([]replacedInt, len(x.Func())) + for k, v := range x.Func() { + enc.Func[k] = replacedInt(v) + } + } return &enc, nil } diff --git a/main.go b/main.go index 9885bad..ee06597 100644 --- a/main.go +++ b/main.go @@ -32,11 +32,17 @@ Field Type Overrides An invocation of gencodec can specify an additional 'field override' struct from which marshaling type replacements are taken. If the override struct contains a field whose name matches the original type, the generated marshaling methods will use the overridden type -and convert to and from the original field type. +and convert to and from the original field type. If the override struct contains a field F +of type T, which does not exist in the original type, and the original type has a method +named F with no arguments and return type assignable to T, the method is called by Marshal*. +If there is a matching method F but the return type or arguments are unsuitable, an error +is raised. In this example, the specialString type implements json.Unmarshaler to enforce additional parsing rules. When json.Unmarshal is used with type foo, the specialString unmarshaler -will be used to parse the value of SpecialField. +will be used to parse the value of SpecialField. The result of foo.Func() is added to the +result on marshaling under the key `id`. If the input on unmarshal contains a key `id` this +field is ignored. //go:generate gencodec -type foo -field-override fooMarshaling -out foo_json.go @@ -45,8 +51,13 @@ will be used to parse the value of SpecialField. SpecialField string } + func (f foo) Func() string { + return f.Field + "-" + f.SpecialField + } + type fooMarshaling struct { SpecialField specialString // overrides type of SpecialField when marshaling/unmarshaling + Func string `json:"id"` // adds the result of foo.Func() to the serialised object under the key id } Relaxed Field Conversions @@ -186,7 +197,8 @@ func (cfg *Config) process() (code []byte, err error) { if err != nil { return nil, fmt.Errorf("can't find field replacement type %s: %v", cfg.FieldOverride, err) } - err = mtyp.loadOverrides(cfg.FieldOverride, otyp.Underlying().(*types.Struct)) + + err = mtyp.loadOverrides(typ, otyp.Underlying().(*types.Struct)) if err != nil { return nil, err } @@ -266,10 +278,11 @@ type marshalerType struct { // marshalerField represents a field of the intermediate marshaling type. type marshalerField struct { - name string - typ types.Type - origTyp types.Type - tag string + name string + typ types.Type + origTyp types.Type + tag string + function *types.Func // map to a function instead of a field } func newMarshalerType(fs *token.FileSet, imp types.Importer, typ *types.Named) *marshalerType { @@ -287,24 +300,46 @@ func newMarshalerType(fs *token.FileSet, imp types.Importer, typ *types.Named) * if !f.Exported() { continue } + if f.Anonymous() { + fmt.Fprintf(os.Stderr, "Warning: ignoring embedded field %s\n", f.Name()) + continue + } + mf := &marshalerField{ name: f.Name(), typ: f.Type(), origTyp: f.Type(), tag: styp.Tag(i), } - if f.Anonymous() { - fmt.Fprintf(os.Stderr, "Warning: ignoring embedded field %s\n", f.Name()) - continue - } + mtyp.Fields = append(mtyp.Fields, mf) } + return mtyp } +// findFunction returns a function with `name` that accepts no arguments +// and returns a single value that is convertible to the given to type. +func findFunction(typ *types.Named, name string, to types.Type) (*types.Func, types.Type) { + for i := 0; i < typ.NumMethods(); i++ { + fun := typ.Method(i) + if fun.Name() != name || !fun.Exported() { + continue + } + sign := fun.Type().(*types.Signature) + if sign.Params().Len() != 0 || sign.Results().Len() != 1 { + continue + } + if err := checkConvertible(sign.Results().At(0).Type(), to); err == nil { + return fun, sign.Results().At(0).Type() + } + } + return nil, nil +} + // loadOverrides sets field types of the intermediate marshaling type from // matching fields of otyp. -func (mtyp *marshalerType) loadOverrides(otypename string, otyp *types.Struct) error { +func (mtyp *marshalerType) loadOverrides(typ *types.Named, otyp *types.Struct) error { for i := 0; i < otyp.NumFields(); i++ { of := otyp.Field(i) if of.Anonymous() || !of.Exported() { @@ -312,7 +347,13 @@ func (mtyp *marshalerType) loadOverrides(otypename string, otyp *types.Struct) e } f := mtyp.fieldByName(of.Name()) if f == nil { - return fmt.Errorf("%v: no matching field for %s in original type %s", mtyp.fs.Position(of.Pos()), of.Name(), mtyp.name) + // field not defined in original type, check if it maps to a suitable function and add it as an override + if fun, retType := findFunction(typ, of.Name(), of.Type()); fun != nil { + f = &marshalerField{name: fun.Name(), origTyp: retType, typ: of.Type(), function: fun, tag: otyp.Tag(i)} + mtyp.Fields = append(mtyp.Fields, f) + } else { + return fmt.Errorf("%v: no matching field or function for %s in original type %s", mtyp.fs.Position(of.Pos()), of.Name(), mtyp.name) + } } if err := checkConvertible(of.Type(), f.origTyp); err != nil { return fmt.Errorf("%v: invalid field override: %v", mtyp.fs.Position(of.Pos()), err) diff --git a/main_test.go b/main_test.go index 608ac60..7c196a9 100644 --- a/main_test.go +++ b/main_test.go @@ -24,6 +24,7 @@ func TestGolden(t *testing.T) { Config{Dir: "nameclash", Type: "Y", FieldOverride: "Yo", Formats: AllFormats}, Config{Dir: "omitempty", Type: "X", FieldOverride: "Xo", Formats: AllFormats}, Config{Dir: "reqfield", Type: "X", Formats: []string{"json"}}, + Config{Dir: "funcoverride", Type: "Z", FieldOverride: "Zo", Formats: AllFormats}, } for _, test := range tests { test := test