diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index a5fadfa21..6ac4f7415 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -2,6 +2,7 @@ Function module """ import logging +from collections import namedtuple from itertools import groupby from slither.core.children.child_contract import ChildContract @@ -17,6 +18,8 @@ from slither.core.variables.state_variable import StateVariable logger = logging.getLogger("Function") +ReacheableNode = namedtuple('ReacheableNode', ['node', 'ir']) + class Function(ChildContract, SourceMapping): """ Function class @@ -75,6 +78,10 @@ class Function(ChildContract, SourceMapping): self._all_conditional_solidity_variables_read = None self._all_solidity_variables_used_as_args = None + # set(ReacheableNode) + self._reachable_from_nodes = set() + self._reachable_from_functions = set() + @property def contains_assembly(self): return self._contains_assembly @@ -415,6 +422,22 @@ class Function(ChildContract, SourceMapping): def slither(self): return self.contract.slither + @property + def reachable_from_nodes(self): + ''' + Return + ReacheableNode + ''' + return self._reachable_from_nodes + + @property + def reachable_from_functions(self): + return self._reachable_from_functions + + def add_reachable_from_node(self, n, ir): + self._reachable_from_nodes.add(ReacheableNode(n, ir)) + self._reachable_from_functions.add(n.function) + def _filter_state_variables_written(self, expressions): ret = [] for expression in expressions: diff --git a/slither/core/slither_core.py b/slither/core/slither_core.py index cd8f8a113..31714d46b 100644 --- a/slither/core/slither_core.py +++ b/slither/core/slither_core.py @@ -3,6 +3,7 @@ """ import os from slither.core.context.context import Context +from slither.slithir.operations import InternalCall class Slither(Context): """ @@ -18,6 +19,8 @@ class Slither(Context): self._pragma_directives = [] self._import_directives = [] self._raw_source_code = {} + self._all_functions = set() + self._all_modifiers = set() @property def source_units(self): @@ -39,6 +42,24 @@ class Slither(Context): """list(dict(str: Contract): List of contracts as dict: name -> Contract.""" return self._contracts + @property + def functions(self): + return list(self._all_functions) + + def add_function(self, func): + self._all_functions.add(func) + + @property + def modifiers(self): + return list(self._all_modifiers) + + def add_modifier(self, modif): + self._all_modifiers.add(modif) + + @property + def functions_and_modifiers(self): + return self.functions + self.modifiers + @property def filename(self): """str: Filename.""" @@ -64,6 +85,13 @@ class Slither(Context): """ {filename: source_code}: source code """ return self._raw_source_code + def _propagate_function_calls(self): + for f in self.functions_and_modifiers: + for node in f.nodes: + for ir in node.irs_ssa: + if isinstance(ir, InternalCall): + ir.function.add_reachable_from_node(node, ir) + def get_contract_from_name(self, contract_name): """ Return a contract from a name diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index 02392068a..c895a3eb2 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -320,6 +320,7 @@ class ContractSolc04(Contract): modif = ModifierSolc(modifier, self) modif.set_contract(self) modif.set_offset(modifier['src'], self.slither) + self.slither.add_modifier(modif) self._modifiers_no_params.append(modif) def parse_modifiers(self): @@ -333,6 +334,7 @@ class ContractSolc04(Contract): def _parse_function(self, function): func = FunctionSolc(function, self) func.set_offset(function['src'], self.slither) + self.slither.add_function(func) self._functions_no_params.append(func) def parse_functions(self): diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index 1bce49757..d75a43797 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -357,6 +357,7 @@ class SlitherSolc(Slither): def _convert_to_slithir(self): for contract in self.contracts: contract.convert_expression_to_slithir() + self._propagate_function_calls() for contract in self.contracts: contract.fix_phi() contract.update_read_write_using_ssa()