Merge pull request #449 from crytic/dev-true-false-branch

Node API: add true/false branchs for if statement
pull/451/head
Feist Josselin 5 years ago committed by GitHub
commit 453e7a1459
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 54
      slither/core/cfg/node.py
  2. 12
      slither/core/declarations/function.py

@ -2,6 +2,7 @@
Node module
"""
import logging
from typing import Optional
from slither.core.children.child_function import ChildFunction
from slither.core.declarations import Contract
@ -24,6 +25,7 @@ from slither.all_exceptions import SlitherException
logger = logging.getLogger("Node")
###################################################################################
###################################################################################
# region NodeType
@ -31,23 +33,22 @@ logger = logging.getLogger("Node")
###################################################################################
class NodeType:
ENTRYPOINT = 0x0 # no expression
# Node with expression
EXPRESSION = 0x10 # normal case
RETURN = 0x11 # RETURN may contain an expression
RETURN = 0x11 # RETURN may contain an expression
IF = 0x12
VARIABLE = 0x13 # Declaration of variable
VARIABLE = 0x13 # Declaration of variable
ASSEMBLY = 0x14
IFLOOP = 0x15
# Merging nodes
# Can have phi IR operation
ENDIF = 0x50 # ENDIF node source mapping points to the if/else body
STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body
ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body
ENDIF = 0x50 # ENDIF node source mapping points to the if/else body
STARTLOOP = 0x51 # STARTLOOP node source mapping points to the entire loop body
ENDLOOP = 0x52 # ENDLOOP node source mapping points to the entire loop body
# Below the nodes have no expression
# But are used to expression CFG structure
@ -69,8 +70,7 @@ class NodeType:
# Use for state variable declaration
OTHER_ENTRYPOINT = 0x50
# @staticmethod
# @staticmethod
def str(t):
if t == NodeType.ENTRYPOINT:
return 'ENTRY_POINT'
@ -118,6 +118,7 @@ def link_nodes(n1, n2):
n1.add_son(n2)
n2.add_father(n1)
def insert_node(origin, node_inserted):
sons = origin.sons
link_nodes(origin, node_inserted)
@ -127,6 +128,7 @@ def insert_node(origin, node_inserted):
link_nodes(node_inserted, son)
def recheable(node):
'''
Return the set of nodes reacheable from the node
@ -167,7 +169,7 @@ class Node(SourceMapping, ChildFunction):
self._dominators = set()
self._immediate_dominator = None
## Nodes of the dominators tree
#self._dom_predecessors = set()
# self._dom_predecessors = set()
self._dom_successors = set()
# Dominance frontier
self._dominance_frontier = set()
@ -189,7 +191,7 @@ class Node(SourceMapping, ChildFunction):
self._internal_calls = []
self._solidity_calls = []
self._high_level_calls = [] # contains library calls
self._high_level_calls = [] # contains library calls
self._library_calls = []
self._low_level_calls = []
self._external_calls_as_expressions = []
@ -207,7 +209,7 @@ class Node(SourceMapping, ChildFunction):
self._local_vars_read = []
self._local_vars_written = []
self._slithir_vars = set() # non SSA
self._slithir_vars = set() # non SSA
self._ssa_local_vars_read = []
self._ssa_local_vars_written = []
@ -396,6 +398,7 @@ class Node(SourceMapping, ChildFunction):
Include library calls
"""
return list(self._library_calls)
@property
def low_level_calls(self):
"""
@ -547,7 +550,6 @@ class Node(SourceMapping, ChildFunction):
def add_inline_asm(self, asm):
self._asm_source_code = asm
# endregion
###################################################################################
###################################################################################
@ -621,6 +623,20 @@ class Node(SourceMapping, ChildFunction):
"""
return list(self._sons)
@property
def son_true(self) -> Optional["Node"]:
if self.type == NodeType.IF:
return self._sons[0]
else:
return None
@property
def son_false(self) -> Optional["Node"]:
if self.type == NodeType.IF and len(self._sons) >= 1:
return self._sons[1]
else:
return None
# endregion
###################################################################################
###################################################################################
@ -648,7 +664,7 @@ class Node(SourceMapping, ChildFunction):
@irs_ssa.setter
def irs_ssa(self, irs):
self._irs_ssa = irs
self._irs_ssa = irs
def add_ssa_ir(self, ir):
'''
@ -748,9 +764,6 @@ class Node(SourceMapping, ChildFunction):
assert v == variable
nodes.add(node)
# endregion
###################################################################################
###################################################################################
@ -789,7 +802,7 @@ class Node(SourceMapping, ChildFunction):
if isinstance(var, (ReferenceVariable)):
var = var.points_to_origin
if var and self._is_non_slithir_var(var):
self._vars_written.append(var)
self._vars_written.append(var)
if isinstance(ir, InternalCall):
self._internal_calls.append(ir.function)
@ -809,7 +822,8 @@ class Node(SourceMapping, ChildFunction):
try:
self._high_level_calls.append((ir.destination.type.type, ir.function))
except AttributeError:
raise SlitherException(f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
raise SlitherException(
f'Function not found on {ir}. Please try compiling with a recent Solidity version.')
elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract)
self._high_level_calls.append((ir.destination, ir.function))
@ -884,7 +898,6 @@ class Node(SourceMapping, ChildFunction):
vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read]
vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written]
self._vars_read += [v for v in vars_read if v not in self._vars_read]
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)]
@ -893,7 +906,6 @@ class Node(SourceMapping, ChildFunction):
self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)]
self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)]
# endregion
###################################################################################
###################################################################################
@ -902,7 +914,7 @@ class Node(SourceMapping, ChildFunction):
###################################################################################
def __str__(self):
txt = NodeType.str(self._node_type) + ' '+ str(self.expression)
txt = NodeType.str(self._node_type) + ' ' + str(self.expression)
return txt
# endregion

@ -1111,8 +1111,16 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
if node.irs:
label += '\nIRs:\n' + '\n'.join([str(ir) for ir in node.irs])
content += '{}[label="{}"];\n'.format(node.node_id, label)
for son in node.sons:
content += '{}->{};\n'.format(node.node_id, son.node_id)
if node.type == NodeType.IF:
true_node = node.son_true
if true_node:
content += '{}->{}[label="True"];\n'.format(node.node_id, true_node.node_id)
false_node = node.son_false
if false_node:
content += '{}->{}[label="False"];\n'.format(node.node_id, false_node.node_id)
else:
for son in node.sons:
content += '{}->{};\n'.format(node.node_id, son.node_id)
content += "}\n"
return content

Loading…
Cancel
Save