Improve reentrancy heuristic

pull/146/head
Josselin 6 years ago
parent ff280c2b6f
commit 4340b86bb4
  1. 11
      slither/detectors/reentrancy/reentrancy_eth.py
  2. 11
      slither/detectors/reentrancy/reentrancy_read_before_write.py

@ -88,32 +88,37 @@ class ReentrancyEth(AbstractDetector):
# send_eth returns the list of calls sending value # send_eth returns the list of calls sending value
# calls returns the list of calls that can callback # calls returns the list of calls that can callback
# read returns the variable read # read returns the variable read
fathers_context = {'send_eth':[], 'calls':[], 'read':[]} fathers_context = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]}
for father in node.fathers: for father in node.fathers:
if self.key in father.context: if self.key in father.context:
fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father] fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father]
fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father] fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father]
fathers_context['read'] += father.context[self.key]['read'] fathers_context['read'] += father.context[self.key]['read']
fathers_context['read_prior_calls'] += father.context[self.key]['read_prior_calls']
# Exclude path that dont bring further information # Exclude path that dont bring further information
if node in self.visited_all_paths: if node in self.visited_all_paths:
if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']): if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']):
if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']): if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']):
if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']): if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']):
if all(read in self.visited_all_paths[node]['read_prior_calls'] for read in fathers_context['read_prior_calls']):
return return
else: else:
self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[]} self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]}
self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth'])) self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth']))
self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls'])) self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls']))
self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read'])) self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read']))
self.visited_all_paths[node]['read_prior_calls'] = list(set(self.visited_all_paths[node]['read_prior_calls'] + fathers_context['read_prior_calls']))
node.context[self.key] = fathers_context node.context[self.key] = fathers_context
contains_call = False contains_call = False
if self._can_callback(node): if self._can_callback(node):
node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node])) node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node]))
node.context[self.key]['read_prior_calls'] = list(set(node.context[self.key]['read_prior_calls'] + node.context[self.key]['read']+ node.state_variables_read))
node.context[self.key]['read'] = []
contains_call = True contains_call = True
if self._can_send_eth(node): if self._can_send_eth(node):
node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node])) node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node]))
@ -127,7 +132,7 @@ class ReentrancyEth(AbstractDetector):
if isinstance(internal_call, Function): if isinstance(internal_call, Function):
state_vars_written += internal_call.all_state_variables_written() state_vars_written += internal_call.all_state_variables_written()
read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read']] read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read_prior_calls']]
node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read)) node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read))
# If a state variables was read and is then written, there is a dangerous call and # If a state variables was read and is then written, there is a dangerous call and

@ -89,32 +89,37 @@ class ReentrancyReadBeforeWritten(AbstractDetector):
# send_eth returns the list of calls sending value # send_eth returns the list of calls sending value
# calls returns the list of calls that can callback # calls returns the list of calls that can callback
# read returns the variable read # read returns the variable read
fathers_context = {'send_eth':[], 'calls':[], 'read':[]} fathers_context = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]}
for father in node.fathers: for father in node.fathers:
if self.key in father.context: if self.key in father.context:
fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father] fathers_context['send_eth'] += [s for s in father.context[self.key]['send_eth'] if s!=skip_father]
fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father] fathers_context['calls'] += [c for c in father.context[self.key]['calls'] if c!=skip_father]
fathers_context['read'] += father.context[self.key]['read'] fathers_context['read'] += father.context[self.key]['read']
fathers_context['read_prior_calls'] += father.context[self.key]['read_prior_calls']
# Exclude path that dont bring further information # Exclude path that dont bring further information
if node in self.visited_all_paths: if node in self.visited_all_paths:
if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']): if all(call in self.visited_all_paths[node]['calls'] for call in fathers_context['calls']):
if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']): if all(send in self.visited_all_paths[node]['send_eth'] for send in fathers_context['send_eth']):
if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']): if all(read in self.visited_all_paths[node]['read'] for read in fathers_context['read']):
if all(read in self.visited_all_paths[node]['read_prior_calls'] for read in fathers_context['read_prior_calls']):
return return
else: else:
self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[]} self.visited_all_paths[node] = {'send_eth':[], 'calls':[], 'read':[], 'read_prior_calls':[]}
self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth'])) self.visited_all_paths[node]['send_eth'] = list(set(self.visited_all_paths[node]['send_eth'] + fathers_context['send_eth']))
self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls'])) self.visited_all_paths[node]['calls'] = list(set(self.visited_all_paths[node]['calls'] + fathers_context['calls']))
self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read'])) self.visited_all_paths[node]['read'] = list(set(self.visited_all_paths[node]['read'] + fathers_context['read']))
self.visited_all_paths[node]['read_prior_calls'] = list(set(self.visited_all_paths[node]['read_prior_calls'] + fathers_context['read_prior_calls']))
node.context[self.key] = fathers_context node.context[self.key] = fathers_context
contains_call = False contains_call = False
if self._can_callback(node): if self._can_callback(node):
node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node])) node.context[self.key]['calls'] = list(set(node.context[self.key]['calls'] + [node]))
node.context[self.key]['read_prior_calls'] = list(set(node.context[self.key]['read_prior_calls'] + node.context[self.key]['read'] + node.state_variables_read))
node.context[self.key]['read'] = []
contains_call = True contains_call = True
if self._can_send_eth(node): if self._can_send_eth(node):
node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node])) node.context[self.key]['send_eth'] = list(set(node.context[self.key]['send_eth'] + [node]))
@ -128,7 +133,7 @@ class ReentrancyReadBeforeWritten(AbstractDetector):
if isinstance(internal_call, Function): if isinstance(internal_call, Function):
state_vars_written += internal_call.all_state_variables_written() state_vars_written += internal_call.all_state_variables_written()
read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read']] read_then_written = [(v, node) for v in state_vars_written if v in node.context[self.key]['read_prior_calls']]
node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read)) node.context[self.key]['read'] = list(set(node.context[self.key]['read'] + node.state_variables_read))
# If a state variables was read and is then written, there is a dangerous call and # If a state variables was read and is then written, there is a dangerous call and

Loading…
Cancel
Save