diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 4c88150d2..ccabb87e6 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -189,7 +189,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu # set(ReacheableNode) self._reachable_from_nodes: Set[ReacheableNode] = set() - self._reachable_from_functions: Set[ReacheableNode] = set() + self._reachable_from_functions: Set[Function] = set() + self._all_reachable_from_functions: Optional[Set[Function]] = None # Constructor, fallback, State variable constructor self._function_type: Optional[FunctionType] = None @@ -214,7 +215,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu self.compilation_unit: "SlitherCompilationUnit" = compilation_unit - # Assume we are analyzing Solidty by default + # Assume we are analyzing Solidity by default self.function_language: FunctionLanguage = FunctionLanguage.Solidity self._id: Optional[str] = None @@ -1024,9 +1025,32 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu return self._reachable_from_nodes @property - def reachable_from_functions(self) -> Set[ReacheableNode]: + def reachable_from_functions(self) -> Set["Function"]: return self._reachable_from_functions + @property + def all_reachable_from_functions(self) -> Set["Function"]: + """ + Give the recursive version of reachable_from_functions (all the functions that lead to call self in the CFG) + """ + if self._all_reachable_from_functions is None: + functions: Set["Function"] = set() + + new_functions = self.reachable_from_functions + print([str(f) for f in new_functions]) + # iterate until we have are finding new functions + while new_functions and new_functions not in functions: + print([str(f) for f in new_functions]) + functions = functions.union(new_functions) + # Use a temporary set, because we iterate over new_functions + new_functionss: Set["Function"] = set() + for f in new_functions: + new_functionss = new_functionss.union(f.reachable_from_functions) + new_functions = new_functionss + + self._all_reachable_from_functions = functions + return self._all_reachable_from_functions + def add_reachable_from_node(self, n: "Node", ir: "Operation"): self._reachable_from_nodes.add(ReacheableNode(n, ir)) self._reachable_from_functions.add(n.function) @@ -1455,6 +1479,26 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu ) return self._is_protected + @property + def is_reentrant(self) -> bool: + """ + Determine if the function can be re-entered + """ + # TODO: compare with hash of known nonReentrant modifier instead of the name + if "nonReentrant" in [m.name for m in self.modifiers]: + return False + + if self.visibility in ["public", "external"]: + return True + + # If it's an internal function, check if all its entry points have the nonReentrant modifier + all_entry_points = [ + f for f in self.all_reachable_from_functions if f.visibility in ["public", "external"] + ] + if not all_entry_points: + return True + return not all(("nonReentrant" in [m.name for m in f.modifiers] for f in all_entry_points)) + # endregion ################################################################################### ################################################################################### diff --git a/tests/test_function.py b/tests/test_function.py index 19fa596ab..09ecdc6ba 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -247,6 +247,7 @@ def test_functions(): def test_function_can_send_eth(): + solc_select.switch_global_version("0.6.12", always_install=True) slither = Slither("tests/test_function.sol") compilation_unit = slither.compilation_units[0] functions = compilation_unit.get_contract_from_name("TestFunctionCanSendEth")[ @@ -267,3 +268,19 @@ def test_function_can_send_eth(): assert functions["transfer_via_external()"].can_send_eth() is False assert functions["call_via_external()"].can_send_eth() is False assert functions["highlevel_call_via_external()"].can_send_eth() is False + + +def test_reentrant(): + solc_select.switch_global_version("0.8.10", always_install=True) + slither = Slither("tests/test_function_reentrant.sol") + compilation_unit = slither.compilation_units[0] + functions = compilation_unit.get_contract_from_name("TestReentrant")[ + 0 + ].available_functions_as_dict() + + assert functions["is_reentrant()"].is_reentrant + assert not functions["is_non_reentrant()"].is_reentrant + assert not functions["internal_and_not_reentrant()"].is_reentrant + assert not functions["internal_and_not_reentrant2()"].is_reentrant + assert functions["internal_and_could_be_reentrant()"].is_reentrant + assert functions["internal_and_reentrant()"].is_reentrant diff --git a/tests/test_function_reentrant.sol b/tests/test_function_reentrant.sol new file mode 100644 index 000000000..a1a8faa7b --- /dev/null +++ b/tests/test_function_reentrant.sol @@ -0,0 +1,36 @@ +contract TestReentrant{ + + modifier nonReentrant(){ + _; + } + + function is_reentrant() public{ + internal_and_could_be_reentrant(); + internal_and_reentrant(); + } + + function is_non_reentrant() nonReentrant() public{ + internal_and_could_be_reentrant(); + internal_and_not_reentrant2(); + } + + function internal_and_not_reentrant() nonReentrant() internal{ + + } + + function internal_and_not_reentrant2() internal{ + + } + + // Called by a protected and unprotected function + function internal_and_could_be_reentrant() internal{ + + } + + // Called by a protected and unprotected function + function internal_and_reentrant() internal{ + + } + + +} \ No newline at end of file