Refactor parsing

pull/1378/head
Simone 2 years ago
parent 6571ada9a2
commit 3278417af0
  1. 81
      slither/solc_parsing/declarations/contract.py
  2. 81
      slither/solc_parsing/declarations/using_for_top_level.py
  3. 13
      slither/solc_parsing/slither_compilation_unit_solc.py

@ -5,7 +5,7 @@ from slither.core.declarations import Modifier, Event, EnumContract, StructureCo
from slither.core.declarations.contract import Contract from slither.core.declarations.contract import Contract
from slither.core.declarations.custom_error_contract import CustomErrorContract from slither.core.declarations.custom_error_contract import CustomErrorContract
from slither.core.declarations.function_contract import FunctionContract from slither.core.declarations.function_contract import FunctionContract
from slither.core.solidity_types import ElementaryType, TypeAliasContract from slither.core.solidity_types import ElementaryType, TypeAliasContract, Type
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.solc_parsing.declarations.caller_context import CallerContextExpression from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
@ -574,7 +574,7 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing state variable {e}") self.log_incorrect_parsing(f"Missing state variable {e}")
def analyze_using_for(self): def analyze_using_for(self): # pylint: disable=too-many-branches
try: try:
for father in self._contract.inheritance: for father in self._contract.inheritance:
self._contract.using_for.update(father.using_for) self._contract.using_for.update(father.using_for)
@ -593,18 +593,7 @@ class ContractSolc(CallerContextExpression):
) )
else: else:
# We have a list of functions. A function can be topLevel or a library function # We have a list of functions. A function can be topLevel or a library function
# at this point library function are yet to be parsed so we add the function name self._analyze_function_list(using_for["functionList"], type_name)
# and add the real function later
for f in using_for["functionList"]:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._contract.using_for[type_name].append(function_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)
else: else:
for using_for in self._usingForNotParsed: for using_for in self._usingForNotParsed:
children = using_for[self.get_children()] children = using_for[self.get_children()]
@ -622,34 +611,42 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e: except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}") self.log_incorrect_parsing(f"Missing using for {e}")
def analyze_library_function_using_for(self): def _analyze_function_list(self, function_list: List, type_name: Type):
for type_name, full_names in self._contract.using_for.items(): for f in function_list:
# If it's a string is a library function e.g. L.a function_name = f["function"]["name"]
# We add the actual function and remove the string if function_name.find(".") != -1:
for full_name in full_names: # Library function
if isinstance(full_name, str): self._analyze_library_function(function_name, type_name)
full_name_split = full_name.split(".") else:
# TODO this doesn't handle the case if there is an import with an alias # Top level function
# e.g. MyImport.MyLib.a for tl_function in self.compilation_unit.functions_top_level:
if len(full_name_split) == 2: if tl_function.name == function_name:
library_name = full_name_split[0] self._contract.using_for[type_name].append(tl_function)
function_name = full_name_split[1]
# Get the library function def _analyze_library_function(self, function_name: str, type_name: Type) -> None:
found = False function_name_split = function_name.split(".")
for c in self.compilation_unit.contracts: # TODO this doesn't handle the case if there is an import with an alias
if found: # e.g. MyImport.MyLib.a
break if len(function_name_split) == 2:
if c.name == library_name: library_name = function_name_split[0]
for f in c.functions: function_name = function_name_split[1]
if f.name == function_name: # Get the library function
self._contract.using_for[type_name].append(f) found = False
found = True for c in self.compilation_unit.contracts:
break if found:
self._contract.using_for[type_name].remove(full_name) break
else: if c.name == library_name:
self.log_incorrect_parsing( for f in c.functions:
f"Expected library function instead received {full_name}" if f.name == function_name:
) self._contract.using_for[type_name].append(f)
found = True
break
if not found:
self.log_incorrect_parsing(f"Library function not found {function_name}")
else:
self.log_incorrect_parsing(
f"Expected library function instead received {function_name}"
)
def analyze_enums(self): def analyze_enums(self):
try: try:

@ -2,10 +2,11 @@
Using For Top Level module Using For Top Level module
""" """
import logging import logging
from typing import TYPE_CHECKING, Dict, Union from typing import TYPE_CHECKING, Dict, Union, Any
from slither.core.compilation_unit import SlitherCompilationUnit from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.using_for_top_level import UsingForTopLevel from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import TypeAliasTopLevel from slither.core.solidity_types import TypeAliasTopLevel
from slither.core.declarations import ( from slither.core.declarations import (
StructureTopLevel, StructureTopLevel,
@ -57,26 +58,12 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
if len(full_name_split) == 1: if len(full_name_split) == 1:
# Top level function # Top level function
function_name = full_name_split[0] function_name = full_name_split[0]
for tl_function in self.compilation_unit.functions_top_level: self._analyze_top_level_function(function_name, type_name)
if tl_function.name == function_name:
self._using_for.using_for[type_name].append(tl_function)
self._propagate_global(type_name)
break
elif len(full_name_split) == 2: elif len(full_name_split) == 2:
# Library function # Library function
library_name = full_name_split[0] library_name = full_name_split[0]
function_name = full_name_split[1] function_name = full_name_split[1]
found = False self._analyze_library_function(function_name, library_name, type_name)
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for cf in c.functions:
if cf.name == function_name:
self._using_for.using_for[type_name].append(cf)
self._propagate_global(type_name)
found = True
break
else: else:
# probably case if there is an import with an alias we don't handle it for now # probably case if there is an import with an alias we don't handle it for now
# e.g. MyImport.MyLib.a # e.g. MyImport.MyLib.a
@ -85,6 +72,35 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
) )
continue continue
def _analyze_top_level_function(
self, function_name: str, type_name: Union[TypeAliasTopLevel, UserDefinedType]
) -> None:
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._using_for.using_for[type_name].append(tl_function)
self._propagate_global(type_name)
break
def _analyze_library_function(
self,
function_name: str,
library_name: str,
type_name: Union[TypeAliasTopLevel, UserDefinedType],
) -> None:
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for cf in c.functions:
if cf.name == function_name:
self._using_for.using_for[type_name].append(cf)
self._propagate_global(type_name)
found = True
break
if not found:
LOGGER.warning(f"Library {library_name} - function {function_name} not found")
def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]) -> None: def _propagate_global(self, type_name: Union[TypeAliasTopLevel, UserDefinedType]) -> None:
if self._global: if self._global:
for scope in self.compilation_unit.scopes.values(): for scope in self.compilation_unit.scopes.values():
@ -93,24 +109,29 @@ class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-
if alias == type_name: if alias == type_name:
scope.usingFor.add(self._using_for) scope.usingFor.add(self._using_for)
elif isinstance(type_name, UserDefinedType): elif isinstance(type_name, UserDefinedType):
underlying = type_name.type self._propagate_global_UserDefinedType(scope, type_name)
if isinstance(underlying, StructureTopLevel):
for struct in scope.structures.values():
if struct == underlying:
scope.usingFor.add(self._using_for)
elif isinstance(underlying, EnumTopLevel):
for enum in scope.enums.values():
if enum == underlying:
scope.usingFor.add(self._using_for)
else:
LOGGER.error(
f"Error when propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel"
)
else: else:
LOGGER.error( LOGGER.error(
f"Error when propagating global using for {type_name} {type(type_name)}" f"Error when propagating global using for {type_name} {type(type_name)}"
) )
def _propagate_global_UserDefinedType(
self, scope: Dict[Any, FileScope], type_name: UserDefinedType
):
underlying = type_name.type
if isinstance(underlying, StructureTopLevel):
for struct in scope.structures.values():
if struct == underlying:
scope.usingFor.add(self._using_for)
elif isinstance(underlying, EnumTopLevel):
for enum in scope.enums.values():
if enum == underlying:
scope.usingFor.add(self._using_for)
else:
LOGGER.error(
f"Error when propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel"
)
@property @property
def is_compact_ast(self) -> bool: def is_compact_ast(self) -> bool:
return self._slither_parser.is_compact_ast return self._slither_parser.is_compact_ast

@ -509,11 +509,7 @@ Please rename it, this name is reserved for Slither's internals"""
# Then we analyse state variables, functions and modifiers # Then we analyse state variables, functions and modifiers
self._analyze_third_part(contracts_to_be_analyzed, libraries) self._analyze_third_part(contracts_to_be_analyzed, libraries)
self._analyze_top_level_using_for() self._analyze_using_for(contracts_to_be_analyzed)
# Convert library function (at the moment are string) in using for that specifies list of functions
# to actual function
self._analyze_library_function_using_for(contracts_to_be_analyzed)
self._parsed = True self._parsed = True
@ -625,9 +621,11 @@ Please rename it, this name is reserved for Slither's internals"""
else: else:
contracts_to_be_analyzed += [contract] contracts_to_be_analyzed += [contract]
def _analyze_library_function_using_for(self, contracts_to_be_analyzed: List[ContractSolc]): def _analyze_using_for(self, contracts_to_be_analyzed: List[ContractSolc]):
self._analyze_top_level_using_for()
for c in contracts_to_be_analyzed: for c in contracts_to_be_analyzed:
c.analyze_library_function_using_for() c.analyze_using_for()
def _analyze_enums(self, contract: ContractSolc): def _analyze_enums(self, contract: ContractSolc):
# Enum must be analyzed first # Enum must be analyzed first
@ -651,7 +649,6 @@ Please rename it, this name is reserved for Slither's internals"""
# Event can refer to struct # Event can refer to struct
contract.analyze_events() contract.analyze_events()
contract.analyze_using_for()
contract.analyze_custom_errors() contract.analyze_custom_errors()
contract.set_is_analyzed(True) contract.set_is_analyzed(True)

Loading…
Cancel
Save