From cbc3077a65e6cefc21537aeb59ef57f1963212ba Mon Sep 17 00:00:00 2001 From: Simone Date: Fri, 6 Jan 2023 01:18:41 +0100 Subject: [PATCH] Fix using for with alias import --- slither/solc_parsing/declarations/contract.py | 78 +++++++++++-------- .../declarations/using_for_top_level.py | 38 ++++++--- 2 files changed, 74 insertions(+), 42 deletions(-) diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 3095b6854..a93914449 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -616,39 +616,55 @@ class ContractSolc(CallerContextExpression): def _analyze_function_list(self, function_list: List, type_name: Type): for f in function_list: - function_name = f["function"]["name"] - if function_name.find(".") != -1: - # Library function - self._analyze_library_function(function_name, type_name) - else: + full_name_split = f["function"]["name"].split(".") + if len(full_name_split) == 1: # Top level function - for tl_function in self.compilation_unit.functions_top_level: - if tl_function.name == function_name: - self._contract.using_for[type_name].append(tl_function) - - def _analyze_library_function(self, function_name: str, type_name: Type) -> None: - function_name_split = function_name.split(".") - # TODO this doesn't handle the case if there is an import with an alias - # e.g. MyImport.MyLib.a - if len(function_name_split) == 2: - library_name = function_name_split[0] - function_name = function_name_split[1] - # Get the library function - found = False - for c in self.compilation_unit.contracts: - if found: - break - if c.name == library_name: - for f in c.functions: - if f.name == function_name: - self._contract.using_for[type_name].append(f) - found = True - break - if not found: - self.log_incorrect_parsing(f"Library function not found {function_name}") - else: + function_name = full_name_split[0] + self._analyze_top_level_function(function_name, type_name) + elif len(full_name_split) == 2: + # It can be a top level function behind an aliased import + # or a library function + first_part = full_name_split[0] + function_name = full_name_split[1] + self._check_aliased_import(first_part, function_name, type_name) + else: + # MyImport.MyLib.a we don't care of the alias + library_name = full_name_split[1] + function_name = full_name_split[2] + self._analyze_library_function(library_name, function_name, type_name) + + def _check_aliased_import(self, first_part: str, function_name: str, type_name: Type): + # We check if the first part appear as alias for an import + # if it is then function_name must be a top level function + # otherwise it's a library function + for i in self._contract.file_scope.imports: + if i.alias == first_part: + self._analyze_top_level_function(function_name, type_name) + return + self._analyze_library_function(first_part, function_name, type_name) + + def _analyze_top_level_function(self, function_name: str, type_name: Type): + for tl_function in self.compilation_unit.functions_top_level: + if tl_function.name == function_name: + self._contract.using_for[type_name].append(tl_function) + + def _analyze_library_function( + self, library_name: str, function_name: str, type_name: Type + ) -> None: + # Get the library function + found = False + for c in self.compilation_unit.contracts: + if found: + break + if c.name == library_name: + for f in c.functions: + if f.name == function_name: + self._contract.using_for[type_name].append(f) + found = True + break + if not found: self.log_incorrect_parsing( - f"Expected library function instead received {function_name}" + f"Contract level using for: Library {library_name} - function {function_name} not found" ) def analyze_enums(self): diff --git a/slither/solc_parsing/declarations/using_for_top_level.py b/slither/solc_parsing/declarations/using_for_top_level.py index 3ec191d46..ad4dd008d 100644 --- a/slither/solc_parsing/declarations/using_for_top_level.py +++ b/slither/solc_parsing/declarations/using_for_top_level.py @@ -61,17 +61,31 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- function_name = full_name_split[0] self._analyze_top_level_function(function_name, type_name) elif len(full_name_split) == 2: - # Library function - library_name = full_name_split[0] + # It can be a top level function behind an aliased import + # or a library function + first_part = full_name_split[0] function_name = full_name_split[1] - self._analyze_library_function(function_name, library_name, type_name) + self._check_aliased_import(first_part, function_name, type_name) else: - # probably case if there is an import with an alias we don't handle it for now - # e.g. MyImport.MyLib.a - LOGGER.warning( - f"Using for directive for function {f['function']['name']} not supported" - ) - continue + # MyImport.MyLib.a we don't care of the alias + library_name = full_name_split[1] + function_name = full_name_split[2] + self._analyze_library_function(library_name, function_name, type_name) + + def _check_aliased_import( + self, + first_part: str, + function_name: str, + type_name: Union[TypeAliasTopLevel, UserDefinedType], + ): + # We check if the first part appear as alias for an import + # if it is then function_name must be a top level function + # otherwise it's a library function + for i in self._using_for.file_scope.imports: + if i.alias == first_part: + self._analyze_top_level_function(function_name, type_name) + return + self._analyze_library_function(first_part, function_name, type_name) def _analyze_top_level_function( self, function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType] @@ -84,8 +98,8 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- def _analyze_library_function( self, - function_name: str, library_name: str, + function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType], ) -> None: found = False @@ -100,7 +114,9 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few- found = True break if not found: - LOGGER.warning(f"Library {library_name} - function {function_name} not found") + LOGGER.warning( + f"Top level using for: Library {library_name} - function {function_name} not found" + ) def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]) -> None: if self._global: