// Copyright 2017 Felix Lange . // Use of this source code is governed by the MIT license, // which can be found in the LICENSE file. package main import ( "errors" "fmt" "go/types" "io" "sort" "strconv" ) // walkNamedTypes runs the callback for all named types contained in the given type. func walkNamedTypes(typ types.Type, callback func(*types.Named)) { switch typ := typ.(type) { case *types.Basic: case *types.Chan: walkNamedTypes(typ.Elem(), callback) case *types.Map: walkNamedTypes(typ.Key(), callback) walkNamedTypes(typ.Elem(), callback) case *types.Named: callback(typ) case *types.Pointer: walkNamedTypes(typ.Elem(), callback) case *types.Slice: walkNamedTypes(typ.Elem(), callback) case *types.Struct: for i := 0; i < typ.NumFields(); i++ { walkNamedTypes(typ.Field(i).Type(), callback) } default: panic(fmt.Errorf("BUG: can't walk %T", typ)) } } func lookupStructType(scope *types.Scope, name string) (*types.Named, error) { typ, err := lookupType(scope, name) if err != nil { return nil, err } _, ok := typ.Underlying().(*types.Struct) if !ok { return nil, errors.New("not a struct type") } return typ, nil } func lookupType(scope *types.Scope, name string) (*types.Named, error) { obj := scope.Lookup(name) if obj == nil { return nil, errors.New("no such identifier") } typ, ok := obj.(*types.TypeName) if !ok { return nil, errors.New("not a type") } return typ.Type().(*types.Named), nil } func isPointer(typ types.Type) bool { _, ok := typ.(*types.Pointer) return ok } func underlyingSlice(typ types.Type) *types.Slice { for { switch typ.(type) { case *types.Named: typ = typ.Underlying() case *types.Slice: return typ.(*types.Slice) default: return nil } } } func underlyingMap(typ types.Type) *types.Map { for { switch typ.(type) { case *types.Named: typ = typ.Underlying() case *types.Map: return typ.(*types.Map) default: return nil } } } func ensureNilCheckable(typ types.Type) types.Type { orig := typ for { switch typ.(type) { case *types.Named: typ = typ.Underlying() case *types.Map, *types.Slice, *types.Pointer: return orig default: return types.NewPointer(orig) } } } // checkConvertible determines whether values of type from can be converted to type to. It // returns nil if convertible and a descriptive error otherwise. // See package documentation for this definition of 'convertible'. func checkConvertible(from, to types.Type) error { if types.ConvertibleTo(from, to) { return nil } // Slices. sfrom := underlyingSlice(from) sto := underlyingSlice(to) if sfrom != nil && sto != nil { if !types.ConvertibleTo(sfrom.Elem(), sto.Elem()) { return fmt.Errorf("slice element type %s is not convertible to %s", sfrom.Elem(), sto.Elem()) } return nil } // Maps. mfrom := underlyingMap(from) mto := underlyingMap(to) if mfrom != nil && mto != nil { if !types.ConvertibleTo(mfrom.Key(), mto.Key()) { return fmt.Errorf("map key type %s is not convertible to %s", mfrom.Key(), mto.Key()) } if !types.ConvertibleTo(mfrom.Elem(), mto.Elem()) { return fmt.Errorf("map element type %s is not convertible to %s", mfrom.Elem(), mto.Elem()) } return nil } return fmt.Errorf("type %s is not convertible to %s", from, to) } // fileScope tracks imports and other names at file scope. type fileScope struct { imports []*types.Package importsByName map[string]*types.Package importNames map[string]string otherNames map[string]bool // non-package identifiers pkg *types.Package imp types.Importer } func newFileScope(imp types.Importer, pkg *types.Package) *fileScope { return &fileScope{otherNames: make(map[string]bool), pkg: pkg, imp: imp} } func (s *fileScope) writeImportDecl(w io.Writer) { fmt.Fprintln(w, "import (") for _, pkg := range s.imports { if s.importNames[pkg.Path()] != pkg.Name() { fmt.Fprintf(w, "\t%s %q\n", s.importNames[pkg.Path()], pkg.Path()) } else { fmt.Fprintf(w, "\t%q\n", pkg.Path()) } } fmt.Fprintln(w, ")") } // addImport loads a package and adds it to the import set. func (s *fileScope) addImport(path string) { pkg, err := s.imp.Import(path) if err != nil { panic(fmt.Errorf("can't import %q: %v", path, err)) } s.insertImport(pkg) s.rebuildImports() } // addReferences marks all names referenced by typ as used. func (s *fileScope) addReferences(typ types.Type) { walkNamedTypes(typ, func(nt *types.Named) { if nt.Obj().Pkg() == s.pkg { s.otherNames[nt.Obj().Name()] = true } else { s.insertImport(nt.Obj().Pkg()) } }) s.rebuildImports() } // insertImport adds pkg to the list of known imports. // This method should not be used directly because it doesn't // rebuild the import name cache. func (s *fileScope) insertImport(pkg *types.Package) { i := sort.Search(len(s.imports), func(i int) bool { return s.imports[i].Path() >= pkg.Path() }) if i < len(s.imports) && s.imports[i] == pkg { return } s.imports = append(s.imports[:i], append([]*types.Package{pkg}, s.imports[i:]...)...) } // rebuildImports caches the names of imported packages. func (s *fileScope) rebuildImports() { s.importNames = make(map[string]string) s.importsByName = make(map[string]*types.Package) for _, pkg := range s.imports { s.maybeRenameImport(pkg) } } func (s *fileScope) maybeRenameImport(pkg *types.Package) { name := pkg.Name() for i := 0; s.isNameTaken(name); i++ { name = pkg.Name() if i > 0 { name += strconv.Itoa(i - 1) } } s.importNames[pkg.Path()] = name s.importsByName[name] = pkg } // isNameTaken reports whether the given name is used by an import or other identifier. func (s *fileScope) isNameTaken(name string) bool { return s.importsByName[name] != nil || s.otherNames[name] || types.Universe.Lookup(name) != nil } // qualify is a types.Qualifier that prepends the (possibly renamed) package name of // imported types to a type name. func (s *fileScope) qualify(pkg *types.Package) string { if pkg == s.pkg { return "" } return s.packageName(pkg.Path()) } func (s *fileScope) packageName(path string) string { name, ok := s.importNames[path] if !ok { panic("BUG: missing package " + path) } return name } // funcScope tracks used identifiers in a function. It can create new identifiers that do // not clash with the parent scope. type funcScope struct { used map[string]bool parent *fileScope } func newFuncScope(parent *fileScope) *funcScope { return &funcScope{make(map[string]bool), parent} } // newIdent creates a new identifier that doesn't clash with any name // in the scope or its parent file scope. func (s *funcScope) newIdent(base string) string { for i := 0; ; i++ { name := base if i > 0 { name += strconv.Itoa(i - 1) } if !s.parent.isNameTaken(name) && !s.used[name] { s.used[name] = true return name } } }