Command gencodec generates marshaling methods for Go struct types.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
gencodec/typeutil.go

273 lines
6.9 KiB

// Copyright 2017 Felix Lange <fjl@twurst.com>.
// Use of this source code is governed by the MIT license,
// which can be found in the LICENSE file.
package main
import (
"errors"
"fmt"
"go/types"
"io"
"sort"
"strconv"
)
// walkNamedTypes runs the callback for all named types contained in the given type.
func walkNamedTypes(typ types.Type, callback func(*types.Named)) {
switch typ := typ.(type) {
case *types.Basic:
case *types.Chan:
walkNamedTypes(typ.Elem(), callback)
case *types.Map:
walkNamedTypes(typ.Key(), callback)
walkNamedTypes(typ.Elem(), callback)
case *types.Named:
callback(typ)
case *types.Pointer:
walkNamedTypes(typ.Elem(), callback)
case *types.Slice:
walkNamedTypes(typ.Elem(), callback)
case *types.Struct:
for i := 0; i < typ.NumFields(); i++ {
walkNamedTypes(typ.Field(i).Type(), callback)
}
case *types.Interface:
if typ.NumMethods() > 0 {
panic("BUG: can't walk non-empty interface")
}
default:
panic(fmt.Errorf("BUG: can't walk %T", typ))
}
}
func lookupStructType(scope *types.Scope, name string) (*types.Named, error) {
typ, err := lookupType(scope, name)
if err != nil {
return nil, err
}
_, ok := typ.Underlying().(*types.Struct)
if !ok {
return nil, errors.New("not a struct type")
}
return typ, nil
}
func lookupType(scope *types.Scope, name string) (*types.Named, error) {
obj := scope.Lookup(name)
if obj == nil {
return nil, errors.New("no such identifier")
}
typ, ok := obj.(*types.TypeName)
if !ok {
return nil, errors.New("not a type")
}
return typ.Type().(*types.Named), nil
}
func isPointer(typ types.Type) bool {
_, ok := typ.(*types.Pointer)
return ok
}
func underlyingSlice(typ types.Type) *types.Slice {
for {
switch typ.(type) {
case *types.Named:
typ = typ.Underlying()
case *types.Slice:
return typ.(*types.Slice)
default:
return nil
}
}
}
func underlyingMap(typ types.Type) *types.Map {
for {
switch typ.(type) {
case *types.Named:
typ = typ.Underlying()
case *types.Map:
return typ.(*types.Map)
default:
return nil
}
}
}
func ensureNilCheckable(typ types.Type) types.Type {
orig := typ
for {
switch typ.(type) {
case *types.Named:
typ = typ.Underlying()
case *types.Map, *types.Slice, *types.Pointer, *types.Interface:
return orig
default:
return types.NewPointer(orig)
}
}
}
// checkConvertible determines whether values of type from can be converted to type to. It
// returns nil if convertible and a descriptive error otherwise.
// See package documentation for this definition of 'convertible'.
func checkConvertible(from, to types.Type) error {
if types.ConvertibleTo(from, to) {
return nil
}
// Slices.
sfrom := underlyingSlice(from)
sto := underlyingSlice(to)
if sfrom != nil && sto != nil {
if !types.ConvertibleTo(sfrom.Elem(), sto.Elem()) {
return fmt.Errorf("slice element type %s is not convertible to %s", sfrom.Elem(), sto.Elem())
}
return nil
}
// Maps.
mfrom := underlyingMap(from)
mto := underlyingMap(to)
if mfrom != nil && mto != nil {
if !types.ConvertibleTo(mfrom.Key(), mto.Key()) {
return fmt.Errorf("map key type %s is not convertible to %s", mfrom.Key(), mto.Key())
}
if !types.ConvertibleTo(mfrom.Elem(), mto.Elem()) {
return fmt.Errorf("map element type %s is not convertible to %s", mfrom.Elem(), mto.Elem())
}
return nil
}
return fmt.Errorf("type %s is not convertible to %s", from, to)
}
// fileScope tracks imports and other names at file scope.
type fileScope struct {
imports []*types.Package
importsByName map[string]*types.Package
importNames map[string]string
otherNames map[string]bool // non-package identifiers
pkg *types.Package
imp types.Importer
}
func newFileScope(imp types.Importer, pkg *types.Package) *fileScope {
return &fileScope{otherNames: make(map[string]bool), pkg: pkg, imp: imp}
}
func (s *fileScope) writeImportDecl(w io.Writer) {
fmt.Fprintln(w, "import (")
for _, pkg := range s.imports {
if s.importNames[pkg.Path()] != pkg.Name() {
fmt.Fprintf(w, "\t%s %q\n", s.importNames[pkg.Path()], pkg.Path())
} else {
fmt.Fprintf(w, "\t%q\n", pkg.Path())
}
}
fmt.Fprintln(w, ")")
}
// addImport loads a package and adds it to the import set.
func (s *fileScope) addImport(path string) {
pkg, err := s.imp.Import(path)
if err != nil {
panic(fmt.Errorf("can't import %q: %v", path, err))
}
s.insertImport(pkg)
s.rebuildImports()
}
// addReferences marks all names referenced by typ as used.
func (s *fileScope) addReferences(typ types.Type) {
walkNamedTypes(typ, func(nt *types.Named) {
pkg := nt.Obj().Pkg()
if pkg == s.pkg {
s.otherNames[nt.Obj().Name()] = true
} else if pkg != nil {
s.insertImport(nt.Obj().Pkg())
}
})
s.rebuildImports()
}
// insertImport adds pkg to the list of known imports.
// This method should not be used directly because it doesn't
// rebuild the import name cache.
func (s *fileScope) insertImport(pkg *types.Package) {
i := sort.Search(len(s.imports), func(i int) bool {
return s.imports[i].Path() >= pkg.Path()
})
if i < len(s.imports) && s.imports[i] == pkg {
return
}
s.imports = append(s.imports[:i], append([]*types.Package{pkg}, s.imports[i:]...)...)
}
// rebuildImports caches the names of imported packages.
func (s *fileScope) rebuildImports() {
s.importNames = make(map[string]string)
s.importsByName = make(map[string]*types.Package)
for _, pkg := range s.imports {
s.maybeRenameImport(pkg)
}
}
func (s *fileScope) maybeRenameImport(pkg *types.Package) {
name := pkg.Name()
for i := 0; s.isNameTaken(name); i++ {
name = pkg.Name()
if i > 0 {
name += strconv.Itoa(i - 1)
}
}
s.importNames[pkg.Path()] = name
s.importsByName[name] = pkg
}
// isNameTaken reports whether the given name is used by an import or other identifier.
func (s *fileScope) isNameTaken(name string) bool {
return s.importsByName[name] != nil || s.otherNames[name] || types.Universe.Lookup(name) != nil
}
// qualify is a types.Qualifier that prepends the (possibly renamed) package name of
// imported types to a type name.
func (s *fileScope) qualify(pkg *types.Package) string {
if pkg == s.pkg {
return ""
}
return s.packageName(pkg.Path())
}
func (s *fileScope) packageName(path string) string {
name, ok := s.importNames[path]
if !ok {
panic("BUG: missing package " + path)
}
return name
}
// funcScope tracks used identifiers in a function. It can create new identifiers that do
// not clash with the parent scope.
type funcScope struct {
used map[string]bool
parent *fileScope
}
func newFuncScope(parent *fileScope) *funcScope {
return &funcScope{make(map[string]bool), parent}
}
// newIdent creates a new identifier that doesn't clash with any name
// in the scope or its parent file scope.
func (s *funcScope) newIdent(base string) string {
for i := 0; ; i++ {
name := base
if i > 0 {
name += strconv.Itoa(i - 1)
}
if !s.parent.isNameTaken(name) && !s.used[name] {
s.used[name] = true
return name
}
}
}