Command gencodec generates marshaling methods for Go struct types.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gencodec/main.go

536 lines
14 KiB

8 years ago
// Copyright 2017 Felix Lange <fjl@twurst.com>.
// Use of this source code is governed by the MIT license,
// which can be found in the LICENSE file.
/*
Command gencodec generates marshaling methods for struct types.
When gencodec is invoked on a directory and type name, it creates a Go source file
containing JSON and YAML marshaling methods for the type. The generated methods add
features which the standard json package cannot offer.
gencodec -dir . -type MyType -out mytype_json.go
Struct Tags
All fields are required unless the "optional" struct tag is present. The generated
unmarshaling method return an error if a required field is missing. Other struct tags are
carried over as is. The standard "json" and "yaml" tags can be used to rename a field when
marshaling to/from JSON.
Example:
type foo {
Required string
Optional string `optional:""`
Renamed string `json:"otherName"`
}
Field Type Overrides
An invocation of gencodec can specify an additional 'field override' struct from which
marshaling type replacements are taken. If the override struct contains a field whose name
matches the original type, the generated marshaling methods will use the overridden type
and convert to and from the original field type.
In this example, the specialString type implements json.Unmarshaler to enforce additional
parsing rules. When json.Unmarshal is used with type foo, the specialString unmarshaler
will be used to parse the value of SpecialField.
//go:generate gencodec -dir . -type foo -field-override fooMarshaling -out foo_json.go
type foo struct {
Field string
SpecialField string
}
type fooMarshaling struct {
SpecialField specialString // overrides type of SpecialField when marshaling/unmarshaling
}
*/
package main
import (
"bytes"
"errors"
"flag"
"fmt"
"go/ast"
"go/importer"
"go/parser"
"go/token"
"go/types"
"io/ioutil"
"os"
"reflect"
"strconv"
"strings"
"text/template"
"golang.org/x/tools/imports"
)
func main() {
var (
pkgdir = flag.String("dir", ".", "input package directory")
output = flag.String("out", "-", "output file")
typename = flag.String("type", "", "type to generate")
overrides = flag.String("field-override", "", "type to take field type replacements from")
)
flag.Parse()
fs := token.NewFileSet()
pkg := loadPackage(fs, *pkgdir)
code := makeMarshalingCode(fs, pkg, *typename, *overrides)
if *output == "-" {
os.Stdout.Write(code)
} else if err := ioutil.WriteFile(*output, code, 0644); err != nil {
fatal(err)
}
}
func loadPackage(fs *token.FileSet, dir string) *types.Package {
// Load the package.
pkgs, err := parser.ParseDir(fs, dir, nil, parser.AllErrors)
if err != nil {
fatal(err)
}
if len(pkgs) == 0 || len(pkgs) > 1 {
fatal(err)
}
var files []*ast.File
var name string
for _, pkg := range pkgs {
for _, file := range pkg.Files {
files = append(files, file)
}
name = pkg.Name
break
}
// Type-check the package.
cfg := types.Config{
IgnoreFuncBodies: true,
FakeImportC: true,
Importer: importer.Default(),
}
tpkg, err := cfg.Check(name, fs, files, nil)
if err != nil {
fatal(err)
}
return tpkg
}
func makeMarshalingCode(fs *token.FileSet, pkg *types.Package, typename, otypename string) (packageBody []byte) {
typ, err := lookupStructType(pkg.Scope(), typename)
if err != nil {
fatal(fmt.Sprintf("can't find %s: %v", typename, err))
}
mtyp := newMarshalerType(fs, pkg, typ)
if otypename != "" {
otyp, err := lookupStructType(pkg.Scope(), otypename)
if err != nil {
fatal(fmt.Sprintf("can't find field replacement type %s: %v", otypename, err))
}
mtyp.loadOverrides(otypename, otyp.Underlying().(*types.Struct))
}
w := new(bytes.Buffer)
fmt.Fprintln(w, "// generated by gencodec, do not edit.\n")
fmt.Fprintln(w, "package ", pkg.Name())
fmt.Fprintln(w, render(mtyp.computeImports(), `
import (
{{- range $name, $path := . }}
{{ $name }} "{{ $path }}"
{{- end }}
)`))
fmt.Fprintln(w)
fmt.Fprintln(w, mtyp.JSONMarshalMethod())
fmt.Fprintln(w)
fmt.Fprintln(w, mtyp.JSONUnmarshalMethod())
fmt.Fprintln(w)
fmt.Fprintln(w, mtyp.YAMLMarshalMethod())
fmt.Fprintln(w)
fmt.Fprintln(w, mtyp.YAMLUnmarshalMethod())
// Use goimports to format the source because it separates imports.
opt := &imports.Options{Comments: true, FormatOnly: true, TabIndent: true, TabWidth: 8}
body, err := imports.Process("", w.Bytes(), opt)
if err != nil {
fatal("can't gofmt generated code:", err, "\n"+w.String())
}
return body
}
// marshalerType represents the intermediate struct type used during marshaling.
// This is the input data to all the Go code templates.
type marshalerType struct {
OrigName string
Name string
Fields []*marshalerField
fs *token.FileSet
orig *types.Named
}
// marshalerField represents a field of the intermediate marshaling type.
type marshalerField struct {
parent *marshalerType
field *types.Var
typ types.Type
tag string
}
func newMarshalerType(fs *token.FileSet, pkg *types.Package, typ *types.Named) *marshalerType {
name := typ.Obj().Name() + "JSON"
styp := typ.Underlying().(*types.Struct)
mtyp := &marshalerType{OrigName: typ.Obj().Name(), Name: name, fs: fs, orig: typ}
for i := 0; i < styp.NumFields(); i++ {
f := styp.Field(i)
if !f.Exported() {
continue
}
mf := &marshalerField{parent: mtyp, field: f, typ: ensurePointer(f.Type()), tag: styp.Tag(i)}
if f.Anonymous() {
fmt.Fprintln(os.Stderr, mf.errorf("Warning: ignoring embedded field"))
continue
}
mtyp.Fields = append(mtyp.Fields, mf)
}
return mtyp
}
// loadOverrides sets field types of the intermediate marshaling type from
// matching fields of otyp.
func (mtyp *marshalerType) loadOverrides(otypename string, otyp *types.Struct) {
for i := 0; i < otyp.NumFields(); i++ {
of := otyp.Field(i)
if of.Anonymous() || !of.Exported() {
fatalf("%v: field override type cannot have embedded or unexported fields", mtyp.fs.Position(of.Pos()))
}
f := mtyp.fieldByName(of.Name())
if f == nil {
fatalf("%v: no matching field for %s in original type %s", mtyp.fs.Position(of.Pos()), of.Name(), mtyp.OrigName)
}
if !types.ConvertibleTo(of.Type(), f.field.Type()) {
fatalf("%v: field override type %s is not convertible to %s", mtyp.fs.Position(of.Pos()), mtyp.typeString(of.Type()), mtyp.typeString(f.field.Type()))
}
f.typ = ensurePointer(of.Type())
}
}
func (mtyp *marshalerType) fieldByName(name string) *marshalerField {
for _, f := range mtyp.Fields {
if f.field.Name() == name {
return f
}
}
return nil
}
// computeImports returns the import paths of all referenced types.
// computeImports must be called before generating any code because it
// renames packages to avoid name clashes.
func (mtyp *marshalerType) computeImports() map[string]string {
seen := make(map[string]string)
counter := 0
add := func(name string, path string, pkg *types.Package) {
if seen[name] != path {
if pkg != nil {
name = "_" + name
pkg.SetName(name)
}
if seen[name] != "" {
// Name clash, add counter.
name += "_" + strconv.Itoa(counter)
counter++
pkg.SetName(name)
}
seen[name] = path
}
}
addNamed := func(typ *types.Named) {
if pkg := typ.Obj().Pkg(); pkg != mtyp.orig.Obj().Pkg() {
add(pkg.Name(), pkg.Path(), pkg)
}
}
// Add packages which always referenced by the generated code.
add("json", "encoding/json", nil)
add("errors", "errors", nil)
for _, f := range mtyp.Fields {
// Add field types of the intermediate struct.
walkNamedTypes(f.typ, addNamed)
// Add field types of the original struct. Note that this won't generate unused
// imports because all fields are either referenced by a conversion or by fields
// of the intermediate struct (if no conversion is needed).
walkNamedTypes(f.field.Type(), addNamed)
}
return seen
}
// JSONMarshalMethod generates MarshalJSON.
func (mtyp *marshalerType) JSONMarshalMethod() string {
return render(mtyp, `
// MarshalJSON implements json.Marshaler.
func (x *{{.OrigName}}) MarshalJSON() ([]byte, error) {
{{.TypeDecl}}
return json.Marshal(&{{.Name}}{
{{- range .Fields}}
{{.Name}}: {{.Convert "x"}},
{{- end}}
})
}`)
}
// YAMLMarsalMethod generates MarshalYAML.
func (mtyp *marshalerType) YAMLMarshalMethod() string {
return render(mtyp, `
// MarshalYAML implements yaml.Marshaler
func (x *{{.OrigName}}) MarshalYAML() (interface{}, error) {
{{.TypeDecl}}
return &{{.Name}}{
{{- range .Fields}}
{{.Name}}: {{.Convert "x"}},
{{- end}}
}, nil
}`)
}
// JSONUnmarshalMethod generates UnmarshalJSON.
func (mtyp *marshalerType) JSONUnmarshalMethod() string {
return render(mtyp, `
// UnmarshalJSON implements json.Unmarshaler.
func (x *{{.OrigName}}) UnmarshalJSON(input []byte) error {
{{.TypeDecl}}
var dec {{.Name}}
if err := json.Unmarshal(input, &dec); err != nil {
return err
}
var v {{.OrigName}}
{{.UnmarshalConversions "json"}}
*x = v
return nil
}`)
}
// YAMLUnmarshalMethod generates UnmarshalYAML.
func (mtyp *marshalerType) YAMLUnmarshalMethod() string {
return render(mtyp, `
// UnmarshalYAML implements yaml.Unmarshaler.
func (x *{{.OrigName}}) UnmarshalYAML(fn func (interface{}) error) error {
{{.TypeDecl}}
var dec {{.Name}}
if err := fn(&dec); err != nil {
return err
}
var v {{.OrigName}}
{{.UnmarshalConversions "yaml"}}
*x = v
return nil
}`)
}
// TypeDecl genereates the declaration of the intermediate marshaling type.
func (mtyp *marshalerType) TypeDecl() string {
return render(mtyp, `
type {{.Name}} struct{
{{- range .Fields}}
{{.Name}} {{.Type}} {{.StructTag}}
{{- end}}
}`)
}
// UnmarshalConversion genereates field conversions and presence checks.
func (mtyp *marshalerType) UnmarshalConversions(formatTag string) (s string) {
type fieldContext struct{ Typ, Name, EncName, Conv string }
for _, mf := range mtyp.Fields {
ctx := fieldContext{
Typ: strings.ToUpper(formatTag) + " " + mtyp.OrigName,
Name: mf.Name(),
EncName: mf.encodedName(formatTag),
Conv: mf.ConvertBack("dec"),
}
if mf.isOptional(formatTag) {
s += render(ctx, `
if dec.{{.Name}} != nil {
v.{{.Name}} = {{.Conv}}
}`)
} else {
s += render(ctx, `
if dec.{{.Name}} == nil {
return errors.New("missing required field '{{.EncName}}' in {{.Typ}}")
}
v.{{.Name}} = {{.Conv}}`)
}
s += "\n"
}
return s
}
func (mf *marshalerField) Name() string {
return mf.field.Name()
}
func (mf *marshalerField) Type() string {
return mf.parent.typeString(mf.typ)
}
func (mf *marshalerField) OrigType() string {
return mf.parent.typeString(mf.typ)
}
func (mf *marshalerField) StructTag() string {
if mf.tag == "" {
return ""
}
return "`" + mf.tag + "`"
}
func (mf *marshalerField) Convert(variable string) string {
expr := fmt.Sprintf("%s.%s", variable, mf.field.Name())
return mf.parent.conversionExpr(expr, mf.field.Type(), mf.typ)
}
func (mf *marshalerField) ConvertBack(variable string) string {
expr := fmt.Sprintf("%s.%s", variable, mf.field.Name())
return mf.parent.conversionExpr(expr, mf.typ, mf.field.Type())
}
func (mtyp *marshalerType) conversionExpr(valueExpr string, from, to types.Type) string {
if isPointer(from) && !isPointer(to) {
valueExpr = "*" + valueExpr
from = from.(*types.Pointer).Elem()
} else if !isPointer(from) && isPointer(to) {
valueExpr = "&" + valueExpr
from = types.NewPointer(from)
}
if types.AssignableTo(from, to) {
return valueExpr
}
return fmt.Sprintf("(%s)(%s)", mtyp.typeString(to), valueExpr)
}
func (mf *marshalerField) errorf(format string, args ...interface{}) error {
pos := mf.parent.fs.Position(mf.field.Pos()).String()
return errors.New(pos + ": (" + mf.parent.OrigName + "." + mf.Name() + ") " + fmt.Sprintf(format, args...))
}
// isOptional returns whether the field is optional when decoding the given format.
func (mf *marshalerField) isOptional(format string) bool {
rtag := reflect.StructTag(mf.tag)
if rtag.Get("optional") == "true" || rtag.Get("optional") == "yes" {
return true
}
// Fields with json:"-" must be treated as optional.
return strings.HasPrefix(rtag.Get(format), "-")
}
// encodedName returns the alternative field name assigned by the format's struct tag.
func (mf *marshalerField) encodedName(format string) string {
val := reflect.StructTag(mf.tag).Get(format)
if comma := strings.Index(val, ","); comma != -1 {
val = val[:comma]
}
if val == "" || val == "-" {
return uncapitalize(mf.Name())
}
return val
}
func (mtyp *marshalerType) typeString(typ types.Type) string {
return types.TypeString(typ, func(pkg *types.Package) string {
if pkg == mtyp.orig.Obj().Pkg() {
return ""
}
return pkg.Name()
})
}
// walkNamedTypes runs the callback for all named types contained in the given type.
func walkNamedTypes(typ types.Type, callback func(*types.Named)) {
switch typ := typ.(type) {
case *types.Basic:
case *types.Chan:
walkNamedTypes(typ.Elem(), callback)
case *types.Map:
walkNamedTypes(typ.Key(), callback)
walkNamedTypes(typ.Elem(), callback)
case *types.Named:
callback(typ)
case *types.Pointer:
walkNamedTypes(typ.Elem(), callback)
case *types.Slice:
walkNamedTypes(typ.Elem(), callback)
case *types.Struct:
for i := 0; i < typ.NumFields(); i++ {
walkNamedTypes(typ.Field(i).Type(), callback)
}
default:
panic(fmt.Errorf("can't walk %T", typ))
}
}
func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
typ, err := lookupType(scope, name)
if err != nil {
return nil, err
}
_, ok := typ.Underlying().(*types.Struct)
if !ok {
return nil, errors.New("not a struct type")
}
return typ, nil
}
func lookupType(scope *types.Scope, name string) (*types.Named, error) {
obj := scope.Lookup(name)
if obj == nil {
return nil, errors.New("no such identifier")
}
typ, ok := obj.(*types.TypeName)
if !ok {
return nil, errors.New("not a type")
}
return typ.Type().(*types.Named), nil
}
func isPointer(typ types.Type) bool {
_, ok := typ.(*types.Pointer)
return ok
}
func ensurePointer(typ types.Type) types.Type {
if isPointer(typ) {
return typ
}
return types.NewPointer(typ)
}
func uncapitalize(s string) string {
return strings.ToLower(s[:1]) + s[1:]
}
func render(data interface{}, text string) string {
t := template.Must(template.New("").Parse(strings.TrimSpace(text)))
out := new(bytes.Buffer)
if err := t.Execute(out, data); err != nil {
panic(err)
}
return out.String()
}
func fatal(args ...interface{}) {
fmt.Fprintln(os.Stderr, args...)
os.Exit(1)
}
func fatalf(format string, args ...interface{}) {
fmt.Fprintf(os.Stderr, format+"\n", args...)
os.Exit(1)
}