improve name resolution of type aliases (#2061)

* BREAKING CHANGE: Renamed user_defined_types to type_aliases so it's less confusing with what we call UserDefinedType.
* Added type aliased at the Contract level so now at the file scope there are only top level aliasing and fully qualified contract aliasing like C.myAlias.
* Fix #1809
pull/2111/head
Simone 1 year ago committed by GitHub
parent 53c97f9f48
commit 8b07fe59d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 6
      slither/core/compilation_unit.py
  2. 34
      slither/core/declarations/contract.py
  3. 6
      slither/core/scope/scope.py
  4. 8
      slither/solc_parsing/declarations/contract.py
  5. 2
      slither/solc_parsing/declarations/using_for_top_level.py
  6. 9
      slither/solc_parsing/expressions/find_variable.py
  7. 8
      slither/solc_parsing/slither_compilation_unit_solc.py
  8. 30
      slither/solc_parsing/solidity_types/type_parsing.py
  9. 4
      slither/visitors/slithir/expression_to_slithir.py
  10. 1
      tests/e2e/solc_parsing/test_ast_parsing.py
  11. BIN
      tests/e2e/solc_parsing/test_data/compile/type-aliases.sol-0.8.19-compact.zip
  12. 6
      tests/e2e/solc_parsing/test_data/expected/type-aliases.sol-0.8.19-compact.json
  13. 20
      tests/e2e/solc_parsing/test_data/type-aliases.sol
  14. 6
      tests/unit/core/test_source_mapping.py

@ -47,7 +47,7 @@ class SlitherCompilationUnit(Context):
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomErrorTopLevel] = []
self._user_defined_value_types: Dict[str, TypeAliasTopLevel] = {}
self._type_aliases: Dict[str, TypeAliasTopLevel] = {}
self._all_functions: Set[Function] = set()
self._all_modifiers: Set[Modifier] = set()
@ -220,8 +220,8 @@ class SlitherCompilationUnit(Context):
return self._custom_errors
@property
def user_defined_value_types(self) -> Dict[str, TypeAliasTopLevel]:
return self._user_defined_value_types
def type_aliases(self) -> Dict[str, TypeAliasTopLevel]:
return self._type_aliases
# endregion
###################################################################################

@ -45,6 +45,7 @@ if TYPE_CHECKING:
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.scope.scope import FileScope
from slither.core.cfg.node import Node
from slither.core.solidity_types import TypeAliasContract
LOGGER = logging.getLogger("Contract")
@ -81,6 +82,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
self._functions: Dict[str, "FunctionContract"] = {}
self._linearizedBaseContracts: List[int] = []
self._custom_errors: Dict[str, "CustomErrorContract"] = {}
self._type_aliases: Dict[str, "TypeAliasContract"] = {}
# The only str is "*"
self._using_for: Dict[USING_FOR_KEY, USING_FOR_ITEM] = {}
@ -364,6 +366,38 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
def custom_errors_as_dict(self) -> Dict[str, "CustomErrorContract"]:
return self._custom_errors
# endregion
###################################################################################
###################################################################################
# region Custom Errors
###################################################################################
###################################################################################
@property
def type_aliases(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the contract's custom errors
"""
return list(self._type_aliases.values())
@property
def type_aliases_inherited(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the inherited custom errors
"""
return [s for s in self.type_aliases if s.contract != self]
@property
def type_aliases_declared(self) -> List["TypeAliasContract"]:
"""
list(TypeAliasContract): List of the custom errors declared within the contract (not inherited)
"""
return [s for s in self.type_aliases if s.contract == self]
@property
def type_aliases_as_dict(self) -> Dict[str, "TypeAliasContract"]:
return self._type_aliases
# endregion
###################################################################################
###################################################################################

@ -52,7 +52,7 @@ class FileScope:
# User defined types
# Name -> type alias
self.user_defined_types: Dict[str, TypeAlias] = {}
self.type_aliases: Dict[str, TypeAlias] = {}
def add_accesible_scopes(self) -> bool:
"""
@ -95,8 +95,8 @@ class FileScope:
if not _dict_contain(new_scope.renaming, self.renaming):
self.renaming.update(new_scope.renaming)
learn_something = True
if not _dict_contain(new_scope.user_defined_types, self.user_defined_types):
self.user_defined_types.update(new_scope.user_defined_types)
if not _dict_contain(new_scope.type_aliases, self.type_aliases):
self.type_aliases.update(new_scope.type_aliases)
learn_something = True
return learn_something

@ -291,10 +291,10 @@ class ContractSolc(CallerContextExpression):
alias = item["name"]
alias_canonical = self._contract.name + "." + item["name"]
user_defined_type = TypeAliasContract(original_type, alias, self.underlying_contract)
user_defined_type.set_offset(item["src"], self.compilation_unit)
self._contract.file_scope.user_defined_types[alias] = user_defined_type
self._contract.file_scope.user_defined_types[alias_canonical] = user_defined_type
type_alias = TypeAliasContract(original_type, alias, self.underlying_contract)
type_alias.set_offset(item["src"], self.compilation_unit)
self._contract.type_aliases_as_dict[alias] = type_alias
self._contract.file_scope.type_aliases[alias_canonical] = type_alias
def _parse_struct(self, struct: Dict) -> None:

@ -152,7 +152,7 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
if self._global:
for scope in self.compilation_unit.scopes.values():
if isinstance(type_name, TypeAliasTopLevel):
for alias in scope.user_defined_types.values():
for alias in scope.type_aliases.values():
if alias == type_name:
scope.using_for_directives.add(self._using_for)
elif isinstance(type_name, UserDefinedType):

@ -114,6 +114,8 @@ def find_top_level(
:return:
:rtype:
"""
if var_name in scope.type_aliases:
return scope.type_aliases[var_name], False
if var_name in scope.structures:
return scope.structures[var_name], False
@ -205,6 +207,10 @@ def _find_in_contract(
if sig == var_name:
return modifier
type_aliases = contract.type_aliases_as_dict
if var_name in type_aliases:
return type_aliases[var_name]
# structures are looked on the contract declarer
structures = contract.structures_as_dict
if var_name in structures:
@ -362,9 +368,6 @@ def find_variable(
if var_name in current_scope.renaming:
var_name = current_scope.renaming[var_name]
if var_name in current_scope.user_defined_types:
return current_scope.user_defined_types[var_name], False
# Use ret0/ret1 to help mypy
ret0 = _find_variable_from_ref_declaration(
referenced_declaration, direct_contracts, direct_functions

@ -344,10 +344,10 @@ class SlitherCompilationUnitSolc(CallerContextExpression):
original_type = ElementaryType(underlying_type["name"])
user_defined_type = TypeAliasTopLevel(original_type, alias, scope)
user_defined_type.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.user_defined_value_types[alias] = user_defined_type
scope.user_defined_types[alias] = user_defined_type
type_alias = TypeAliasTopLevel(original_type, alias, scope)
type_alias.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.type_aliases[alias] = type_alias
scope.type_aliases[alias] = type_alias
else:
raise SlitherException(f"Top level {top_level_data[self.get_key()]} not supported")

@ -235,7 +235,7 @@ def parse_type(
sl: "SlitherCompilationUnit"
renaming: Dict[str, str]
user_defined_types: Dict[str, TypeAlias]
type_aliases: Dict[str, TypeAlias]
enums_direct_access: List["Enum"] = []
# Note: for convenicence top level functions use the same parser than function in contract
# but contract_parser is set to None
@ -247,13 +247,13 @@ def parse_type(
sl = caller_context.compilation_unit
next_context = caller_context
renaming = {}
user_defined_types = sl.user_defined_value_types
type_aliases = sl.type_aliases
else:
assert isinstance(caller_context, FunctionSolc)
sl = caller_context.underlying_function.compilation_unit
next_context = caller_context.slither_parser
renaming = caller_context.underlying_function.file_scope.renaming
user_defined_types = caller_context.underlying_function.file_scope.user_defined_types
type_aliases = caller_context.underlying_function.file_scope.type_aliases
structures_direct_access = sl.structures_top_level
all_structuress = [c.structures for c in sl.contracts]
all_structures = [item for sublist in all_structuress for item in sublist]
@ -299,7 +299,7 @@ def parse_type(
functions = list(scope.functions)
renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
elif isinstance(caller_context, (ContractSolc, FunctionSolc)):
sl = caller_context.compilation_unit
if isinstance(caller_context, FunctionSolc):
@ -329,7 +329,7 @@ def parse_type(
functions = contract.functions + contract.modifiers
renaming = scope.renaming
user_defined_types = scope.user_defined_types
type_aliases = scope.type_aliases
else:
raise ParsingError(f"Incorrect caller context: {type(caller_context)}")
@ -343,8 +343,8 @@ def parse_type(
name = t.name
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
return _find_from_type_name(
name,
functions,
@ -365,9 +365,9 @@ def parse_type(
name = t["typeDescriptions"]["typeString"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
@ -386,9 +386,9 @@ def parse_type(
name = t["attributes"][type_name_key]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
_add_type_references(user_defined_types[name], t["src"], sl)
return user_defined_types[name]
if name in type_aliases:
_add_type_references(type_aliases[name], t["src"], sl)
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,
@ -407,8 +407,8 @@ def parse_type(
name = t["name"]
if name in renaming:
name = renaming[name]
if name in user_defined_types:
return user_defined_types[name]
if name in type_aliases:
return type_aliases[name]
type_found = _find_from_type_name(
name,
functions,

@ -516,8 +516,8 @@ class ExpressionToSlithIR(ExpressionVisitor):
# contract A { type MyInt is int}
# contract B { function f() public{ A.MyInt test = A.MyInt.wrap(1);}}
# The logic is handled by _post_call_expression
if expression.member_name in expr.file_scope.user_defined_types:
set_val(expression, expr.file_scope.user_defined_types[expression.member_name])
if expression.member_name in expr.file_scope.type_aliases:
set_val(expression, expr.file_scope.type_aliases[expression.member_name])
return
# Lookup errors referred to as member of contract e.g. Test.myError.selector
if expression.member_name in expr.custom_errors_as_dict:

@ -459,6 +459,7 @@ ALL_TESTS = [
["0.6.9", "0.7.6", "0.8.16"],
),
Test("user_defined_operators-0.8.19.sol", ["0.8.19"]),
Test("type-aliases.sol", ["0.8.19"]),
]
# create the output folder if needed
try:

@ -0,0 +1,6 @@
{
"OtherTest": {
"myfunc()": "digraph{\n0[label=\"Node Type: ENTRY_POINT 0\n\"];\n0->1;\n1[label=\"Node Type: NEW VARIABLE 1\n\"];\n}\n"
},
"DeleteTest": {}
}

@ -0,0 +1,20 @@
struct Z {
int x;
int y;
}
contract OtherTest {
struct Z {
int x;
int y;
}
function myfunc() external {
Z memory z = Z(2,3);
}
}
contract DeleteTest {
type Z is int;
}

@ -85,15 +85,13 @@ def test_references_user_defined_aliases(solc_binary_path):
file = Path(SRC_MAPPING_TEST_ROOT, "ReferencesUserDefinedAliases.sol").as_posix()
slither = Slither(file, solc=solc_path)
alias_top_level = slither.compilation_units[0].user_defined_value_types["aliasTopLevel"]
alias_top_level = slither.compilation_units[0].type_aliases["aliasTopLevel"]
assert len(alias_top_level.references) == 2
lines = _sort_references_lines(alias_top_level.references)
assert lines == [12, 16]
alias_contract_level = (
slither.compilation_units[0]
.contracts[0]
.file_scope.user_defined_types["C.aliasContractLevel"]
slither.compilation_units[0].contracts[0].file_scope.type_aliases["C.aliasContractLevel"]
)
assert len(alias_contract_level.references) == 2
lines = _sort_references_lines(alias_contract_level.references)

Loading…
Cancel
Save