diff --git a/slither/core/compilation_unit.py b/slither/core/compilation_unit.py index a2568a5de..d97b7fbf5 100644 --- a/slither/core/compilation_unit.py +++ b/slither/core/compilation_unit.py @@ -16,6 +16,7 @@ from slither.core.declarations import ( from slither.core.declarations.custom_error import CustomError from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel +from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel from slither.core.scope.scope import FileScope from slither.core.variables.state_variable import StateVariable @@ -41,6 +42,7 @@ class SlitherCompilationUnit(Context): self._enums_top_level: List[EnumTopLevel] = [] self._variables_top_level: List[TopLevelVariable] = [] self._functions_top_level: List[FunctionTopLevel] = [] + self._using_for_top_level: List[UsingForTopLevel] = [] self._pragma_directives: List[Pragma] = [] self._import_directives: List[Import] = [] self._custom_errors: List[CustomError] = [] @@ -205,6 +207,10 @@ class SlitherCompilationUnit(Context): def functions_top_level(self) -> List[FunctionTopLevel]: return self._functions_top_level + @property + def using_for_top_level(self) -> List[UsingForTopLevel]: + return self._using_for_top_level + @property def custom_errors(self) -> List[CustomError]: return self._custom_errors diff --git a/slither/core/declarations/__init__.py b/slither/core/declarations/__init__.py index 3b619c1d1..f891ad621 100644 --- a/slither/core/declarations/__init__.py +++ b/slither/core/declarations/__init__.py @@ -12,5 +12,8 @@ from .solidity_variables import ( ) from .structure import Structure from .enum_contract import EnumContract +from .enum_top_level import EnumTopLevel from .structure_contract import StructureContract +from .structure_top_level import StructureTopLevel from .function_contract import FunctionContract +from .function_top_level import FunctionTopLevel diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 5d65d5cc5..ba738b0f8 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -78,6 +78,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods # The only str is "*" self._using_for: Dict[Union[str, Type], List[Type]] = {} + self._using_for_complete: Dict[Union[str, Type], List[Type]] = None self._kind: Optional[str] = None self._is_interface: bool = False self._is_library: bool = False @@ -259,6 +260,27 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods def using_for(self) -> Dict[Union[str, Type], List[Type]]: return self._using_for + @property + def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]: + """ + Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive + """ + + def _merge_using_for(uf1, uf2): + result = {**uf1, **uf2} + for key, value in result.items(): + if key in uf1 and key in uf2: + result[key] = value + uf1[key] + return result + + if self._using_for_complete is None: + result = self.using_for + top_level_using_for = self.file_scope.usingFor + for uftl in top_level_using_for: + result = _merge_using_for(result, uftl.using_for) + self._using_for_complete = result + return self._using_for_complete + # endregion ################################################################################### ################################################################################### diff --git a/slither/core/declarations/using_for_top_level.py b/slither/core/declarations/using_for_top_level.py new file mode 100644 index 000000000..a1b43e1c1 --- /dev/null +++ b/slither/core/declarations/using_for_top_level.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING, List, Dict, Union + +from slither.core.solidity_types.type import Type +from slither.core.declarations.top_level import TopLevel + +if TYPE_CHECKING: + from slither.core.scope.scope import FileScope + + +class UsingForTopLevel(TopLevel): + def __init__(self, scope: "FileScope"): + super().__init__() + self._using_for: Dict[Union[str, Type], List[Type]] = {} + self.file_scope: "FileScope" = scope + + @property + def using_for(self) -> Dict[Type, List[Type]]: + return self._using_for diff --git a/slither/core/scope/scope.py b/slither/core/scope/scope.py index c6d18556e..a27483824 100644 --- a/slither/core/scope/scope.py +++ b/slither/core/scope/scope.py @@ -5,6 +5,7 @@ from slither.core.declarations import Contract, Import, Pragma from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel from slither.core.declarations.enum_top_level import EnumTopLevel from slither.core.declarations.function_top_level import FunctionTopLevel +from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.declarations.structure_top_level import StructureTopLevel from slither.core.solidity_types import TypeAlias from slither.core.variables.top_level_variable import TopLevelVariable @@ -35,6 +36,7 @@ class FileScope: # Because we parse the function signature later on # So we simplify the logic and have the scope fields all populated self.functions: Set[FunctionTopLevel] = set() + self.usingFor: Set[UsingForTopLevel] = set() self.imports: Set[Import] = set() self.pragmas: Set[Pragma] = set() self.structures: Dict[str, StructureTopLevel] = {} @@ -72,6 +74,9 @@ class FileScope: if not new_scope.functions.issubset(self.functions): self.functions |= new_scope.functions learn_something = True + if not new_scope.usingFor.issubset(self.usingFor): + self.usingFor |= new_scope.usingFor + learn_something = True if not new_scope.imports.issubset(self.imports): self.imports |= new_scope.imports learn_something = True diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index 829a23e92..7658deb26 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -171,10 +171,10 @@ def _fits_under_integer(val: int, can_be_int: bool, can_be_uint) -> List[str]: assert can_be_int | can_be_uint while n <= 256: if can_be_uint: - if val <= 2**n - 1: + if val <= 2 ** n - 1: ret.append(f"uint{n}") if can_be_int: - if val <= (2**n) / 2 - 1: + if val <= (2 ** n) / 2 - 1: ret.append(f"int{n}") n = n + 8 return ret @@ -498,7 +498,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals # propagate the type node_function = node.function using_for = ( - node_function.contract.using_for if isinstance(node_function, FunctionContract) else {} + node_function.contract.using_for_complete + if isinstance(node_function, FunctionContract) + else {} ) if isinstance(ir, OperationWithLValue): # Force assignment in case of missing previous correct type @@ -879,7 +881,9 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis # } node_func = ins.node.function using_for = ( - node_func.contract.using_for if isinstance(node_func, FunctionContract) else {} + node_func.contract.using_for_complete + if isinstance(node_func, FunctionContract) + else {} ) targeted_libraries = ( @@ -892,10 +896,14 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis lib_contract_type.type, Contract ): continue - lib_contract = lib_contract_type.type - for lib_func in lib_contract.functions: - if lib_func.name == ins.ori.variable_right: - candidates.append(lib_func) + if isinstance(lib_contract_type, FunctionContract): + # Using for with list of functions, this is the function called + candidates.append(lib_contract_type) + else: + lib_contract = lib_contract_type.type + for lib_func in lib_contract.functions: + if lib_func.name == ins.ori.variable_right: + candidates.append(lib_func) if len(candidates) == 1: lib_func = candidates[0] @@ -1325,7 +1333,10 @@ def convert_to_pop(ir, node): def look_for_library(contract, ir, using_for, t): for destination in using_for[t]: - lib_contract = contract.file_scope.get_contract_from_name(str(destination)) + if isinstance(destination, FunctionContract) and destination.contract.is_library: + lib_contract = destination.contract + else: + lib_contract = contract.file_scope.get_contract_from_name(str(destination)) if lib_contract: lib_call = LibraryCall( lib_contract, diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index b7f938d1d..32ca0e8ab 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -578,17 +578,33 @@ class ContractSolc(CallerContextExpression): try: for father in self._contract.inheritance: self._contract.using_for.update(father.using_for) - if self.is_compact_ast: for using_for in self._usingForNotParsed: - lib_name = parse_type(using_for["libraryName"], self) if "typeName" in using_for and using_for["typeName"]: type_name = parse_type(using_for["typeName"], self) else: type_name = "*" if type_name not in self._contract.using_for: self._contract.using_for[type_name] = [] - self._contract.using_for[type_name].append(lib_name) + + if "libraryName" in using_for: + self._contract.using_for[type_name].append( + parse_type(using_for["libraryName"], self) + ) + else: + # We have a list of functions. A function can be topLevel or a library function + # at this point library function are yet to be parsed so we add the function name + # and add the real function later + for f in using_for["functionList"]: + function_name = f["function"]["name"] + if function_name.find(".") != -1: + # Library function + self._contract.using_for[type_name].append(function_name) + else: + # Top level function + for tl_function in self.compilation_unit.functions_top_level: + if tl_function.name == function_name: + self._contract.using_for[type_name].append(tl_function) else: for using_for in self._usingForNotParsed: children = using_for[self.get_children()] @@ -606,6 +622,35 @@ class ContractSolc(CallerContextExpression): except (VariableNotFound, KeyError) as e: self.log_incorrect_parsing(f"Missing using for {e}") + def analyze_library_function_using_for(self): + for type_name, full_names in self._contract.using_for.items(): + # If it's a string is a library function e.g. L.a + # We add the actual function and remove the string + for full_name in full_names: + if isinstance(full_name, str): + full_name_split = full_name.split(".") + # TODO this doesn't handle the case if there is an import with an alias + # e.g. MyImport.MyLib.a + if len(full_name_split) == 2: + library_name = full_name_split[0] + function_name = full_name_split[1] + # Get the library function + found = False + for c in self.compilation_unit.contracts: + if found: + break + if c.name == library_name: + for f in c.functions: + if f.name == function_name: + self._contract.using_for[type_name].append(f) + found = True + break + self._contract.using_for[type_name].remove(full_name) + else: + self.log_incorrect_parsing( + f"Expected library function instead received {full_name}" + ) + def analyze_enums(self): try: for father in self._contract.inheritance: diff --git a/slither/solc_parsing/declarations/using_for_top_level.py b/slither/solc_parsing/declarations/using_for_top_level.py new file mode 100644 index 000000000..070ed8cf7 --- /dev/null +++ b/slither/solc_parsing/declarations/using_for_top_level.py @@ -0,0 +1,134 @@ +""" + Using For Top Level module +""" +import logging +from typing import TYPE_CHECKING, Dict, Union + +from slither.core.compilation_unit import SlitherCompilationUnit +from slither.core.declarations.using_for_top_level import UsingForTopLevel +from slither.core.solidity_types import Type, TypeAliasTopLevel +from slither.core.declarations import ( + FunctionContract, + FunctionTopLevel, + StructureTopLevel, + EnumTopLevel, +) +from slither.solc_parsing.declarations.caller_context import CallerContextExpression +from slither.solc_parsing.solidity_types.type_parsing import parse_type +from slither.core.solidity_types.user_defined_type import UserDefinedType + +if TYPE_CHECKING: + from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc + +LOGGER = logging.getLogger("UsingForTopLevelSolc") + + +class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-public-methods + """ + UsingFor class + """ + + # elems = [(type, name)] + + def __init__( # pylint: disable=too-many-arguments + self, + uftl: UsingForTopLevel, + top_level_data: Dict, + slither_parser: "SlitherCompilationUnitSolc", + ): + # TODO think if save global here is useful + self._type_name = top_level_data["typeName"] + self._global = top_level_data["global"] + + if "libraryName" in top_level_data: + self._library_name = top_level_data["libraryName"] + else: + self._functions = top_level_data["functionList"] + + self._using_for = uftl + self._slither_parser = slither_parser + + def analyze(self): + type_name = parse_type(self._type_name, self) + self._using_for.using_for[type_name] = [] + + if hasattr(self, "_library_name"): + library_name = parse_type(self._library_name, self) + self._using_for.using_for[type_name].append(library_name) + self._propagate_global(type_name, library_name) + else: + for f in self._functions: + full_name_split = f["function"]["name"].split(".") + if len(full_name_split) == 1: + # Top level function + function_name = full_name_split[0] + for tl_function in self.compilation_unit.functions_top_level: + if tl_function.name == function_name: + self._using_for.using_for[type_name].append(tl_function) + self._propagate_global(type_name, tl_function) + elif len(full_name_split) == 2: + # Library function + library_name = full_name_split[0] + function_name = full_name_split[1] + found = False + for c in self.compilation_unit.contracts: + if found: + break + if c.name == library_name: + for cf in c.functions: + if cf.name == function_name: + self._using_for.using_for[type_name].append(cf) + self._propagate_global(type_name, cf) + found = True + break + else: + # probably case if there is an import with an alias we don't handle it for now + # e.g. MyImport.MyLib.a + return + + def _propagate_global( + self, type_name: Type, to_add: Union[FunctionTopLevel, FunctionContract, UserDefinedType] + ): + if self._global: + for scope in self.compilation_unit.scopes.values(): + if isinstance(type_name, TypeAliasTopLevel): + for alias in scope.user_defined_types.values(): + if alias == type_name: + scope.usingFor.add(self._using_for) + elif isinstance(type_name, UserDefinedType): + underlying = type_name.type + if isinstance(underlying, StructureTopLevel): + for struct in scope.structures.values(): + if struct == underlying: + scope.usingFor.add(self._using_for) + elif isinstance(underlying, EnumTopLevel): + for enum in scope.enums.values(): + if enum == underlying: + scope.usingFor.add(self._using_for) + else: + LOGGER.error( + f"Error propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel" + ) + else: + LOGGER.error( + f"Found {to_add} {type(to_add)} when propagating global using for {type_name} {type(type_name)}" + ) + + @property + def is_compact_ast(self) -> bool: + return self._slither_parser.is_compact_ast + + @property + def compilation_unit(self) -> SlitherCompilationUnit: + return self._slither_parser.compilation_unit + + def get_key(self) -> str: + return self._slither_parser.get_key() + + @property + def slither_parser(self) -> "SlitherCompilationUnitSolc": + return self._slither_parser + + @property + def underlying_using_for(self) -> UsingForTopLevel: + return self._using_for diff --git a/slither/solc_parsing/slither_compilation_unit_solc.py b/slither/solc_parsing/slither_compilation_unit_solc.py index 3054b4470..828229c0c 100644 --- a/slither/solc_parsing/slither_compilation_unit_solc.py +++ b/slither/solc_parsing/slither_compilation_unit_solc.py @@ -14,6 +14,7 @@ from slither.core.declarations.function_top_level import FunctionTopLevel from slither.core.declarations.import_directive import Import from slither.core.declarations.pragma_directive import Pragma from slither.core.declarations.structure_top_level import StructureTopLevel +from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.scope.scope import FileScope from slither.core.solidity_types import ElementaryType, TypeAliasTopLevel from slither.core.variables.top_level_variable import TopLevelVariable @@ -22,6 +23,7 @@ from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.function import FunctionSolc from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc +from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc from slither.solc_parsing.exceptions import VariableNotFound from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc @@ -71,6 +73,7 @@ class SlitherCompilationUnitSolc: self._custom_error_parser: List[CustomErrorSolc] = [] self._variables_top_level_parser: List[TopLevelVariableSolc] = [] self._functions_top_level_parser: List[FunctionSolc] = [] + self._using_for_top_level_parser: List[UsingForTopLevelSolc] = [] self._is_compact_ast = False # self._core: SlitherCore = core @@ -221,6 +224,17 @@ class SlitherCompilationUnitSolc: scope.pragmas.add(pragma) pragma.set_offset(top_level_data["src"], self._compilation_unit) self._compilation_unit.pragma_directives.append(pragma) + + elif top_level_data[self.get_key()] == "UsingForDirective": + scope = self.compilation_unit.get_scope(filename) + usingFor = UsingForTopLevel(scope) + usingFor_parser = UsingForTopLevelSolc(usingFor, top_level_data, self) + usingFor.set_offset(top_level_data["src"], self._compilation_unit) + scope.usingFor.add(usingFor) + + self._compilation_unit.using_for_top_level.append(usingFor) + self._using_for_top_level_parser.append(usingFor_parser) + elif top_level_data[self.get_key()] == "ImportDirective": if self.is_compact_ast: import_directive = Import( @@ -495,6 +509,12 @@ Please rename it, this name is reserved for Slither's internals""" # Then we analyse state variables, functions and modifiers self._analyze_third_part(contracts_to_be_analyzed, libraries) + self._analyze_top_level_using_for() + + # Convert library function (at the moment are string) in using for that specifies list of functions + # to actual function + self._analyze_library_function_using_for(contracts_to_be_analyzed) + self._parsed = True def analyze_contracts(self): # pylint: disable=too-many-statements,too-many-branches @@ -605,6 +625,10 @@ Please rename it, this name is reserved for Slither's internals""" else: contracts_to_be_analyzed += [contract] + def _analyze_library_function_using_for(self, contracts_to_be_analyzed: List[ContractSolc]): + for c in contracts_to_be_analyzed: + c.analyze_library_function_using_for() + def _analyze_enums(self, contract: ContractSolc): # Enum must be analyzed first contract.analyze_enums() @@ -651,6 +675,10 @@ Please rename it, this name is reserved for Slither's internals""" func_parser.analyze_params() self._compilation_unit.add_function(func_parser.underlying_function) + def _analyze_top_level_using_for(self): + for using_for in self._using_for_top_level_parser: + using_for.analyze() + def _analyze_params_custom_error(self): for custom_error_parser in self._custom_error_parser: custom_error_parser.analyze_params() diff --git a/slither/solc_parsing/solidity_types/type_parsing.py b/slither/solc_parsing/solidity_types/type_parsing.py index 9a8ef5db2..b62f908d1 100644 --- a/slither/solc_parsing/solidity_types/type_parsing.py +++ b/slither/solc_parsing/solidity_types/type_parsing.py @@ -224,6 +224,7 @@ def parse_type( from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc from slither.solc_parsing.declarations.contract import ContractSolc from slither.solc_parsing.declarations.function import FunctionSolc + from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc @@ -259,11 +260,16 @@ def parse_type( all_enums += enums_direct_access contracts = sl.contracts functions = [] - elif isinstance(caller_context, (StructureTopLevelSolc, CustomErrorSolc, TopLevelVariableSolc)): + elif isinstance( + caller_context, + (StructureTopLevelSolc, CustomErrorSolc, TopLevelVariableSolc, UsingForTopLevelSolc), + ): if isinstance(caller_context, StructureTopLevelSolc): scope = caller_context.underlying_structure.file_scope elif isinstance(caller_context, TopLevelVariableSolc): scope = caller_context.underlying_variable.file_scope + elif isinstance(caller_context, UsingForTopLevelSolc): + scope = caller_context.underlying_using_for.file_scope else: assert isinstance(caller_context, CustomErrorSolc) custom_error = caller_context.underlying_custom_error