Merge pull request #449 from crytic/dev-true-false-branch

Node API: add true/false branchs for if statement
pull/451/head
Feist Josselin 5 years ago committed by GitHub
commit 453e7a1459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 30
      slither/core/cfg/node.py
  2. 8
      slither/core/declarations/function.py

@ -2,6 +2,7 @@
Node module
"""
import logging
from typing import Optional
from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract
@ -24,6 +25,7 @@ from slither.all_exceptions import SlitherException
logger = logging.getLogger("Node")
###################################################################################
###################################################################################
# region NodeType
@ -31,7 +33,6 @@ logger = logging.getLogger("Node")
###################################################################################
class NodeType:
ENTRYPOINT = 0x0 # no expression
# Node with expression
@ -69,7 +70,6 @@ class NodeType:
# Use for state variable declaration
OTHER_ENTRYPOINT = 0x50
# @staticmethod
def str(t):
if t == NodeType.ENTRYPOINT:
@ -118,6 +118,7 @@ def link_nodes(n1, n2):
n1.add_son(n2)
n2.add_father(n1)
def insert_node(origin, node_inserted):
sons = origin.sons
link_nodes(origin, node_inserted)
@ -127,6 +128,7 @@ def insert_node(origin, node_inserted):
link_nodes(node_inserted, son)
def recheable(node):
'''
Return the set of nodes reacheable from the node
@ -396,6 +398,7 @@ class Node(SourceMapping, ChildFunction):
Include library calls
"""
return list(self._library_calls)
@property
def low_level_calls(self):
"""
@ -547,7 +550,6 @@ class Node(SourceMapping, ChildFunction):
def add_inline_asm(self, asm):
self._asm_source_code = asm
# endregion
###################################################################################
###################################################################################
@ -621,6 +623,20 @@ class Node(SourceMapping, ChildFunction):
"""
return list(self._sons)
@property
def son_true(self) -> Optional["Node"]:
if self.type == NodeType.IF:
return self._sons[0]
else:
return None
@property
def son_false(self) -> Optional["Node"]:
if self.type == NodeType.IF and len(self._sons) >= 1:
return self._sons[1]
else:
return None
# endregion
###################################################################################
###################################################################################
@ -748,9 +764,6 @@ class Node(SourceMapping, ChildFunction):
assert v == variable
nodes.add(node)
# endregion
###################################################################################
###################################################################################
@ -809,7 +822,8 @@ class Node(SourceMapping, ChildFunction):
try:
self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError:
raise SlitherException(f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
raise SlitherException(
f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract)
self._high_level_calls.append((ir.destination, ir.function))
@ -884,7 +898,6 @@ class Node(SourceMapping, ChildFunction):
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
@ -893,7 +906,6 @@ class Node(SourceMapping, ChildFunction):
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
# endregion
###################################################################################
###################################################################################

@ -1111,6 +1111,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node.irs:
label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs])
content += '{}[label="{}"];\n'.format(node.node_id, label)
if node.type == NodeType.IF:
true_node = node.son_true
if true_node:
content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id)
false_node = node.son_false
if false_node:
content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id)
else:
for son in node.sons:
content += '{}->{};\n'.format(node.node_id, son.node_id)

Loading…
Cancel
Save