Add function.all_reachable_from_functions and function.is_reentrant

Fix type in function.reacheable_from_functions
pull/1351/head
Josselin Feist 2 years ago
parent 831f9dead3
commit 70dff581ed
  1. 50
      slither/core/declarations/function.py
  2. 17
      tests/test_function.py
  3. 36
      tests/test_function_reentrant.sol

@ -189,7 +189,8 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
# set(ReacheableNode)
self._reachable_from_nodes: Set[ReacheableNode] = set()
self._reachable_from_functions: Set[ReacheableNode] = set()
self._reachable_from_functions: Set[Function] = set()
self._all_reachable_from_functions: Optional[Set[Function]] = None
# Constructor, fallback, State variable constructor
self._function_type: Optional[FunctionType] = None
@ -214,7 +215,7 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
self.compilation_unit: "SlitherCompilationUnit" = compilation_unit
# Assume we are analyzing Solidty by default
# Assume we are analyzing Solidity by default
self.function_language: FunctionLanguage = FunctionLanguage.Solidity
self._id: Optional[str] = None
@ -1024,9 +1025,32 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
return self._reachable_from_nodes
@property
def reachable_from_functions(self) -> Set[ReacheableNode]:
def reachable_from_functions(self) -> Set["Function"]:
return self._reachable_from_functions
@property
def all_reachable_from_functions(self) -> Set["Function"]:
"""
Give the recursive version of reachable_from_functions (all the functions that lead to call self in the CFG)
"""
if self._all_reachable_from_functions is None:
functions: Set["Function"] = set()
new_functions = self.reachable_from_functions
print([str(f) for f in new_functions])
# iterate until we have are finding new functions
while new_functions and new_functions not in functions:
print([str(f) for f in new_functions])
functions = functions.union(new_functions)
# Use a temporary set, because we iterate over new_functions
new_functionss: Set["Function"] = set()
for f in new_functions:
new_functionss = new_functionss.union(f.reachable_from_functions)
new_functions = new_functionss
self._all_reachable_from_functions = functions
return self._all_reachable_from_functions
def add_reachable_from_node(self, n: "Node", ir: "Operation"):
self._reachable_from_nodes.add(ReacheableNode(n, ir))
self._reachable_from_functions.add(n.function)
@ -1455,6 +1479,26 @@ class Function(SourceMapping, metaclass=ABCMeta): # pylint: disable=too-many-pu
)
return self._is_protected
@property
def is_reentrant(self) -> bool:
"""
Determine if the function can be re-entered
"""
# TODO: compare with hash of known nonReentrant modifier instead of the name
if "nonReentrant" in [m.name for m in self.modifiers]:
return False
if self.visibility in ["public", "external"]:
return True
# If it's an internal function, check if all its entry points have the nonReentrant modifier
all_entry_points = [
f for f in self.all_reachable_from_functions if f.visibility in ["public", "external"]
]
if not all_entry_points:
return True
return not all(("nonReentrant" in [m.name for m in f.modifiers] for f in all_entry_points))
# endregion
###################################################################################
###################################################################################

@ -247,6 +247,7 @@ def test_functions():
def test_function_can_send_eth():
solc_select.switch_global_version("0.6.12", always_install=True)
slither = Slither("tests/test_function.sol")
compilation_unit = slither.compilation_units[0]
functions = compilation_unit.get_contract_from_name("TestFunctionCanSendEth")[
@ -267,3 +268,19 @@ def test_function_can_send_eth():
assert functions["transfer_via_external()"].can_send_eth() is False
assert functions["call_via_external()"].can_send_eth() is False
assert functions["highlevel_call_via_external()"].can_send_eth() is False
def test_reentrant():
solc_select.switch_global_version("0.8.10", always_install=True)
slither = Slither("tests/test_function_reentrant.sol")
compilation_unit = slither.compilation_units[0]
functions = compilation_unit.get_contract_from_name("TestReentrant")[
0
].available_functions_as_dict()
assert functions["is_reentrant()"].is_reentrant
assert not functions["is_non_reentrant()"].is_reentrant
assert not functions["internal_and_not_reentrant()"].is_reentrant
assert not functions["internal_and_not_reentrant2()"].is_reentrant
assert functions["internal_and_could_be_reentrant()"].is_reentrant
assert functions["internal_and_reentrant()"].is_reentrant

@ -0,0 +1,36 @@
contract TestReentrant{
modifier nonReentrant(){
_;
}
function is_reentrant() public{
internal_and_could_be_reentrant();
internal_and_reentrant();
}
function is_non_reentrant() nonReentrant() public{
internal_and_could_be_reentrant();
internal_and_not_reentrant2();
}
function internal_and_not_reentrant() nonReentrant() internal{
}
function internal_and_not_reentrant2() internal{
}
// Called by a protected and unprotected function
function internal_and_could_be_reentrant() internal{
}
// Called by a protected and unprotected function
function internal_and_reentrant() internal{
}
}
Loading…
Cancel
Save