diff --git a/slither/__main__.py b/slither/__main__.py index 6c1e264dd..876dd6443 100644 --- a/slither/__main__.py +++ b/slither/__main__.py @@ -157,6 +157,7 @@ def get_detectors_and_printers(): from slither.printers.call.call_graph import PrinterCallGraph from slither.printers.functions.authorization import PrinterWrittenVariablesAndAuthorization from slither.printers.summary.slithir import PrinterSlithIR + from slither.printers.summary.slithir_ssa import PrinterSlithIRSSA from slither.printers.summary.human_summary import PrinterHumanSummary printers = [FunctionSummary, @@ -166,6 +167,7 @@ def get_detectors_and_printers(): PrinterCallGraph, PrinterWrittenVariablesAndAuthorization, PrinterSlithIR, + PrinterSlithIRSSA, PrinterHumanSummary] # Handle plugins! @@ -213,6 +215,7 @@ def main_impl(all_detector_classes, all_printer_classes): ('FunctionSolc', default_log), ('ExpressionParsing', default_log), ('TypeParsing', default_log), + ('SSA_Conversion', default_log), ('Printers', default_log)]: l = logging.getLogger(l_name) l.setLevel(l_level) diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index b593831e1..b9077f976 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -131,6 +131,7 @@ class Node(SourceMapping, ChildFunction): self._low_level_calls = [] self._external_calls_as_expressions = [] self._irs = [] + self._irs_ssa = [] self._state_vars_written = [] self._state_vars_read = [] @@ -424,8 +425,24 @@ class Node(SourceMapping, ChildFunction): """ return self._irs - def add_pre_ir(self, ir): - self._irs.insert(0, ir) + @property + def irs_ssa(self): + """ Returns the slithIR representation with SSA + + return + list(slithIR.Operation) + """ + return self._irs_ssa + + @irs_ssa.setter + def irs_ssa(self, irs): + self._irs_ssa = irs + + def add_ssa_ir(self, ir): + ''' + Use to place phi operation + ''' + self._irs_ssa.append(ir) def slithir_generation(self): if self.expression: diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index fd54a1113..7b497fcd1 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -13,7 +13,7 @@ class HighLevelCall(Call, OperationWithLValue): def __init__(self, destination, function_name, nbr_arguments, result, type_call): assert isinstance(function_name, Constant) - assert is_valid_lvalue(result) + assert is_valid_lvalue(result) or result is None self._check_destination(destination) super(HighLevelCall, self).__init__() self._destination = destination diff --git a/slither/slithir/operations/index.py b/slither/slithir/operations/index.py index ac3140f46..354ab56a4 100644 --- a/slither/slithir/operations/index.py +++ b/slither/slithir/operations/index.py @@ -33,5 +33,9 @@ class Index(OperationWithLValue): def variable_right(self): return self._variables[1] + @property + def index_type(self): + return self._type + def __str__(self): return "{}({}) -> {}[{}]".format(self.lvalue, self.lvalue.type, self.variable_left, self.variable_right) diff --git a/slither/slithir/operations/unary.py b/slither/slithir/operations/unary.py index 44bcdf0bc..9a09f09ba 100644 --- a/slither/slithir/operations/unary.py +++ b/slither/slithir/operations/unary.py @@ -48,6 +48,10 @@ class Unary(OperationWithLValue): def rvalue(self): return self._variable + @property + def type(self): + return self._type + @property def type_str(self): return UnaryType.str(self._type) diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index b71dbd9e3..9b325e036 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -1,7 +1,25 @@ -from slither.slithir.variables import (Constant, ReferenceVariable, - TemporaryVariable, TupleVariable) -from slither.slithir.operations import OperationWithLValue, Phi -from slither.slithir.variables import LocalIRVariable +import logging + +from slither.core.cfg.node import NodeType +from slither.core.variables.local_variable import LocalVariable +from slither.slithir.operations import (Assignment, Balance, Binary, + BinaryType, Condition, Delete, + EventCall, HighLevelCall, Index, + InitArray, InternalCall, + InternalDynamicCall, Length, + LibraryCall, LowLevelCall, Member, + NewArray, NewContract, + NewElementaryType, NewStructure, + OperationWithLValue, Phi, Push, Return, + Send, SolidityCall, Transfer, + TypeConversion, Unary, Unpack) +from slither.slithir.variables import (Constant, LocalIRVariable, + ReferenceVariable, TemporaryVariable, + TupleVariable) + +logger = logging.getLogger('SSA_Conversion') + + def transform_slithir_vars_to_ssa(function): """ @@ -9,7 +27,7 @@ def transform_slithir_vars_to_ssa(function): """ variables = [] for node in function.nodes: - for ir in node.irs: + for ir in node.irs_ssa: if isinstance(ir, OperationWithLValue) and not ir.lvalue in variables: variables += [ir.lvalue] @@ -23,12 +41,18 @@ def transform_slithir_vars_to_ssa(function): for idx in range(len(tuple_variables)): tuple_variables[idx].index = idx -def add_phi_operations(function): +def add_ssa_ir(function): + ''' + Add SSA version of the IR + ''' + + if not function.is_implemented: + return + init_definition = dict() for v in function.parameters+function.returns: if v.name: init_definition[v.name] = function.entry_point - add_phi_origins(function.entry_point, init_definition) for node in function.nodes: @@ -39,57 +63,75 @@ def add_phi_operations(function): # an instance of the variable # by looking at the variables written # of any of the nodes - n = list(nodes)[0] - variable = next((v for v in n.variables_written if v.name == variable_name), None) - if variable is None: - variable = n.variable_declaration - node.add_pre_ir(Phi(LocalIRVariable(variable), nodes)) + for n in nodes: + variable = next((v for v in n.variables_written if v.name == variable_name), None) + if variable is None: + variable = n.variable_declaration + if variable: + break + assert variable + node.add_ssa_ir(Phi(LocalIRVariable(variable), nodes)) - init_local_variables_counter = dict() + init_local_variables_instances = dict() for v in function.parameters+function.returns: if v.name: - init_local_variables_counter[v.name] = 0 - init_global_variables_counter = dict(init_local_variables_counter) - rename_variables(function.entry_point, init_local_variables_counter, init_global_variables_counter) + init_local_variables_instances[v.name] = LocalIRVariable(v) + init_global_variables_instances = dict(init_local_variables_instances) + generate_ssa_irs(function.entry_point, + dict(init_local_variables_instances), + init_global_variables_instances) + + fix_phi_operations(function.nodes, init_local_variables_instances) - fix_phi_operations(function.nodes) -def fix_phi_operations(nodes): +def fix_phi_operations(nodes, init_vars): def last_name(n, var): - candidates = [v for v in n.variables_written if v.name == var.name] + candidates = [] + # Todo optimize by creating a variables_ssa_written attribute + for ir_ssa in n.irs_ssa: + if isinstance(ir_ssa, OperationWithLValue): + if ir_ssa.lvalue and ir_ssa.lvalue.name == var.name: + candidates.append(ir_ssa.lvalue) if n.variable_declaration and n.variable_declaration.name == var.name: candidates.append(LocalIRVariable(n.variable_declaration)) + if n.type == NodeType.ENTRYPOINT: + if var.name in init_vars: + candidates.append(init_vars[var.name]) assert candidates return max(candidates, key=lambda v: v.index) for node in nodes: - for ir in node.irs: + for ir in node.irs_ssa: if isinstance(ir, Phi): variables = [last_name(dst, ir.lvalue) for dst in ir.nodes] ir.rvalues = variables -def rename_variables(node, local_variables_counter, global_variables_counter): +def generate_ssa_irs(node, local_variables_instances, global_variables_instances): if node.variable_declaration: - local_variables_counter[node.variable_declaration.name] = 0 - global_variables_counter[node.variable_declaration.name] = 0 - - for idx in range(len(node.irs)): - ir = node.irs[idx] - for used in ir.used: - if isinstance(used, LocalIRVariable): - used.index = local_variables_counter[used.name] - - if isinstance(ir, OperationWithLValue): - if isinstance(ir.lvalue, LocalIRVariable): - counter = global_variables_counter[ir.lvalue.name] - counter = counter + 1 - global_variables_counter[ir.lvalue.name] = counter - local_variables_counter[ir.lvalue.name] = counter - ir.lvalue.index = counter + new_var = LocalIRVariable(node.variable_declaration) + local_variables_instances[node.variable_declaration.name] = new_var + global_variables_instances[node.variable_declaration.name] = new_var + + for ir in node.irs: +# ir = node.irs[idx] +# for used in ir.used: +# if isinstance(used, LocalIRVariable): +# used.index = local_variables_instances[used.name] + new_ir = copy_ir(ir, local_variables_instances) + if new_ir: + node.add_ssa_ir(new_ir) + + if isinstance(new_ir, OperationWithLValue): + if isinstance(new_ir.lvalue, LocalIRVariable): + new_var = LocalIRVariable(new_ir.lvalue) + new_var.index = global_variables_instances[new_ir.lvalue.name].index + 1 + global_variables_instances[new_ir.lvalue.name] = new_var + local_variables_instances[new_ir.lvalue.name] = new_var + new_ir.lvalue = new_var for succ in node.dominator_successors: - rename_variables(succ, dict(local_variables_counter), global_variables_counter) + generate_ssa_irs(succ, dict(local_variables_instances), global_variables_instances) def add_phi_origins(node, variables_definition): @@ -112,14 +154,208 @@ def add_phi_origins(node, variables_definition): if not node.dominator_successors: return - for succ in node.dominator_successors: add_phi_origins(succ, variables_definition) +def copy_ir(ir, variables_instances): + ''' + Args: + ir (Operation) + variables_instances(dict(str -> Variable)) + ''' + + def get_variable(ir, f): + variable = f(ir) + if isinstance(variable, LocalVariable) and variable.name in variables_instances: + variable = variables_instances[variable.name] + return variable + + def get_arguments(ir): + arguments = [] + for arg in ir.arguments: + if isinstance(arg, LocalVariable) and arg.name in variables_instances: + arg = variables_instances[arg.name] + arguments.append(arg) + return arguments + + def get_rec_values(ir, f): + # Use by InitArray and NewArray + # Potential recursive array(s) + ori_init_values = f(ir) + + def traversal(values): + ret = [] + for v in values: + if isinstance(v, list): + v = traversal(v) + else: + if isinstance(v, LocalVariable) and v.name in variables_instances: + v = variables_instances[v.name] + ret.append(v) + return ret + return traversal(ori_init_values) + + + if isinstance(ir, Assignment): + lvalue = get_variable(ir, lambda x: ir.lvalue) + rvalue = get_variable(ir, lambda x: ir.rvalue) + variable_return_type = ir.variable_return_type + return Assignment(lvalue, rvalue, variable_return_type) + elif isinstance(ir, Balance): + lvalue = get_variable(ir, lambda x: ir.lvalue) + value = get_variable(ir, lambda x: ir.value) + return Balance(value, lvalue) + elif isinstance(ir, Binary): + lvalue = get_variable(ir, lambda x: ir.lvalue) + variable_left = get_variable(ir, lambda x: ir.variable_left) + variable_right = get_variable(ir, lambda x: ir.variable_right) + operation_type = ir.type + return Binary(lvalue, variable_left, variable_right, operation_type) + elif isinstance(ir, Condition): + val = get_variable(ir, lambda x: ir.value) + return Condition(val) + elif isinstance(ir, Delete): + lvalue = get_variable(ir, lambda x: ir.lvalue) + variable = get_variable(ir, lambda x: ir.variable) + return Delete(lvalue, variable) + elif isinstance(ir, EventCall): + name = ir.name + return EventCall(name) + elif isinstance(ir, HighLevelCall): # include LibraryCall + destination = get_variable(ir, lambda x: ir.destination) + function_name = ir.function_name + nbr_arguments = ir.nbr_arguments + lvalue = get_variable(ir, lambda x: ir.lvalue) + type_call = ir.type_call + if isinstance(ir, LibraryCall): + new_ir = LibraryCall(destination, function_name, nbr_arguments, lvalue, type_call) + else: + new_ir = HighLevelCall(destination, function_name, nbr_arguments, lvalue, type_call) + new_ir.call_id = ir.call_id + new_ir.call_value = get_variable(ir, lambda x: ir.call_value) + new_ir.call_gas = get_variable(ir, lambda x: ir.call_gas) + new_ir.arguments = get_arguments(ir) + new_ir.function_instance = ir.function + return new_ir + elif isinstance(ir, Index): + lvalue = get_variable(ir, lambda x: ir.lvalue) + variable_left = get_variable(ir, lambda x: ir.variable_left) + variable_right = get_variable(ir, lambda x: ir.variable_right) + index_type = ir.index_type + return Index(lvalue, variable_left, variable_right, index_type) + elif isinstance(ir, InitArray): + lvalue = get_variable(ir, lambda x: ir.lvalue) + init_values = get_rec_values(ir, lambda x: ir.init_values) + return InitArray(init_values, lvalue) + elif isinstance(ir, InternalCall): + function = ir.function + nbr_arguments = ir.nbr_arguments + lvalue = get_variable(ir, lambda x: ir.lvalue) + type_call = ir.type_call + new_ir = InternalCall(function, nbr_arguments, lvalue, type_call) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, InternalDynamicCall): + lvalue = get_variable(ir, lambda x: ir.lvalue) + function = ir.function + function_type = ir.function_type + new_ir = InternalDynamicCall(lvalue, function, function_type) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, LowLevelCall): + destination = get_variable(ir, lambda x: x.destination) + function_name = ir.function_name + nbr_arguments = ir.nbr_arguments + lvalue = get_variable(ir, lambda x: ir.lvalue) + type_call = ir.type_call + new_ir = LowLevelCall(destination, function_name, nbr_arguments, lvalue, type_call) + new_ir.call_id = ir.call_id + new_ir.call_value = get_variable(ir, lambda x: ir.call_value) + new_ir.call_gas = get_variable(ir, lambda x: ir.call_gas) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, Member): + lvalue = get_variable(ir, lambda x: ir.lvalue) + variable_left = get_variable(ir, lambda x: ir.variable_left) + variable_right = get_variable(ir, lambda x: ir.variable_right) + return Member(variable_left, variable_right, lvalue) + elif isinstance(ir, NewArray): + depth = ir.depth + array_type = ir.array_type + lvalue = get_variable(ir, lambda x: ir.lvalue) + new_ir = NewArray(depth, array_type, lvalue) + new_ir.arguments = get_rec_values(ir, lambda x: ir.arguments) + return new_ir + elif isinstance(ir, NewElementaryType): + new_type = ir.type + lvalue = get_variable(ir, lambda x: ir.lvalue) + new_ir = NewElementaryType(new_type, lvalue) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, NewContract): + contract_name = ir.contract_name + lvalue = get_variable(ir, lambda x: ir.lvalue) + new_ir = NewContract(contract_name, lvalue) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, NewStructure): + structure = ir.structure + lvalue = get_variable(ir, lambda x: ir.lvalue) + new_ir = NewStructure(structure, lvalue) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, Push): + array = get_variable(ir, lambda x: ir.array) + lvalue = get_variable(ir, lambda x: ir.lvalue) + return Push(array, lvalue) + elif isinstance(ir, Return): + value = get_variable(ir, lambda x: ir.value) + return Return(value) + elif isinstance(ir, Send): + destination = get_variable(ir, lambda x: ir.destination) + value = get_variable(ir, lambda x: ir.call_value) + lvalue = get_variable(ir, lambda x: ir.lvalue) + return Send(destination, value, lvalue) + elif isinstance(ir, SolidityCall): + function = ir.function + nbr_arguments = ir.nbr_arguments + lvalue = get_variable(ir, lambda x: ir.lvalue) + type_call = ir.type_call + new_ir = SolidityCall(function, nbr_arguments, lvalue, type_call) + new_ir.arguments = get_arguments(ir) + return new_ir + elif isinstance(ir, Transfer): + destination = get_variable(ir, lambda x: ir.destination) + value = get_variable(ir, lambda x: ir.call_value) + return Transfer(destination, value) + elif isinstance(ir, TypeConversion): + lvalue = get_variable(ir, lambda x: ir.lvalue) + variable = get_variable(ir, lambda x: ir.variable) + variable_type = ir.type + return TypeConversion(lvalue, variable, variable_type) + elif isinstance(ir, Unary): + lvalue = get_variable(ir, lambda x: ir.lvalue) + rvalue = get_variable(ir, lambda x: ir.rvalue) + operation_type = ir.type + return Unary(lvalue, rvalue, operation_type) + elif isinstance(ir, Unpack): + lvalue = get_variable(ir, lambda x: ir.lvalue) + tuple_var = ir.tuple + idx = ir.index + return Unpack(lvalue, tuple_var, idx) + elif isinstance(ir, Length): + lvalue = get_variable(ir, lambda x: ir.lvalue) + value = get_variable(ir, lambda x: ir.value) + return Length(value, lvalue) + + + logger.error('Impossible ir copy on {} ({})'.format(ir, type(ir))) + exit(-1) def transform_localir_vars_to_ssa(function): """ Transform slithIR vars to SSA """ pass + diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 452ba6823..042fc4340 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -20,7 +20,7 @@ from slither.visitors.expression.has_conditional import HasConditional from slither.utils.expression_manipulations import SplitTernaryExpression -from slither.slithir.utils.ssa import transform_slithir_vars_to_ssa, add_phi_operations +from slither.slithir.utils.ssa import transform_slithir_vars_to_ssa, add_ssa_ir from slither.core.dominators.utils import compute_dominators, compute_dominance_frontier logger = logging.getLogger("FunctionSolc") @@ -839,8 +839,8 @@ class FunctionSolc(Function): compute_dominance_frontier(self.nodes) for node in self.nodes: node.slithir_generation() - add_phi_operations(self) transform_slithir_vars_to_ssa(self) + add_ssa_ir(self) self._analyze_read_write() self._analyze_calls() diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index 5682511df..b1673bd09 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -20,8 +20,8 @@ from slither.slithir.variables import (Constant, ReferenceVariable, from slither.visitors.expression.expression import ExpressionVisitor from slither.core.variables.state_variable import StateVariable from slither.core.variables.local_variable import LocalVariable -from slither.slithir.variables.state_variable import StateIRVariable -from slither.slithir.variables.local_variable import LocalIRVariable +#from slither.slithir.variables.state_variable import StateIRVariable +#from slither.slithir.variables.local_variable import LocalIRVariable logger = logging.getLogger("VISTIOR:ExpressionToSlithIR") @@ -37,8 +37,9 @@ def set_val(expression, val): expression.context[key] = val def convert_assignment(left, right, t, return_type): - if isinstance(left, LocalVariable): - left = LocalIRVariable(left) +# if isinstance(left, LocalVariable): +# left = LocalIRVariable(left) +# print(left) if t == AssignmentOperationType.ASSIGN: return Assignment(left, right, return_type) elif t == AssignmentOperationType.ASSIGN_OR: @@ -108,7 +109,7 @@ class ExpressionToSlithIR(ExpressionVisitor): self._result.append(operation) # Return left to handle # a = b = 1; - set_val(expression, left) + set_val(expression, operation.lvalue) def _post_binary_operation(self, expression): left = get(expression.expression_left) @@ -156,12 +157,12 @@ class ExpressionToSlithIR(ExpressionVisitor): set_val(expression, expression.type) def _post_identifier(self, expression): - if isinstance(expression.value, StateVariable): - set_val(expression, StateIRVariable(expression.value)) - elif isinstance(expression.value, LocalVariable): - set_val(expression, LocalIRVariable(expression.value)) - else: - assert isinstance(expression.value, (SolidityVariable, SolidityFunction, Function, Contract, Enum, Structure)) +# if isinstance(expression.value, StateVariable): +# set_val(expression, StateIRVariable(expression.value)) +# elif isinstance(expression.value, LocalVariable): +# set_val(expression, LocalIRVariable(expression.value)) +# else: + assert isinstance(expression.value, (SolidityVariable, SolidityFunction, Function, Contract, Enum, Structure, StateVariable, LocalVariable)) set_val(expression, expression.value) def _post_index_access(self, expression): @@ -219,9 +220,9 @@ class ExpressionToSlithIR(ExpressionVisitor): def _post_unary_operation(self, expression): value = get(expression.expression) new_value = value - # need new instance for ssa - if isinstance(new_value, LocalVariable): - new_value = LocalIRVariable(new_value) + # # need new instance for ssa + # if isinstance(new_value, LocalVariable): + # new_value = LocalIRVariable(new_value) if expression.type in [UnaryOperationType.BANG, UnaryOperationType.TILD]: lvalue = TemporaryVariable(self._node) operation = Unary(lvalue, value, expression.type)