Command gencodec generates marshaling methods for Go struct types.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gencodec/genmethod.go

383 lines
11 KiB

// Copyright 2017 Felix Lange <fjl@twurst.com>.
// Use of this source code is governed by the MIT license,
// which can be found in the LICENSE file.
package main
import (
"fmt"
"go/ast"
"go/printer"
"go/token"
"go/types"
"io"
"strconv"
"strings"
. "github.com/garslo/gogen"
)
var (
NIL = Name("nil")
intType = types.Universe.Lookup("int").Type()
)
type marshalMethod struct {
mtyp *marshalerType
scope *funcScope
isUnmarshal bool
// cached identifiers for map, slice conversions
iterKey, iterVal Var
}
func newMarshalMethod(mtyp *marshalerType, isUnmarshal bool) *marshalMethod {
s := newFuncScope(mtyp.scope)
return &marshalMethod{
mtyp: mtyp,
scope: newFuncScope(mtyp.scope),
isUnmarshal: isUnmarshal,
iterKey: Name(s.newIdent("k")),
iterVal: Name(s.newIdent("v")),
}
}
func writeFunction(w io.Writer, fs *token.FileSet, fn Function) {
printer.Fprint(w, fs, fn.Declaration())
fmt.Fprintln(w)
}
// genUnmarshalJSON generates the UnmarshalJSON method.
func genUnmarshalJSON(mtyp *marshalerType) Function {
var (
m = newMarshalMethod(mtyp, true)
recv = m.receiver()
input = Name(m.scope.newIdent("input"))
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name()))
dec = Name(m.scope.newIdent("dec"))
json = Name(m.scope.parent.packageName("encoding/json"))
)
fn := Function{
Receiver: recv,
Name: "UnmarshalJSON",
ReturnTypes: Types{{TypeName: "error"}},
Parameters: Types{{Name: input.Name, TypeName: "[]byte"}},
Body: []Statement{
declStmt{intertyp},
Declare{Name: dec.Name, TypeName: intertyp.Name},
errCheck(CallFunction{
Func: Dotted{Receiver: json, Name: "Unmarshal"},
Params: []Expression{input, AddressOf{Value: dec}},
}),
},
}
fn.Body = append(fn.Body, m.unmarshalConversions(dec, Name(recv.Name), "json")...)
fn.Body = append(fn.Body, Return{Values: []Expression{NIL}})
return fn
}
// genMarshalJSON generates the MarshalJSON method.
func genMarshalJSON(mtyp *marshalerType) Function {
var (
m = newMarshalMethod(mtyp, false)
recv = m.receiver()
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name()))
enc = Name(m.scope.newIdent("enc"))
json = Name(m.scope.parent.packageName("encoding/json"))
)
fn := Function{
Receiver: recv,
Name: "MarshalJSON",
ReturnTypes: Types{{TypeName: "[]byte"}, {TypeName: "error"}},
Body: []Statement{
declStmt{intertyp},
Declare{Name: enc.Name, TypeName: intertyp.Name},
},
}
fn.Body = append(fn.Body, m.marshalConversions(Name(recv.Name), enc, "json")...)
fn.Body = append(fn.Body, Return{Values: []Expression{
CallFunction{
Func: Dotted{Receiver: json, Name: "Marshal"},
Params: []Expression{AddressOf{Value: enc}},
},
}})
return fn
}
// genUnmarshalYAML generates the UnmarshalYAML method.
func genUnmarshalYAML(mtyp *marshalerType) Function {
return genUnmarshalLikeYAML(mtyp, "YAML")
}
// genUnmarshalTOML generates the UnmarshalTOML method.
func genUnmarshalTOML(mtyp *marshalerType) Function {
return genUnmarshalLikeYAML(mtyp, "TOML")
}
func genUnmarshalLikeYAML(mtyp *marshalerType, name string) Function {
var (
m = newMarshalMethod(mtyp, true)
recv = m.receiver()
unmarshal = Name(m.scope.newIdent("unmarshal"))
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name()))
dec = Name(m.scope.newIdent("dec"))
tag = strings.ToLower(name)
)
fn := Function{
Receiver: recv,
Name: "Unmarshal" + name,
ReturnTypes: Types{{TypeName: "error"}},
Parameters: Types{{Name: unmarshal.Name, TypeName: "func (interface{}) error"}},
Body: []Statement{
declStmt{intertyp},
Declare{Name: dec.Name, TypeName: intertyp.Name},
errCheck(CallFunction{Func: unmarshal, Params: []Expression{AddressOf{Value: dec}}}),
},
}
fn.Body = append(fn.Body, m.unmarshalConversions(dec, Name(recv.Name), tag)...)
fn.Body = append(fn.Body, Return{Values: []Expression{NIL}})
return fn
}
// genMarshalYAML generates the MarshalYAML method.
func genMarshalYAML(mtyp *marshalerType) Function {
return genMarshalLikeYAML(mtyp, "YAML")
}
// genMarshalTOML generates the MarshalTOML method.
func genMarshalTOML(mtyp *marshalerType) Function {
return genMarshalLikeYAML(mtyp, "TOML")
}
func genMarshalLikeYAML(mtyp *marshalerType, name string) Function {
var (
m = newMarshalMethod(mtyp, false)
recv = m.receiver()
intertyp = m.intermediateType(m.scope.newIdent(m.mtyp.orig.Obj().Name()))
enc = Name(m.scope.newIdent("enc"))
tag = strings.ToLower(name)
)
fn := Function{
Receiver: recv,
Name: "Marshal" + name,
ReturnTypes: Types{{TypeName: "interface{}"}, {TypeName: "error"}},
Body: []Statement{
declStmt{intertyp},
Declare{Name: enc.Name, TypeName: intertyp.Name},
},
}
fn.Body = append(fn.Body, m.marshalConversions(Name(recv.Name), enc, tag)...)
fn.Body = append(fn.Body, Return{Values: []Expression{AddressOf{Value: enc}, NIL}})
return fn
}
func (m *marshalMethod) receiver() Receiver {
letter := strings.ToLower(m.mtyp.name[:1])
r := Receiver{Name: m.scope.newIdent(letter), Type: Name(m.mtyp.name)}
if m.isUnmarshal {
r.Type = Star{Value: r.Type}
}
return r
}
func (m *marshalMethod) intermediateType(name string) Struct {
s := Struct{Name: name}
for _, f := range m.mtyp.Fields {
if m.isUnmarshal && f.function != nil {
continue // fields generated from functions cannot be assigned on unmarshal
}
typ := f.typ
if m.isUnmarshal {
typ = ensureNilCheckable(typ)
}
s.Fields = append(s.Fields, Field{
Name: f.name,
TypeName: types.TypeString(typ, m.mtyp.scope.qualify),
Tag: f.tag,
})
}
return s
}
func (m *marshalMethod) unmarshalConversions(from, to Var, format string) (s []Statement) {
for _, f := range m.mtyp.Fields {
if f.function != nil {
continue // fields generated from functions cannot be assigned
}
accessFrom := Dotted{Receiver: from, Name: f.name}
accessTo := Dotted{Receiver: to, Name: f.name}
typ := ensureNilCheckable(f.typ)
if !f.isRequired(format) {
s = append(s, If{
Condition: NotEqual{Lhs: accessFrom, Rhs: NIL},
Body: m.convert(accessFrom, accessTo, typ, f.origTyp),
})
} else {
err := fmt.Sprintf("missing required field '%s' for %s", f.encodedName(format), m.mtyp.name)
errors := m.scope.parent.packageName("errors")
s = append(s, If{
Condition: Equals{Lhs: accessFrom, Rhs: NIL},
Body: []Statement{
Return{
Values: []Expression{
CallFunction{
Func: Dotted{Receiver: Name(errors), Name: "New"},
Params: []Expression{stringLit{err}},
},
},
},
},
})
s = append(s, m.convert(accessFrom, accessTo, typ, f.origTyp)...)
}
}
return s
}
func (m *marshalMethod) marshalConversions(from, to Var, format string) (s []Statement) {
for _, f := range m.mtyp.Fields {
accessFrom := Dotted{Receiver: from, Name: f.name}
accessTo := Dotted{Receiver: to, Name: f.name}
if f.function != nil {
s = append(s, m.convert(CallFunction{Func: accessFrom}, accessTo, f.origTyp, f.typ)...)
} else {
s = append(s, m.convert(accessFrom, accessTo, f.origTyp, f.typ)...)
}
}
return s
}
func (m *marshalMethod) convert(from, to Expression, fromtyp, totyp types.Type) (s []Statement) {
// Remove pointer introduced by ensureNilCheckable during field building.
if isPointer(fromtyp) && !isPointer(totyp) {
from = Star{Value: from}
fromtyp = fromtyp.(*types.Pointer).Elem()
} else if !isPointer(fromtyp) && isPointer(totyp) {
from = AddressOf{Value: from}
fromtyp = types.NewPointer(fromtyp)
}
// Generate the conversion.
qf := m.mtyp.scope.qualify
switch {
case types.ConvertibleTo(fromtyp, totyp):
s = append(s, Assign{Lhs: to, Rhs: simpleConv(from, fromtyp, totyp, qf)})
case underlyingSlice(fromtyp) != nil:
s = append(s, m.loopConv(from, to, sliceKV(fromtyp), sliceKV(totyp))...)
case underlyingMap(fromtyp) != nil:
s = append(s, m.loopConv(from, to, mapKV(fromtyp), mapKV(totyp))...)
default:
invalidConv(fromtyp, totyp, qf)
}
return s
}
type kvType struct {
Type types.Type
Key, Elem types.Type
}
func mapKV(typ types.Type) kvType {
maptyp := underlyingMap(typ)
return kvType{typ, maptyp.Key(), maptyp.Elem()}
}
func sliceKV(typ types.Type) kvType {
slicetyp := underlyingSlice(typ)
return kvType{typ, intType, slicetyp.Elem()}
}
func (m *marshalMethod) loopConv(from, to Expression, fromTyp, toTyp kvType) (conv []Statement) {
if hasSideEffects(from) {
orig := from
from = Name(m.scope.newIdent("tmp"))
conv = []Statement{DeclareAndAssign{Lhs: from, Rhs: orig}}
}
// The actual conversion is a loop that assigns each element.
inner := []Statement{
Assign{Lhs: to, Rhs: makeExpr(toTyp.Type, from, m.scope.parent.qualify)},
Range{
Key: m.iterKey,
Value: m.iterVal,
RangeValue: from,
Body: []Statement{Assign{
Lhs: Index{Value: to, Index: simpleConv(m.iterKey, fromTyp.Key, toTyp.Key, m.scope.parent.qualify)},
Rhs: simpleConv(m.iterVal, fromTyp.Elem, toTyp.Elem, m.scope.parent.qualify),
}},
},
}
// Preserve nil maps and slices when marshaling. This is not required for unmarshaling
// methods because the field is already nil-checked earlier.
if !m.isUnmarshal {
inner = []Statement{If{
Condition: NotEqual{Lhs: from, Rhs: NIL},
Body: inner,
}}
}
return append(conv, inner...)
}
func simpleConv(from Expression, fromtyp, totyp types.Type, qf types.Qualifier) Expression {
if types.AssignableTo(fromtyp, totyp) {
return from
}
if !types.ConvertibleTo(fromtyp, totyp) {
invalidConv(fromtyp, totyp, qf)
}
toname := types.TypeString(totyp, qf)
if isPointer(totyp) {
toname = "(" + toname + ")" // hack alert!
}
return CallFunction{Func: Name(toname), Params: []Expression{from}}
}
func invalidConv(from, to types.Type, qf types.Qualifier) {
panic(fmt.Errorf("BUG: invalid conversion %s -> %s", types.TypeString(from, qf), types.TypeString(to, qf)))
}
func makeExpr(typ types.Type, lenfrom Expression, qf types.Qualifier) Expression {
return CallFunction{Func: Name("make"), Params: []Expression{
Name(types.TypeString(typ, qf)),
CallFunction{Func: Name("len"), Params: []Expression{lenfrom}},
}}
}
func errCheck(expr Expression) If {
err := Name("err")
return If{
Init: DeclareAndAssign{Lhs: err, Rhs: expr},
Condition: NotEqual{Lhs: err, Rhs: NIL},
Body: []Statement{Return{Values: []Expression{err}}},
}
}
// hasSideEffects returns whether an expression may have side effects.
func hasSideEffects(expr Expression) bool {
switch expr := expr.(type) {
case Var:
return false
case Dotted:
return hasSideEffects(expr.Receiver)
case Star:
return hasSideEffects(expr.Value)
case Index:
return hasSideEffects(expr.Index) && hasSideEffects(expr.Value)
default:
return true
}
}
type stringLit struct {
V string
}
func (l stringLit) Expression() ast.Expr {
return &ast.BasicLit{Kind: token.STRING, Value: strconv.Quote(l.V)}
}
type declStmt struct {
d Declaration
}
func (ds declStmt) Statement() ast.Stmt {
return &ast.DeclStmt{Decl: ds.d.Declaration()}
}