Merge pull request #449 from crytic/dev-true-false-branch

Node API: add true/false branchs for if statement
pull/451/head
Feist Josselin 5 years ago committed by GitHub
commit 453e7a1459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      slither/core/cfg/node.py
  2. 12
      slither/core/declarations/function.py

@ -2,6 +2,7 @@
Node module Node module
""" """
import logging import logging
from typing import Optional
from slither.core.children.child_function import ChildFunction from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract from slither.core.declarations import Contract
@ -24,6 +25,7 @@ from slither.all_exceptions import SlitherException
logger = logging.getLogger("Node") logger = logging.getLogger("Node")
################################################################################### ###################################################################################
################################################################################### ###################################################################################
# region NodeType # region NodeType
@ -31,23 +33,22 @@ logger = logging.getLogger("Node")
################################################################################### ###################################################################################
class NodeType: class NodeType:
ENTRYPOINT = 0x0 # no expression ENTRYPOINT = 0x0 # no expression
# Node with expression # Node with expression
EXPRESSION = 0x10 # normal case EXPRESSION = 0x10 # normal case
RETURN = 0x11 # RETURN may contain an expression RETURN = 0x11 # RETURN may contain an expression
IF = 0x12 IF = 0x12
VARIABLE = 0x13 # Declaration of variable VARIABLE = 0x13 # Declaration of variable
ASSEMBLY = 0x14 ASSEMBLY = 0x14
IFLOOP = 0x15 IFLOOP = 0x15
# Merging nodes # Merging nodes
# Can have phi IR operation # Can have phi IR operation
ENDIF = 0x50 # ENDIF node source mapping points to the if/else body ENDIF = 0x50 # ENDIF node source mapping points to the if/else body
STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body
ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body
# Below the nodes have no expression # Below the nodes have no expression
# But are used to expression CFG structure # But are used to expression CFG structure
@ -69,8 +70,7 @@ class NodeType:
# Use for state variable declaration # Use for state variable declaration
OTHER_ENTRYPOINT = 0x50 OTHER_ENTRYPOINT = 0x50
# @staticmethod
# @staticmethod
def str(t): def str(t):
if t == NodeType.ENTRYPOINT: if t == NodeType.ENTRYPOINT:
return 'ENTRY_POINT' return 'ENTRY_POINT'
@ -118,6 +118,7 @@ def link_nodes(n1, n2):
n1.add_son(n2) n1.add_son(n2)
n2.add_father(n1) n2.add_father(n1)
def insert_node(origin, node_inserted): def insert_node(origin, node_inserted):
sons = origin.sons sons = origin.sons
link_nodes(origin, node_inserted) link_nodes(origin, node_inserted)
@ -127,6 +128,7 @@ def insert_node(origin, node_inserted):
link_nodes(node_inserted, son) link_nodes(node_inserted, son)
def recheable(node): def recheable(node):
''' '''
Return the set of nodes reacheable from the node Return the set of nodes reacheable from the node
@ -167,7 +169,7 @@ class Node(SourceMapping, ChildFunction):
self._dominators = set() self._dominators = set()
self._immediate_dominator = None self._immediate_dominator = None
## Nodes of the dominators tree ## Nodes of the dominators tree
#self._dom_predecessors = set() # self._dom_predecessors = set()
self._dom_successors = set() self._dom_successors = set()
# Dominance frontier # Dominance frontier
self._dominance_frontier = set() self._dominance_frontier = set()
@ -189,7 +191,7 @@ class Node(SourceMapping, ChildFunction):
self._internal_calls = [] self._internal_calls = []
self._solidity_calls = [] self._solidity_calls = []
self._high_level_calls = [] # contains library calls self._high_level_calls = [] # contains library calls
self._library_calls = [] self._library_calls = []
self._low_level_calls = [] self._low_level_calls = []
self._external_calls_as_expressions = [] self._external_calls_as_expressions = []
@ -207,7 +209,7 @@ class Node(SourceMapping, ChildFunction):
self._local_vars_read = [] self._local_vars_read = []
self._local_vars_written = [] self._local_vars_written = []
self._slithir_vars = set() # non SSA self._slithir_vars = set() # non SSA
self._ssa_local_vars_read = [] self._ssa_local_vars_read = []
self._ssa_local_vars_written = [] self._ssa_local_vars_written = []
@ -396,6 +398,7 @@ class Node(SourceMapping, ChildFunction):
Include library calls Include library calls
""" """
return list(self._library_calls) return list(self._library_calls)
@property @property
def low_level_calls(self): def low_level_calls(self):
""" """
@ -547,7 +550,6 @@ class Node(SourceMapping, ChildFunction):
def add_inline_asm(self, asm): def add_inline_asm(self, asm):
self._asm_source_code = asm self._asm_source_code = asm
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -621,6 +623,20 @@ class Node(SourceMapping, ChildFunction):
""" """
return list(self._sons) return list(self._sons)
@property
def son_true(self) -> Optional["Node"]:
if self.type == NodeType.IF:
return self._sons[0]
else:
return None
@property
def son_false(self) -> Optional["Node"]:
if self.type == NodeType.IF and len(self._sons) >= 1:
return self._sons[1]
else:
return None
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -648,7 +664,7 @@ class Node(SourceMapping, ChildFunction):
@irs_ssa.setter @irs_ssa.setter
def irs_ssa(self, irs): def irs_ssa(self, irs):
self._irs_ssa = irs self._irs_ssa = irs
def add_ssa_ir(self, ir): def add_ssa_ir(self, ir):
''' '''
@ -748,9 +764,6 @@ class Node(SourceMapping, ChildFunction):
assert v == variable assert v == variable
nodes.add(node) nodes.add(node)
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -789,7 +802,7 @@ class Node(SourceMapping, ChildFunction):
if isinstance(var, (ReferenceVariable)): if isinstance(var, (ReferenceVariable)):
var = var.points_to_origin var = var.points_to_origin
if var and self._is_non_slithir_var(var): if var and self._is_non_slithir_var(var):
self._vars_written.append(var) self._vars_written.append(var)
if isinstance(ir, InternalCall): if isinstance(ir, InternalCall):
self._internal_calls.append(ir.function) self._internal_calls.append(ir.function)
@ -809,7 +822,8 @@ class Node(SourceMapping, ChildFunction):
try: try:
self._high_level_calls.append((ir.destination.type.type, ir.function)) self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError: except AttributeError:
raise SlitherException(f'Function not found on {ir}. Please try compiling with a recent Solidity version.') raise SlitherException(
f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
elif isinstance(ir, LibraryCall): elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract) assert isinstance(ir.destination, Contract)
self._high_level_calls.append((ir.destination, ir.function)) self._high_level_calls.append((ir.destination, ir.function))
@ -884,7 +898,6 @@ class Node(SourceMapping, ChildFunction):
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read] self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
@ -893,7 +906,6 @@ class Node(SourceMapping, ChildFunction):
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
# endregion # endregion
################################################################################### ###################################################################################
################################################################################### ###################################################################################
@ -902,7 +914,7 @@ class Node(SourceMapping, ChildFunction):
################################################################################### ###################################################################################
def __str__(self): def __str__(self):
txt = NodeType.str(self._node_type) + ' '+ str(self.expression) txt = NodeType.str(self._node_type) + ' ' + str(self.expression)
return txt return txt
# endregion # endregion

@ -1111,8 +1111,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node.irs: if node.irs:
label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs]) label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs])
content += '{}[label="{}"];\n'.format(node.node_id, label) content += '{}[label="{}"];\n'.format(node.node_id, label)
for son in node.sons: if node.type == NodeType.IF:
content += '{}->{};\n'.format(node.node_id, son.node_id) true_node = node.son_true
if true_node:
content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id)
false_node = node.son_false
if false_node:
content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id)
else:
for son in node.sons:
content += '{}->{};\n'.format(node.node_id, son.node_id)
content += "}\n" content += "}\n"
return content return content

Loading…
Cancel
Save