diff --git a/slither/detectors/reentrancy/reentrancy_eth.py b/slither/detectors/reentrancy/reentrancy_eth.py index c9b0efcc9..f9698c963 100644 --- a/slither/detectors/reentrancy/reentrancy_eth.py +++ b/slither/detectors/reentrancy/reentrancy_eth.py @@ -88,32 +88,37 @@ class ReentrancyEth(AbstractDetector): # send_eth returns the list of calls sending value # calls returns the list of calls that can callback # read returns the variable read - fathers_context = {'send_eth':[], 'calls':[], 'read':[]} + fathers_context = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]} for father in node.fathers: if self.key in father.context: fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father] fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father] fathers_context['read'] += father.context[self.key]['read'] + fathers_context['read_prior_calls'] += father.context[self.key]['read_prior_calls'] # Exclude path that dont bring further information if node in self.visited_all_paths: if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']): if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']): if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']): - return + if all(read in self.visited_all_paths[node]['read_prior_calls'] for read in fathers_context['read_prior_calls']): + return else: - self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[]} + self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]} self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth'])) self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls'])) self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read'])) + self.visited_all_paths[node]['read_prior_calls'] = list(set(self.visited_all_paths[node]['read_prior_calls'] + fathers_context['read_prior_calls'])) node.context[self.key] = fathers_context contains_call = False if self._can_callback(node): node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node])) + node.context[self.key]['read_prior_calls'] = list(set(node.context[self.key]['read_prior_calls'] + node.context[self.key]['read']+ node.state_variables_read)) + node.context[self.key]['read'] = [] contains_call = True if self._can_send_eth(node): node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node])) @@ -127,7 +132,7 @@ class ReentrancyEth(AbstractDetector): if isinstance(internal_call, Function): state_vars_written += internal_call.all_state_variables_written() - read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read']] + read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read_prior_calls']] node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read)) # If a state variables was read and is then written, there is a dangerous call and diff --git a/slither/detectors/reentrancy/reentrancy_read_before_write.py b/slither/detectors/reentrancy/reentrancy_read_before_write.py index 508a34393..cc23b09e7 100644 --- a/slither/detectors/reentrancy/reentrancy_read_before_write.py +++ b/slither/detectors/reentrancy/reentrancy_read_before_write.py @@ -89,32 +89,37 @@ class ReentrancyReadBeforeWritten(AbstractDetector): # send_eth returns the list of calls sending value # calls returns the list of calls that can callback # read returns the variable read - fathers_context = {'send_eth':[], 'calls':[], 'read':[]} + fathers_context = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]} for father in node.fathers: if self.key in father.context: fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father] fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father] fathers_context['read'] += father.context[self.key]['read'] + fathers_context['read_prior_calls'] += father.context[self.key]['read_prior_calls'] # Exclude path that dont bring further information if node in self.visited_all_paths: if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']): if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']): if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']): - return + if all(read in self.visited_all_paths[node]['read_prior_calls'] for read in fathers_context['read_prior_calls']): + return else: - self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[]} + self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]} self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth'])) self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls'])) self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read'])) + self.visited_all_paths[node]['read_prior_calls'] = list(set(self.visited_all_paths[node]['read_prior_calls'] + fathers_context['read_prior_calls'])) node.context[self.key] = fathers_context contains_call = False if self._can_callback(node): node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node])) + node.context[self.key]['read_prior_calls'] = list(set(node.context[self.key]['read_prior_calls'] + node.context[self.key]['read'] + node.state_variables_read)) + node.context[self.key]['read'] = [] contains_call = True if self._can_send_eth(node): node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node])) @@ -128,7 +133,7 @@ class ReentrancyReadBeforeWritten(AbstractDetector): if isinstance(internal_call, Function): state_vars_written += internal_call.all_state_variables_written() - read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read']] + read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read_prior_calls']] node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read)) # If a state variables was read and is then written, there is a dangerous call and