Merge pull request #449 from crytic/dev-true-false-branch

Node API: add true/false branchs for if statement
pull/451/head
Feist Josselin 5 years ago committed by GitHub
commit 453e7a1459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 36
      slither/core/cfg/node.py
  2. 8
      slither/core/declarations/function.py

@ -2,6 +2,7 @@
Node module Node module
""" """
import logging import logging
from typing import Optional
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -24,6 +25,7 @@ from slither.all_exceptions import SlitherException
logger = logging.getLogger("Node") logger = logging.getLogger("Node")
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region NodeType # region NodeType
@ -31,7 +33,6 @@ logger = logging.getLogger("Node")
################################################################################### ###################################################################################
class NodeType: class NodeType:
ENTRYPOINT = 0x0 # no expression ENTRYPOINT = 0x0 # no expression
# Node with expression # Node with expression
@ -69,8 +70,7 @@ class NodeType:
# Use for state variable declaration # Use for state variable declaration
OTHER_ENTRYPOINT = 0x50 OTHER_ENTRYPOINT = 0x50
# @staticmethod
# @staticmethod
def str(t): def str(t):
if t == NodeType.ENTRYPOINT: if t == NodeType.ENTRYPOINT:
return 'ENTRY_POINT' return 'ENTRY_POINT'
@ -118,6 +118,7 @@ def link_nodes(n1, n2):
n1.add_son(n2) n1.add_son(n2)
n2.add_father(n1) n2.add_father(n1)
def insert_node(origin, node_inserted): def insert_node(origin, node_inserted):
sons = origin.sons sons = origin.sons
link_nodes(origin, node_inserted) link_nodes(origin, node_inserted)
@ -127,6 +128,7 @@ def insert_node(origin, node_inserted):
link_nodes(node_inserted, son) link_nodes(node_inserted, son)
def recheable(node): def recheable(node):
''' '''
Return the set of nodes reacheable from the node Return the set of nodes reacheable from the node
@ -167,7 +169,7 @@ class Node(SourceMapping, ChildFunction):
self._dominators = set() self._dominators = set()
self._immediate_dominator = None self._immediate_dominator = None
## Nodes of the dominators tree ## Nodes of the dominators tree
#self._dom_predecessors = set() # self._dom_predecessors = set()
self._dom_successors = set() self._dom_successors = set()
# Dominance frontier # Dominance frontier
self._dominance_frontier = set() self._dominance_frontier = set()
@ -396,6 +398,7 @@ class Node(SourceMapping, ChildFunction):
Include library calls Include library calls
""" """
return list(self._library_calls) return list(self._library_calls)
@property @property
def low_level_calls(self): def low_level_calls(self):
""" """
@ -547,7 +550,6 @@ class Node(SourceMapping, ChildFunction):
def add_inline_asm(self, asm): def add_inline_asm(self, asm):
self._asm_source_code = asm self._asm_source_code = asm
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -621,6 +623,20 @@ class Node(SourceMapping, ChildFunction):
""" """
return list(self._sons) return list(self._sons)
@property
def son_true(self) -> Optional["Node"]:
if self.type == NodeType.IF:
return self._sons[0]
else:
return None
@property
def son_false(self) -> Optional["Node"]:
if self.type == NodeType.IF and len(self._sons) >= 1:
return self._sons[1]
else:
return None
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -748,9 +764,6 @@ class Node(SourceMapping, ChildFunction):
assert v == variable assert v == variable
nodes.add(node) nodes.add(node)
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -809,7 +822,8 @@ class Node(SourceMapping, ChildFunction):
try: try:
self._high_level_calls.append((ir.destination.type.type, ir.function)) self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError: except AttributeError:
raise SlitherException(f'Function not found on {ir}. Please try compiling with a recent Solidity version.') raise SlitherException(
f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
elif isinstance(ir, LibraryCall): elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract) assert isinstance(ir.destination, Contract)
self._high_level_calls.append((ir.destination, ir.function)) self._high_level_calls.append((ir.destination, ir.function))
@ -884,7 +898,6 @@ class Node(SourceMapping, ChildFunction):
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read] self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
@ -893,7 +906,6 @@ class Node(SourceMapping, ChildFunction):
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -902,7 +914,7 @@ class Node(SourceMapping, ChildFunction):
################################################################################### ###################################################################################
def __str__(self): def __str__(self):
txt = NodeType.str(self._node_type) + ' '+ str(self.expression) txt = NodeType.str(self._node_type) + ' ' + str(self.expression)
return txt return txt
# endregion # endregion

@ -1111,6 +1111,14 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node.irs: if node.irs:
label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs]) label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs])
content += '{}[label="{}"];\n'.format(node.node_id, label) content += '{}[label="{}"];\n'.format(node.node_id, label)
if node.type == NodeType.IF:
true_node = node.son_true
if true_node:
content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id)
false_node = node.son_false
if false_node:
content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id)
else:
for son in node.sons: for son in node.sons:
content += '{}->{};\n'.format(node.node_id, son.node_id) content += '{}->{};\n'.format(node.node_id, son.node_id)

Loading…
Cancel
Save