Keep original slithIR and generate the SSA version as a second list of IR (WIP)

pull/87/head
Josselin 6 years ago
parent 0e6e0cb176
commit 2915840be2
  1. 3
      slither/__main__.py
  2. 21
      slither/core/cfg/node.py
  3. 2
      slither/slithir/operations/high_level_call.py
  4. 4
      slither/slithir/operations/index.py
  5. 4
      slither/slithir/operations/unary.py
  6. 308
      slither/slithir/utils/ssa.py
  7. 4
      slither/solc_parsing/declarations/function.py
  8. 29
      slither/visitors/slithir/expression_to_slithir.py

@ -157,6 +157,7 @@ def get_detectors_and_printers():
from slither.printers.call.call_graph import PrinterCallGraph from slither.printers.call.call_graph import PrinterCallGraph
from slither.printers.functions.authorization import PrinterWrittenVariablesAndAuthorization from slither.printers.functions.authorization import PrinterWrittenVariablesAndAuthorization
from slither.printers.summary.slithir import PrinterSlithIR from slither.printers.summary.slithir import PrinterSlithIR
from slither.printers.summary.slithir_ssa import PrinterSlithIRSSA
from slither.printers.summary.human_summary import PrinterHumanSummary from slither.printers.summary.human_summary import PrinterHumanSummary
printers = [FunctionSummary, printers = [FunctionSummary,
@ -166,6 +167,7 @@ def get_detectors_and_printers():
PrinterCallGraph, PrinterCallGraph,
PrinterWrittenVariablesAndAuthorization, PrinterWrittenVariablesAndAuthorization,
PrinterSlithIR, PrinterSlithIR,
PrinterSlithIRSSA,
PrinterHumanSummary] PrinterHumanSummary]
# Handle plugins! # Handle plugins!
@ -213,6 +215,7 @@ def main_impl(all_detector_classes, all_printer_classes):
('FunctionSolc', default_log), ('FunctionSolc', default_log),
('ExpressionParsing', default_log), ('ExpressionParsing', default_log),
('TypeParsing', default_log), ('TypeParsing', default_log),
('SSA_Conversion', default_log),
('Printers', default_log)]: ('Printers', default_log)]:
l = logging.getLogger(l_name) l = logging.getLogger(l_name)
l.setLevel(l_level) l.setLevel(l_level)

@ -131,6 +131,7 @@ class Node(SourceMapping, ChildFunction):
self._low_level_calls = [] self._low_level_calls = []
self._external_calls_as_expressions = [] self._external_calls_as_expressions = []
self._irs = [] self._irs = []
self._irs_ssa = []
self._state_vars_written = [] self._state_vars_written = []
self._state_vars_read = [] self._state_vars_read = []
@ -424,8 +425,24 @@ class Node(SourceMapping, ChildFunction):
""" """
return self._irs return self._irs
def add_pre_ir(self, ir): @property
self._irs.insert(0, ir) def irs_ssa(self):
""" Returns the slithIR representation with SSA
return
list(slithIR.Operation)
"""
return self._irs_ssa
@irs_ssa.setter
def irs_ssa(self, irs):
self._irs_ssa = irs
def add_ssa_ir(self, ir):
'''
Use to place phi operation
'''
self._irs_ssa.append(ir)
def slithir_generation(self): def slithir_generation(self):
if self.expression: if self.expression:

@ -13,7 +13,7 @@ class HighLevelCall(Call, OperationWithLValue):
def __init__(self, destination, function_name, nbr_arguments, result, type_call): def __init__(self, destination, function_name, nbr_arguments, result, type_call):
assert isinstance(function_name, Constant) assert isinstance(function_name, Constant)
assert is_valid_lvalue(result) assert is_valid_lvalue(result) or result is None
self._check_destination(destination) self._check_destination(destination)
super(HighLevelCall, self).__init__() super(HighLevelCall, self).__init__()
self._destination = destination self._destination = destination

@ -33,5 +33,9 @@ class Index(OperationWithLValue):
def variable_right(self): def variable_right(self):
return self._variables[1] return self._variables[1]
@property
def index_type(self):
return self._type
def __str__(self): def __str__(self):
return "{}({}) -> {}[{}]".format(self.lvalue, self.lvalue.type, self.variable_left, self.variable_right) return "{}({}) -> {}[{}]".format(self.lvalue, self.lvalue.type, self.variable_left, self.variable_right)

@ -48,6 +48,10 @@ class Unary(OperationWithLValue):
def rvalue(self): def rvalue(self):
return self._variable return self._variable
@property
def type(self):
return self._type
@property @property
def type_str(self): def type_str(self):
return UnaryType.str(self._type) return UnaryType.str(self._type)

@ -1,7 +1,25 @@
from slither.slithir.variables import (Constant, ReferenceVariable, import logging
TemporaryVariable, TupleVariable)
from slither.slithir.operations import OperationWithLValue, Phi from slither.core.cfg.node import NodeType
from slither.slithir.variables import LocalIRVariable from slither.core.variables.local_variable import LocalVariable
from slither.slithir.operations import (Assignment, Balance, Binary,
BinaryType, Condition, Delete,
EventCall, HighLevelCall, Index,
InitArray, InternalCall,
InternalDynamicCall, Length,
LibraryCall, LowLevelCall, Member,
NewArray, NewContract,
NewElementaryType, NewStructure,
OperationWithLValue, Phi, Push, Return,
Send, SolidityCall, Transfer,
TypeConversion, Unary, Unpack)
from slither.slithir.variables import (Constant, LocalIRVariable,
ReferenceVariable, TemporaryVariable,
TupleVariable)
logger = logging.getLogger('SSA_Conversion')
def transform_slithir_vars_to_ssa(function): def transform_slithir_vars_to_ssa(function):
""" """
@ -9,7 +27,7 @@ def transform_slithir_vars_to_ssa(function):
""" """
variables = [] variables = []
for node in function.nodes: for node in function.nodes:
for ir in node.irs: for ir in node.irs_ssa:
if isinstance(ir, OperationWithLValue) and not ir.lvalue in variables: if isinstance(ir, OperationWithLValue) and not ir.lvalue in variables:
variables += [ir.lvalue] variables += [ir.lvalue]
@ -23,12 +41,18 @@ def transform_slithir_vars_to_ssa(function):
for idx in range(len(tuple_variables)): for idx in range(len(tuple_variables)):
tuple_variables[idx].index = idx tuple_variables[idx].index = idx
def add_phi_operations(function): def add_ssa_ir(function):
'''
Add SSA version of the IR
'''
if not function.is_implemented:
return
init_definition = dict() init_definition = dict()
for v in function.parameters+function.returns: for v in function.parameters+function.returns:
if v.name: if v.name:
init_definition[v.name] = function.entry_point init_definition[v.name] = function.entry_point
add_phi_origins(function.entry_point, init_definition) add_phi_origins(function.entry_point, init_definition)
for node in function.nodes: for node in function.nodes:
@ -39,57 +63,75 @@ def add_phi_operations(function):
# an instance of the variable # an instance of the variable
# by looking at the variables written # by looking at the variables written
# of any of the nodes # of any of the nodes
n = list(nodes)[0] for n in nodes:
variable = next((v for v in n.variables_written if v.name == variable_name), None) variable = next((v for v in n.variables_written if v.name == variable_name), None)
if variable is None: if variable is None:
variable = n.variable_declaration variable = n.variable_declaration
node.add_pre_ir(Phi(LocalIRVariable(variable), nodes)) if variable:
break
assert variable
node.add_ssa_ir(Phi(LocalIRVariable(variable), nodes))
init_local_variables_counter = dict() init_local_variables_instances = dict()
for v in function.parameters+function.returns: for v in function.parameters+function.returns:
if v.name: if v.name:
init_local_variables_counter[v.name] = 0 init_local_variables_instances[v.name] = LocalIRVariable(v)
init_global_variables_counter = dict(init_local_variables_counter) init_global_variables_instances = dict(init_local_variables_instances)
rename_variables(function.entry_point, init_local_variables_counter, init_global_variables_counter) generate_ssa_irs(function.entry_point,
dict(init_local_variables_instances),
init_global_variables_instances)
fix_phi_operations(function.nodes) fix_phi_operations(function.nodes, init_local_variables_instances)
def fix_phi_operations(nodes):
def fix_phi_operations(nodes, init_vars):
def last_name(n, var): def last_name(n, var):
candidates = [v for v in n.variables_written if v.name == var.name] candidates = []
# Todo optimize by creating a variables_ssa_written attribute
for ir_ssa in n.irs_ssa:
if isinstance(ir_ssa, OperationWithLValue):
if ir_ssa.lvalue and ir_ssa.lvalue.name == var.name:
candidates.append(ir_ssa.lvalue)
if n.variable_declaration and n.variable_declaration.name == var.name: if n.variable_declaration and n.variable_declaration.name == var.name:
candidates.append(LocalIRVariable(n.variable_declaration)) candidates.append(LocalIRVariable(n.variable_declaration))
if n.type == NodeType.ENTRYPOINT:
if var.name in init_vars:
candidates.append(init_vars[var.name])
assert candidates assert candidates
return max(candidates, key=lambda v: v.index) return max(candidates, key=lambda v: v.index)
for node in nodes: for node in nodes:
for ir in node.irs: for ir in node.irs_ssa:
if isinstance(ir, Phi): if isinstance(ir, Phi):
variables = [last_name(dst, ir.lvalue) for dst in ir.nodes] variables = [last_name(dst, ir.lvalue) for dst in ir.nodes]
ir.rvalues = variables ir.rvalues = variables
def rename_variables(node, local_variables_counter, global_variables_counter): def generate_ssa_irs(node, local_variables_instances, global_variables_instances):
if node.variable_declaration: if node.variable_declaration:
local_variables_counter[node.variable_declaration.name] = 0 new_var = LocalIRVariable(node.variable_declaration)
global_variables_counter[node.variable_declaration.name] = 0 local_variables_instances[node.variable_declaration.name] = new_var
global_variables_instances[node.variable_declaration.name] = new_var
for idx in range(len(node.irs)):
ir = node.irs[idx] for ir in node.irs:
for used in ir.used: # ir = node.irs[idx]
if isinstance(used, LocalIRVariable): # for used in ir.used:
used.index = local_variables_counter[used.name] # if isinstance(used, LocalIRVariable):
# used.index = local_variables_instances[used.name]
if isinstance(ir, OperationWithLValue): new_ir = copy_ir(ir, local_variables_instances)
if isinstance(ir.lvalue, LocalIRVariable): if new_ir:
counter = global_variables_counter[ir.lvalue.name] node.add_ssa_ir(new_ir)
counter = counter + 1
global_variables_counter[ir.lvalue.name] = counter if isinstance(new_ir, OperationWithLValue):
local_variables_counter[ir.lvalue.name] = counter if isinstance(new_ir.lvalue, LocalIRVariable):
ir.lvalue.index = counter new_var = LocalIRVariable(new_ir.lvalue)
new_var.index = global_variables_instances[new_ir.lvalue.name].index + 1
global_variables_instances[new_ir.lvalue.name] = new_var
local_variables_instances[new_ir.lvalue.name] = new_var
new_ir.lvalue = new_var
for succ in node.dominator_successors: for succ in node.dominator_successors:
rename_variables(succ, dict(local_variables_counter), global_variables_counter) generate_ssa_irs(succ, dict(local_variables_instances), global_variables_instances)
def add_phi_origins(node, variables_definition): def add_phi_origins(node, variables_definition):
@ -112,14 +154,208 @@ def add_phi_origins(node, variables_definition):
if not node.dominator_successors: if not node.dominator_successors:
return return
for succ in node.dominator_successors: for succ in node.dominator_successors:
add_phi_origins(succ, variables_definition) add_phi_origins(succ, variables_definition)
def copy_ir(ir, variables_instances):
'''
Args:
ir (Operation)
variables_instances(dict(str -> Variable))
'''
def get_variable(ir, f):
variable = f(ir)
if isinstance(variable, LocalVariable) and variable.name in variables_instances:
variable = variables_instances[variable.name]
return variable
def get_arguments(ir):
arguments = []
for arg in ir.arguments:
if isinstance(arg, LocalVariable) and arg.name in variables_instances:
arg = variables_instances[arg.name]
arguments.append(arg)
return arguments
def get_rec_values(ir, f):
# Use by InitArray and NewArray
# Potential recursive array(s)
ori_init_values = f(ir)
def traversal(values):
ret = []
for v in values:
if isinstance(v, list):
v = traversal(v)
else:
if isinstance(v, LocalVariable) and v.name in variables_instances:
v = variables_instances[v.name]
ret.append(v)
return ret
return traversal(ori_init_values)
if isinstance(ir, Assignment):
lvalue = get_variable(ir, lambda x: ir.lvalue)
rvalue = get_variable(ir, lambda x: ir.rvalue)
variable_return_type = ir.variable_return_type
return Assignment(lvalue, rvalue, variable_return_type)
elif isinstance(ir, Balance):
lvalue = get_variable(ir, lambda x: ir.lvalue)
value = get_variable(ir, lambda x: ir.value)
return Balance(value, lvalue)
elif isinstance(ir, Binary):
lvalue = get_variable(ir, lambda x: ir.lvalue)
variable_left = get_variable(ir, lambda x: ir.variable_left)
variable_right = get_variable(ir, lambda x: ir.variable_right)
operation_type = ir.type
return Binary(lvalue, variable_left, variable_right, operation_type)
elif isinstance(ir, Condition):
val = get_variable(ir, lambda x: ir.value)
return Condition(val)
elif isinstance(ir, Delete):
lvalue = get_variable(ir, lambda x: ir.lvalue)
variable = get_variable(ir, lambda x: ir.variable)
return Delete(lvalue, variable)
elif isinstance(ir, EventCall):
name = ir.name
return EventCall(name)
elif isinstance(ir, HighLevelCall): # include LibraryCall
destination = get_variable(ir, lambda x: ir.destination)
function_name = ir.function_name
nbr_arguments = ir.nbr_arguments
lvalue = get_variable(ir, lambda x: ir.lvalue)
type_call = ir.type_call
if isinstance(ir, LibraryCall):
new_ir = LibraryCall(destination, function_name, nbr_arguments, lvalue, type_call)
else:
new_ir = HighLevelCall(destination, function_name, nbr_arguments, lvalue, type_call)
new_ir.call_id = ir.call_id
new_ir.call_value = get_variable(ir, lambda x: ir.call_value)
new_ir.call_gas = get_variable(ir, lambda x: ir.call_gas)
new_ir.arguments = get_arguments(ir)
new_ir.function_instance = ir.function
return new_ir
elif isinstance(ir, Index):
lvalue = get_variable(ir, lambda x: ir.lvalue)
variable_left = get_variable(ir, lambda x: ir.variable_left)
variable_right = get_variable(ir, lambda x: ir.variable_right)
index_type = ir.index_type
return Index(lvalue, variable_left, variable_right, index_type)
elif isinstance(ir, InitArray):
lvalue = get_variable(ir, lambda x: ir.lvalue)
init_values = get_rec_values(ir, lambda x: ir.init_values)
return InitArray(init_values, lvalue)
elif isinstance(ir, InternalCall):
function = ir.function
nbr_arguments = ir.nbr_arguments
lvalue = get_variable(ir, lambda x: ir.lvalue)
type_call = ir.type_call
new_ir = InternalCall(function, nbr_arguments, lvalue, type_call)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, InternalDynamicCall):
lvalue = get_variable(ir, lambda x: ir.lvalue)
function = ir.function
function_type = ir.function_type
new_ir = InternalDynamicCall(lvalue, function, function_type)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, LowLevelCall):
destination = get_variable(ir, lambda x: x.destination)
function_name = ir.function_name
nbr_arguments = ir.nbr_arguments
lvalue = get_variable(ir, lambda x: ir.lvalue)
type_call = ir.type_call
new_ir = LowLevelCall(destination, function_name, nbr_arguments, lvalue, type_call)
new_ir.call_id = ir.call_id
new_ir.call_value = get_variable(ir, lambda x: ir.call_value)
new_ir.call_gas = get_variable(ir, lambda x: ir.call_gas)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, Member):
lvalue = get_variable(ir, lambda x: ir.lvalue)
variable_left = get_variable(ir, lambda x: ir.variable_left)
variable_right = get_variable(ir, lambda x: ir.variable_right)
return Member(variable_left, variable_right, lvalue)
elif isinstance(ir, NewArray):
depth = ir.depth
array_type = ir.array_type
lvalue = get_variable(ir, lambda x: ir.lvalue)
new_ir = NewArray(depth, array_type, lvalue)
new_ir.arguments = get_rec_values(ir, lambda x: ir.arguments)
return new_ir
elif isinstance(ir, NewElementaryType):
new_type = ir.type
lvalue = get_variable(ir, lambda x: ir.lvalue)
new_ir = NewElementaryType(new_type, lvalue)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, NewContract):
contract_name = ir.contract_name
lvalue = get_variable(ir, lambda x: ir.lvalue)
new_ir = NewContract(contract_name, lvalue)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, NewStructure):
structure = ir.structure
lvalue = get_variable(ir, lambda x: ir.lvalue)
new_ir = NewStructure(structure, lvalue)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, Push):
array = get_variable(ir, lambda x: ir.array)
lvalue = get_variable(ir, lambda x: ir.lvalue)
return Push(array, lvalue)
elif isinstance(ir, Return):
value = get_variable(ir, lambda x: ir.value)
return Return(value)
elif isinstance(ir, Send):
destination = get_variable(ir, lambda x: ir.destination)
value = get_variable(ir, lambda x: ir.call_value)
lvalue = get_variable(ir, lambda x: ir.lvalue)
return Send(destination, value, lvalue)
elif isinstance(ir, SolidityCall):
function = ir.function
nbr_arguments = ir.nbr_arguments
lvalue = get_variable(ir, lambda x: ir.lvalue)
type_call = ir.type_call
new_ir = SolidityCall(function, nbr_arguments, lvalue, type_call)
new_ir.arguments = get_arguments(ir)
return new_ir
elif isinstance(ir, Transfer):
destination = get_variable(ir, lambda x: ir.destination)
value = get_variable(ir, lambda x: ir.call_value)
return Transfer(destination, value)
elif isinstance(ir, TypeConversion):
lvalue = get_variable(ir, lambda x: ir.lvalue)
variable = get_variable(ir, lambda x: ir.variable)
variable_type = ir.type
return TypeConversion(lvalue, variable, variable_type)
elif isinstance(ir, Unary):
lvalue = get_variable(ir, lambda x: ir.lvalue)
rvalue = get_variable(ir, lambda x: ir.rvalue)
operation_type = ir.type
return Unary(lvalue, rvalue, operation_type)
elif isinstance(ir, Unpack):
lvalue = get_variable(ir, lambda x: ir.lvalue)
tuple_var = ir.tuple
idx = ir.index
return Unpack(lvalue, tuple_var, idx)
elif isinstance(ir, Length):
lvalue = get_variable(ir, lambda x: ir.lvalue)
value = get_variable(ir, lambda x: ir.value)
return Length(value, lvalue)
logger.error('Impossible ir copy on {} ({})'.format(ir, type(ir)))
exit(-1)
def transform_localir_vars_to_ssa(function): def transform_localir_vars_to_ssa(function):
""" """
Transform slithIR vars to SSA Transform slithIR vars to SSA
""" """
pass pass

@ -20,7 +20,7 @@ from slither.visitors.expression.has_conditional import HasConditional
from slither.utils.expression_manipulations import SplitTernaryExpression from slither.utils.expression_manipulations import SplitTernaryExpression
from slither.slithir.utils.ssa import transform_slithir_vars_to_ssa, add_phi_operations from slither.slithir.utils.ssa import transform_slithir_vars_to_ssa, add_ssa_ir
from slither.core.dominators.utils import compute_dominators, compute_dominance_frontier from slither.core.dominators.utils import compute_dominators, compute_dominance_frontier
logger = logging.getLogger("FunctionSolc") logger = logging.getLogger("FunctionSolc")
@ -839,8 +839,8 @@ class FunctionSolc(Function):
compute_dominance_frontier(self.nodes) compute_dominance_frontier(self.nodes)
for node in self.nodes: for node in self.nodes:
node.slithir_generation() node.slithir_generation()
add_phi_operations(self)
transform_slithir_vars_to_ssa(self) transform_slithir_vars_to_ssa(self)
add_ssa_ir(self)
self._analyze_read_write() self._analyze_read_write()
self._analyze_calls() self._analyze_calls()

@ -20,8 +20,8 @@ from slither.slithir.variables import (Constant, ReferenceVariable,
from slither.visitors.expression.expression import ExpressionVisitor from slither.visitors.expression.expression import ExpressionVisitor
from slither.core.variables.state_variable import StateVariable from slither.core.variables.state_variable import StateVariable
from slither.core.variables.local_variable import LocalVariable from slither.core.variables.local_variable import LocalVariable
from slither.slithir.variables.state_variable import StateIRVariable #from slither.slithir.variables.state_variable import StateIRVariable
from slither.slithir.variables.local_variable import LocalIRVariable #from slither.slithir.variables.local_variable import LocalIRVariable
logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") logger = logging.getLogger("VISTIOR:ExpressionToSlithIR")
@ -37,8 +37,9 @@ def set_val(expression, val):
expression.context[key] = val expression.context[key] = val
def convert_assignment(left, right, t, return_type): def convert_assignment(left, right, t, return_type):
if isinstance(left, LocalVariable): # if isinstance(left, LocalVariable):
left = LocalIRVariable(left) # left = LocalIRVariable(left)
# print(left)
if t == AssignmentOperationType.ASSIGN: if t == AssignmentOperationType.ASSIGN:
return Assignment(left, right, return_type) return Assignment(left, right, return_type)
elif t == AssignmentOperationType.ASSIGN_OR: elif t == AssignmentOperationType.ASSIGN_OR:
@ -108,7 +109,7 @@ class ExpressionToSlithIR(ExpressionVisitor):
self._result.append(operation) self._result.append(operation)
# Return left to handle # Return left to handle
# a = b = 1; # a = b = 1;
set_val(expression, left) set_val(expression, operation.lvalue)
def _post_binary_operation(self, expression): def _post_binary_operation(self, expression):
left = get(expression.expression_left) left = get(expression.expression_left)
@ -156,12 +157,12 @@ class ExpressionToSlithIR(ExpressionVisitor):
set_val(expression, expression.type) set_val(expression, expression.type)
def _post_identifier(self, expression): def _post_identifier(self, expression):
if isinstance(expression.value, StateVariable): # if isinstance(expression.value, StateVariable):
set_val(expression, StateIRVariable(expression.value)) # set_val(expression, StateIRVariable(expression.value))
elif isinstance(expression.value, LocalVariable): # elif isinstance(expression.value, LocalVariable):
set_val(expression, LocalIRVariable(expression.value)) # set_val(expression, LocalIRVariable(expression.value))
else: # else:
assert isinstance(expression.value, (SolidityVariable, SolidityFunction, Function, Contract, Enum, Structure)) assert isinstance(expression.value, (SolidityVariable, SolidityFunction, Function, Contract, Enum, Structure, StateVariable, LocalVariable))
set_val(expression, expression.value) set_val(expression, expression.value)
def _post_index_access(self, expression): def _post_index_access(self, expression):
@ -219,9 +220,9 @@ class ExpressionToSlithIR(ExpressionVisitor):
def _post_unary_operation(self, expression): def _post_unary_operation(self, expression):
value = get(expression.expression) value = get(expression.expression)
new_value = value new_value = value
# need new instance for ssa # # need new instance for ssa
if isinstance(new_value, LocalVariable): # if isinstance(new_value, LocalVariable):
new_value = LocalIRVariable(new_value) # new_value = LocalIRVariable(new_value)
if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]:
lvalue = TemporaryVariable(self._node) lvalue = TemporaryVariable(self._node)
operation = Unary(lvalue, value, expression.type) operation = Unary(lvalue, value, expression.type)

Loading…
Cancel
Save