Improve using for support

pull/1378/head
Simone 2 years ago
parent 78397b94dd
commit f0439905e0
  1. 6
      slither/core/compilation_unit.py
  2. 3
      slither/core/declarations/__init__.py
  3. 22
      slither/core/declarations/contract.py
  4. 18
      slither/core/declarations/using_for_top_level.py
  5. 5
      slither/core/scope/scope.py
  6. 29
      slither/slithir/convert.py
  7. 51
      slither/solc_parsing/declarations/contract.py
  8. 134
      slither/solc_parsing/declarations/using_for_top_level.py
  9. 28
      slither/solc_parsing/slither_compilation_unit_solc.py
  10. 8
      slither/solc_parsing/solidity_types/type_parsing.py

@ -16,6 +16,7 @@ from slither.core.declarations import (
from slither.core.declarations.custom_error import CustomError
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.scope.scope import FileScope
from slither.core.variables.state_variable import StateVariable
@ -41,6 +42,7 @@ class SlitherCompilationUnit(Context):
self._enums_top_level: List[EnumTopLevel] = []
self._variables_top_level: List[TopLevelVariable] = []
self._functions_top_level: List[FunctionTopLevel] = []
self._using_for_top_level: List[UsingForTopLevel] = []
self._pragma_directives: List[Pragma] = []
self._import_directives: List[Import] = []
self._custom_errors: List[CustomError] = []
@ -205,6 +207,10 @@ class SlitherCompilationUnit(Context):
def functions_top_level(self) -> List[FunctionTopLevel]:
return self._functions_top_level
@property
def using_for_top_level(self) -> List[UsingForTopLevel]:
return self._using_for_top_level
@property
def custom_errors(self) -> List[CustomError]:
return self._custom_errors

@ -12,5 +12,8 @@ from .solidity_variables import (
)
from .structure import Structure
from .enum_contract import EnumContract
from .enum_top_level import EnumTopLevel
from .structure_contract import StructureContract
from .structure_top_level import StructureTopLevel
from .function_contract import FunctionContract
from .function_top_level import FunctionTopLevel

@ -78,6 +78,7 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
# The only str is "*"
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self._using_for_complete: Dict[Union[str, Type], List[Type]] = None
self._kind: Optional[str] = None
self._is_interface: bool = False
self._is_library: bool = False
@ -259,6 +260,27 @@ class Contract(SourceMapping): # pylint: disable=too-many-public-methods
def using_for(self) -> Dict[Union[str, Type], List[Type]]:
return self._using_for
@property
def using_for_complete(self) -> Dict[Union[str, Type], List[Type]]:
"""
Dict[Union[str, Type], List[Type]]: Dict of merged local using for directive with top level directive
"""
def _merge_using_for(uf1, uf2):
result = {**uf1, **uf2}
for key, value in result.items():
if key in uf1 and key in uf2:
result[key] = value + uf1[key]
return result
if self._using_for_complete is None:
result = self.using_for
top_level_using_for = self.file_scope.usingFor
for uftl in top_level_using_for:
result = _merge_using_for(result, uftl.using_for)
self._using_for_complete = result
return self._using_for_complete
# endregion
###################################################################################
###################################################################################

@ -0,0 +1,18 @@
from typing import TYPE_CHECKING, List, Dict, Union
from slither.core.solidity_types.type import Type
from slither.core.declarations.top_level import TopLevel
if TYPE_CHECKING:
from slither.core.scope.scope import FileScope
class UsingForTopLevel(TopLevel):
def __init__(self, scope: "FileScope"):
super().__init__()
self._using_for: Dict[Union[str, Type], List[Type]] = {}
self.file_scope: "FileScope" = scope
@property
def using_for(self) -> Dict[Type, List[Type]]:
return self._using_for

@ -5,6 +5,7 @@ from slither.core.declarations import Contract, Import, Pragma
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.solidity_types import TypeAlias
from slither.core.variables.top_level_variable import TopLevelVariable
@ -35,6 +36,7 @@ class FileScope:
# Because we parse the function signature later on
# So we simplify the logic and have the scope fields all populated
self.functions: Set[FunctionTopLevel] = set()
self.usingFor: Set[UsingForTopLevel] = set()
self.imports: Set[Import] = set()
self.pragmas: Set[Pragma] = set()
self.structures: Dict[str, StructureTopLevel] = {}
@ -72,6 +74,9 @@ class FileScope:
if not new_scope.functions.issubset(self.functions):
self.functions |= new_scope.functions
learn_something = True
if not new_scope.usingFor.issubset(self.usingFor):
self.usingFor |= new_scope.usingFor
learn_something = True
if not new_scope.imports.issubset(self.imports):
self.imports |= new_scope.imports
learn_something = True

@ -171,10 +171,10 @@ def _fits_under_integer(val: int, can_be_int: bool, can_be_uint) -> List[str]:
assert can_be_int | can_be_uint
while n <= 256:
if can_be_uint:
if val <= 2**n - 1:
if val <= 2 ** n - 1:
ret.append(f"uint{n}")
if can_be_int:
if val <= (2**n) / 2 - 1:
if val <= (2 ** n) / 2 - 1:
ret.append(f"int{n}")
n = n + 8
return ret
@ -498,7 +498,9 @@ def propagate_types(ir, node: "Node"): # pylint: disable=too-many-locals
# propagate the type
node_function = node.function
using_for = (
node_function.contract.using_for if isinstance(node_function, FunctionContract) else {}
node_function.contract.using_for_complete
if isinstance(node_function, FunctionContract)
else {}
)
if isinstance(ir, OperationWithLValue):
# Force assignment in case of missing previous correct type
@ -879,7 +881,9 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
# }
node_func = ins.node.function
using_for = (
node_func.contract.using_for if isinstance(node_func, FunctionContract) else {}
node_func.contract.using_for_complete
if isinstance(node_func, FunctionContract)
else {}
)
targeted_libraries = (
@ -892,10 +896,14 @@ def extract_tmp_call(ins: TmpCall, contract: Optional[Contract]): # pylint: dis
lib_contract_type.type, Contract
):
continue
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)
if isinstance(lib_contract_type, FunctionContract):
# Using for with list of functions, this is the function called
candidates.append(lib_contract_type)
else:
lib_contract = lib_contract_type.type
for lib_func in lib_contract.functions:
if lib_func.name == ins.ori.variable_right:
candidates.append(lib_func)
if len(candidates) == 1:
lib_func = candidates[0]
@ -1325,7 +1333,10 @@ def convert_to_pop(ir, node):
def look_for_library(contract, ir, using_for, t):
for destination in using_for[t]:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if isinstance(destination, FunctionContract) and destination.contract.is_library:
lib_contract = destination.contract
else:
lib_contract = contract.file_scope.get_contract_from_name(str(destination))
if lib_contract:
lib_call = LibraryCall(
lib_contract,

@ -578,17 +578,33 @@ class ContractSolc(CallerContextExpression):
try:
for father in self._contract.inheritance:
self._contract.using_for.update(father.using_for)
if self.is_compact_ast:
for using_for in self._usingForNotParsed:
lib_name = parse_type(using_for["libraryName"], self)
if "typeName" in using_for and using_for["typeName"]:
type_name = parse_type(using_for["typeName"], self)
else:
type_name = "*"
if type_name not in self._contract.using_for:
self._contract.using_for[type_name] = []
self._contract.using_for[type_name].append(lib_name)
if "libraryName" in using_for:
self._contract.using_for[type_name].append(
parse_type(using_for["libraryName"], self)
)
else:
# We have a list of functions. A function can be topLevel or a library function
# at this point library function are yet to be parsed so we add the function name
# and add the real function later
for f in using_for["functionList"]:
function_name = f["function"]["name"]
if function_name.find(".") != -1:
# Library function
self._contract.using_for[type_name].append(function_name)
else:
# Top level function
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._contract.using_for[type_name].append(tl_function)
else:
for using_for in self._usingForNotParsed:
children = using_for[self.get_children()]
@ -606,6 +622,35 @@ class ContractSolc(CallerContextExpression):
except (VariableNotFound, KeyError) as e:
self.log_incorrect_parsing(f"Missing using for {e}")
def analyze_library_function_using_for(self):
for type_name, full_names in self._contract.using_for.items():
# If it's a string is a library function e.g. L.a
# We add the actual function and remove the string
for full_name in full_names:
if isinstance(full_name, str):
full_name_split = full_name.split(".")
# TODO this doesn't handle the case if there is an import with an alias
# e.g. MyImport.MyLib.a
if len(full_name_split) == 2:
library_name = full_name_split[0]
function_name = full_name_split[1]
# Get the library function
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for f in c.functions:
if f.name == function_name:
self._contract.using_for[type_name].append(f)
found = True
break
self._contract.using_for[type_name].remove(full_name)
else:
self.log_incorrect_parsing(
f"Expected library function instead received {full_name}"
)
def analyze_enums(self):
try:
for father in self._contract.inheritance:

@ -0,0 +1,134 @@
"""
Using For Top Level module
"""
import logging
from typing import TYPE_CHECKING, Dict, Union
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.solidity_types import Type, TypeAliasTopLevel
from slither.core.declarations import (
FunctionContract,
FunctionTopLevel,
StructureTopLevel,
EnumTopLevel,
)
from slither.solc_parsing.declarations.caller_context import CallerContextExpression
from slither.solc_parsing.solidity_types.type_parsing import parse_type
from slither.core.solidity_types.user_defined_type import UserDefinedType
if TYPE_CHECKING:
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
LOGGER = logging.getLogger("UsingForTopLevelSolc")
class UsingForTopLevelSolc(CallerContextExpression): # pylint: disable=too-few-public-methods
"""
UsingFor class
"""
# elems = [(type, name)]
def __init__( # pylint: disable=too-many-arguments
self,
uftl: UsingForTopLevel,
top_level_data: Dict,
slither_parser: "SlitherCompilationUnitSolc",
):
# TODO think if save global here is useful
self._type_name = top_level_data["typeName"]
self._global = top_level_data["global"]
if "libraryName" in top_level_data:
self._library_name = top_level_data["libraryName"]
else:
self._functions = top_level_data["functionList"]
self._using_for = uftl
self._slither_parser = slither_parser
def analyze(self):
type_name = parse_type(self._type_name, self)
self._using_for.using_for[type_name] = []
if hasattr(self, "_library_name"):
library_name = parse_type(self._library_name, self)
self._using_for.using_for[type_name].append(library_name)
self._propagate_global(type_name, library_name)
else:
for f in self._functions:
full_name_split = f["function"]["name"].split(".")
if len(full_name_split) == 1:
# Top level function
function_name = full_name_split[0]
for tl_function in self.compilation_unit.functions_top_level:
if tl_function.name == function_name:
self._using_for.using_for[type_name].append(tl_function)
self._propagate_global(type_name, tl_function)
elif len(full_name_split) == 2:
# Library function
library_name = full_name_split[0]
function_name = full_name_split[1]
found = False
for c in self.compilation_unit.contracts:
if found:
break
if c.name == library_name:
for cf in c.functions:
if cf.name == function_name:
self._using_for.using_for[type_name].append(cf)
self._propagate_global(type_name, cf)
found = True
break
else:
# probably case if there is an import with an alias we don't handle it for now
# e.g. MyImport.MyLib.a
return
def _propagate_global(
self, type_name: Type, to_add: Union[FunctionTopLevel, FunctionContract, UserDefinedType]
):
if self._global:
for scope in self.compilation_unit.scopes.values():
if isinstance(type_name, TypeAliasTopLevel):
for alias in scope.user_defined_types.values():
if alias == type_name:
scope.usingFor.add(self._using_for)
elif isinstance(type_name, UserDefinedType):
underlying = type_name.type
if isinstance(underlying, StructureTopLevel):
for struct in scope.structures.values():
if struct == underlying:
scope.usingFor.add(self._using_for)
elif isinstance(underlying, EnumTopLevel):
for enum in scope.enums.values():
if enum == underlying:
scope.usingFor.add(self._using_for)
else:
LOGGER.error(
f"Error propagating global {underlying} {type(underlying)} not a StructTopLevel or EnumTopLevel"
)
else:
LOGGER.error(
f"Found {to_add} {type(to_add)} when propagating global using for {type_name} {type(type_name)}"
)
@property
def is_compact_ast(self) -> bool:
return self._slither_parser.is_compact_ast
@property
def compilation_unit(self) -> SlitherCompilationUnit:
return self._slither_parser.compilation_unit
def get_key(self) -> str:
return self._slither_parser.get_key()
@property
def slither_parser(self) -> "SlitherCompilationUnitSolc":
return self._slither_parser
@property
def underlying_using_for(self) -> UsingForTopLevel:
return self._using_for

@ -14,6 +14,7 @@ from slither.core.declarations.function_top_level import FunctionTopLevel
from slither.core.declarations.import_directive import Import
from slither.core.declarations.pragma_directive import Pragma
from slither.core.declarations.structure_top_level import StructureTopLevel
from slither.core.declarations.using_for_top_level import UsingForTopLevel
from slither.core.scope.scope import FileScope
from slither.core.solidity_types import ElementaryType, TypeAliasTopLevel
from slither.core.variables.top_level_variable import TopLevelVariable
@ -22,6 +23,7 @@ from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc
from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc
from slither.solc_parsing.exceptions import VariableNotFound
from slither.solc_parsing.variables.top_level_variable import TopLevelVariableSolc
@ -71,6 +73,7 @@ class SlitherCompilationUnitSolc:
self._custom_error_parser: List[CustomErrorSolc] = []
self._variables_top_level_parser: List[TopLevelVariableSolc] = []
self._functions_top_level_parser: List[FunctionSolc] = []
self._using_for_top_level_parser: List[UsingForTopLevelSolc] = []
self._is_compact_ast = False
# self._core: SlitherCore = core
@ -221,6 +224,17 @@ class SlitherCompilationUnitSolc:
scope.pragmas.add(pragma)
pragma.set_offset(top_level_data["src"], self._compilation_unit)
self._compilation_unit.pragma_directives.append(pragma)
elif top_level_data[self.get_key()] == "UsingForDirective":
scope = self.compilation_unit.get_scope(filename)
usingFor = UsingForTopLevel(scope)
usingFor_parser = UsingForTopLevelSolc(usingFor, top_level_data, self)
usingFor.set_offset(top_level_data["src"], self._compilation_unit)
scope.usingFor.add(usingFor)
self._compilation_unit.using_for_top_level.append(usingFor)
self._using_for_top_level_parser.append(usingFor_parser)
elif top_level_data[self.get_key()] == "ImportDirective":
if self.is_compact_ast:
import_directive = Import(
@ -495,6 +509,12 @@ Please rename it, this name is reserved for Slither's internals"""
# Then we analyse state variables, functions and modifiers
self._analyze_third_part(contracts_to_be_analyzed, libraries)
self._analyze_top_level_using_for()
# Convert library function (at the moment are string) in using for that specifies list of functions
# to actual function
self._analyze_library_function_using_for(contracts_to_be_analyzed)
self._parsed = True
def analyze_contracts(self): # pylint: disable=too-many-statements,too-many-branches
@ -605,6 +625,10 @@ Please rename it, this name is reserved for Slither's internals"""
else:
contracts_to_be_analyzed += [contract]
def _analyze_library_function_using_for(self, contracts_to_be_analyzed: List[ContractSolc]):
for c in contracts_to_be_analyzed:
c.analyze_library_function_using_for()
def _analyze_enums(self, contract: ContractSolc):
# Enum must be analyzed first
contract.analyze_enums()
@ -651,6 +675,10 @@ Please rename it, this name is reserved for Slither's internals"""
func_parser.analyze_params()
self._compilation_unit.add_function(func_parser.underlying_function)
def _analyze_top_level_using_for(self):
for using_for in self._using_for_top_level_parser:
using_for.analyze()
def _analyze_params_custom_error(self):
for custom_error_parser in self._custom_error_parser:
custom_error_parser.analyze_params()

@ -224,6 +224,7 @@ def parse_type(
from slither.solc_parsing.variables.function_type_variable import FunctionTypeVariableSolc
from slither.solc_parsing.declarations.contract import ContractSolc
from slither.solc_parsing.declarations.function import FunctionSolc
from slither.solc_parsing.declarations.using_for_top_level import UsingForTopLevelSolc
from slither.solc_parsing.declarations.custom_error import CustomErrorSolc
from slither.solc_parsing.declarations.structure_top_level import StructureTopLevelSolc
from slither.solc_parsing.slither_compilation_unit_solc import SlitherCompilationUnitSolc
@ -259,11 +260,16 @@ def parse_type(
all_enums += enums_direct_access
contracts = sl.contracts
functions = []
elif isinstance(caller_context, (StructureTopLevelSolc, CustomErrorSolc, TopLevelVariableSolc)):
elif isinstance(
caller_context,
(StructureTopLevelSolc, CustomErrorSolc, TopLevelVariableSolc, UsingForTopLevelSolc),
):
if isinstance(caller_context, StructureTopLevelSolc):
scope = caller_context.underlying_structure.file_scope
elif isinstance(caller_context, TopLevelVariableSolc):
scope = caller_context.underlying_variable.file_scope
elif isinstance(caller_context, UsingForTopLevelSolc):
scope = caller_context.underlying_using_for.file_scope
else:
assert isinstance(caller_context, CustomErrorSolc)
custom_error = caller_context.underlying_custom_error

Loading…
Cancel
Save