From 8da3b86a306a62300f6aa981a95045380924f620 Mon Sep 17 00:00:00 2001 From: Josselin Date: Fri, 5 Jul 2019 13:29:56 +0200 Subject: [PATCH] Add a new type of function: constructor_variable to hold state variable initialization Add Function.FunctionType to determine if the function is a constructor/fallback/constructor_variables API change: function.is_fallback() becomes a property --- slither/core/declarations/function.py | 209 ++++++++++++++++-- slither/solc_parsing/declarations/contract.py | 47 +++- slither/solc_parsing/declarations/function.py | 142 ++---------- slither/solc_parsing/slitherSolc.py | 4 + 4 files changed, 252 insertions(+), 150 deletions(-) diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 6025f7829..631e262bf 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -4,6 +4,7 @@ import logging from collections import namedtuple from itertools import groupby +from enum import Enum from slither.core.children.child_contract import ChildContract from slither.core.children.child_inheritance import ChildInheritance @@ -13,8 +14,9 @@ from slither.core.declarations.solidity_variables import (SolidityFunction, from slither.core.expressions import (Identifier, IndexAccess, MemberAccess, UnaryOperation) from slither.core.source_mapping.source_mapping import SourceMapping -from slither.core.variables.state_variable import StateVariable +from slither.core.variables.state_variable import StateVariable +from slither.utils.utils import unroll logger = logging.getLogger("Function") @@ -22,6 +24,12 @@ ReacheableNode = namedtuple('ReacheableNode', ['node', 'ir']) ModifierStatements = namedtuple('Modifier', ['modifier', 'node']) +class FunctionType(Enum): + NORMAL = 0 + CONSTRUCTOR = 1 + FALLBACK = 2 + CONSTRUCTOR_VARIABLES = 3 # Fake function to hold variable declaration statements + class Function(ChildContract, ChildInheritance, SourceMapping): """ Function class @@ -34,7 +42,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping): self._pure = None self._payable = None self._visibility = None - self._is_constructor = None + self._is_implemented = None self._is_empty = None self._entry_point = None @@ -94,6 +102,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): self._reachable_from_nodes = set() self._reachable_from_functions = set() + # Constructor, fallback, State variable constructor + self._function_type = None + self._is_constructor = None ################################################################################### ################################################################################### @@ -106,11 +117,12 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ str: function name """ - if self._name == '': - if self.is_constructor: - return 'constructor' - else: - return 'fallback' + if self._function_type == FunctionType.CONSTRUCTOR: + return 'constructor' + elif self._function_type == FunctionType.FALLBACK: + return 'fallback' + elif self._function_type == FunctionType.CONSTRUCTOR_VARIABLES: + return 'slither_constructor_variables' return self._name @property @@ -131,12 +143,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping): name, parameters, _ = self.signature return self.contract_declarer.name + '.' + name + '(' + ','.join(parameters) + ')' - @property - def is_constructor(self): - """ - bool: True if the function is the constructor - """ - return self._is_constructor or self._name == self.contract_declarer.name @property def contains_assembly(self): @@ -154,6 +160,41 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ return self.contract_declarer == contract + # endregion + ################################################################################### + ################################################################################### + # region Type (FunctionType) + ################################################################################### + ################################################################################### + + def set_function_type(self, t): + assert isinstance(t, FunctionType) + self._function_type = t + + @property + def is_constructor(self): + """ + bool: True if the function is the constructor + """ + return self._function_type == FunctionType.CONSTRUCTOR + + @property + def is_constructor_variables(self): + """ + bool: True if the function is the constructor of the variables + Slither has a inbuilt function to hold the state variables initialization + """ + return self._function_type == FunctionType.CONSTRUCTOR_VARIABLES + + @property + def is_fallback(self): + """ + Determine if the function is the fallback function for the contract + Returns + (bool) + """ + return self._function_type == FunctionType.FALLBACK + # endregion ################################################################################### ################################################################################### @@ -182,6 +223,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ return self._visibility + def set_visibility(self, v): + self._visibility = v + @property def view(self): """ @@ -248,6 +292,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping): """ return self._entry_point + def add_node(self, node): + if not self._entry_point: + self._entry_point = node + self._nodes.append(node) + # endregion ################################################################################### ################################################################################### @@ -1005,13 +1054,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping): args_vars = self.all_solidity_variables_used_as_args() return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars - def is_fallback(self): - """ - Determine if the function is the fallback function for the contract - Returns - (bool) - """ - return self._name == "" and not self.is_constructor # endregion ################################################################################### @@ -1119,6 +1161,133 @@ class Function(ChildContract, ChildInheritance, SourceMapping): # endregion + ################################################################################### + ################################################################################### + # region SlithIr and SSA + ################################################################################### + ################################################################################### + + def get_last_ssa_state_variables_instances(self): + from slither.slithir.variables import ReferenceVariable + from slither.slithir.operations import OperationWithLValue + from slither.core.cfg.node import NodeType + + if not self.is_implemented: + return dict() + + # node, values + to_explore = [(self._entry_point, dict())] + # node -> values + explored = dict() + # name -> instances + ret = dict() + + while to_explore: + node, values = to_explore[0] + to_explore = to_explore[1::] + + if node.type != NodeType.ENTRYPOINT: + for ir_ssa in node.irs_ssa: + if isinstance(ir_ssa, OperationWithLValue): + lvalue = ir_ssa.lvalue + if isinstance(lvalue, ReferenceVariable): + lvalue = lvalue.points_to_origin + if isinstance(lvalue, StateVariable): + values[lvalue.canonical_name] = {lvalue} + + # Check for fixpoint + if node in explored: + if values == explored[node]: + continue + for k, instances in values.items(): + if not k in explored[node]: + explored[node][k] = set() + explored[node][k] |= instances + values = explored[node] + else: + explored[node] = values + + # Return condition + if not node.sons and node.type != NodeType.THROW: + for name, instances in values.items(): + if name not in ret: + ret[name] = set() + ret[name] |= instances + + for son in node.sons: + to_explore.append((son, dict(values))) + + return ret + + @staticmethod + def _unchange_phi(ir): + from slither.slithir.operations import (Phi, PhiCallback) + if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1: + return False + if not ir.rvalues: + return True + return ir.rvalues[0] == ir.lvalue + + def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): + from slither.slithir.operations import (InternalCall, PhiCallback) + from slither.slithir.variables import (Constant, StateIRVariable) + for node in self.nodes: + for ir in node.irs_ssa: + if node == self.entry_point: + if isinstance(ir.lvalue, StateIRVariable): + additional = [initial_state_variables_instances[ir.lvalue.canonical_name]] + additional += last_state_variables_instances[ir.lvalue.canonical_name] + ir.rvalues = list(set(additional + ir.rvalues)) + # function parameter + else: + # find index of the parameter + idx = self.parameters.index(ir.lvalue.non_ssa_version) + # find non ssa version of that index + additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes] + additional = unroll(additional) + additional = [a for a in additional if not isinstance(a, Constant)] + ir.rvalues = list(set(additional + ir.rvalues)) + if isinstance(ir, PhiCallback): + callee_ir = ir.callee_ir + if isinstance(callee_ir, InternalCall): + last_ssa = callee_ir.function.get_last_ssa_state_variables_instances() + if ir.lvalue.canonical_name in last_ssa: + ir.rvalues = list(last_ssa[ir.lvalue.canonical_name]) + else: + ir.rvalues = [ir.lvalue] + else: + additional = last_state_variables_instances[ir.lvalue.canonical_name] + ir.rvalues = list(set(additional + ir.rvalues)) + + node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] + + def generate_slithir_and_analyze(self): + for node in self.nodes: + node.slithir_generation() + + for modifier_statement in self.modifiers_statements: + modifier_statement.node.slithir_generation() + + for modifier_statement in self.explicit_base_constructor_calls_statements: + modifier_statement.node.slithir_generation() + + self._analyze_read_write() + self._analyze_calls() + + def generate_slithir_ssa(self, all_ssa_state_variables_instances): + from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa + from slither.core.dominators.utils import (compute_dominance_frontier, + compute_dominators) + compute_dominators(self.nodes) + compute_dominance_frontier(self.nodes) + transform_slithir_vars_to_ssa(self) + add_ssa_ir(self, all_ssa_state_variables_instances) + + def update_read_write_using_ssa(self): + for node in self.nodes: + node.update_read_write_using_ssa() + self._analyze_read_write() + ################################################################################### ################################################################################### # region Built in definitions diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index d3a01c51f..f5fba278c 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -1,7 +1,10 @@ import logging from slither.core.declarations.contract import Contract +from slither.core.declarations.function import Function, FunctionType from slither.core.declarations.enum import Enum +from slither.core.cfg.node import Node, NodeType +from slither.core.expressions import AssignmentOperation, Identifier, AssignmentOperationType from slither.slithir.variables import StateIRVariable from slither.solc_parsing.declarations.event import EventSolc from slither.solc_parsing.declarations.function import FunctionSolc @@ -319,7 +322,6 @@ class ContractSolc04(Contract): :return: """ all_elements = {} - accessible_elements = {} for father in self.inheritance: for element in getter(father): @@ -368,6 +370,49 @@ class ContractSolc04(Contract): pass return + + def _create_node(self, func, counter, variable): + # Function uses to create node for state variable declaration statements + node = Node(NodeType.STANDALONE, counter) + node.set_offset(variable.source_mapping, self.slither) + node.set_function(func) + func.add_node(node) + print(variable.expression) + expression = AssignmentOperation(Identifier(variable), + variable.expression, + AssignmentOperationType.ASSIGN, + variable.type) + + node.add_expression(expression) + return node + + def add_constructor_variables(self): + if self.state_variables: + found_candidate = False + for (idx, variable_candidate) in enumerate(self.state_variables): + if variable_candidate.expression and not variable_candidate.is_constant: + found_candidate = True + break + if found_candidate: + constructor_variable = Function() + constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES) + constructor_variable.set_contract(self) + constructor_variable.set_contract_declarer(self) + constructor_variable.set_visibility('internal') + self._functions[constructor_variable.canonical_name] = constructor_variable + + prev_node = self._create_node(constructor_variable, 0, variable_candidate) + counter = 1 + for v in self.state_variables[idx+1:]: + if v.expression and not v.is_constant: + next_node = self._create_node(constructor_variable, counter, v) + prev_node.add_son(next_node) + next_node.add_father(prev_node) + counter += 1 + + + + def analyze_state_variables(self): for var in self.variables: var.analyze(self) diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index 6da6233c5..5c95cc231 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -4,16 +4,10 @@ import logging from slither.core.cfg.node import NodeType, link_nodes from slither.core.declarations.contract import Contract -from slither.core.declarations.function import Function, ModifierStatements -from slither.core.dominators.utils import (compute_dominance_frontier, - compute_dominators) +from slither.core.declarations.function import Function, ModifierStatements, FunctionType + from slither.core.expressions import AssignmentOperation -from slither.core.variables.state_variable import StateVariable -from slither.slithir.operations import (InternalCall, OperationWithLValue, Phi, - PhiCallback) -from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa -from slither.slithir.variables import (Constant, ReferenceVariable, - StateIRVariable) + from slither.solc_parsing.cfg.node import NodeSolc from slither.solc_parsing.expressions.expression_parsing import \ parse_expression @@ -23,7 +17,6 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import \ from slither.solc_parsing.variables.variable_declaration import \ MultipleVariablesDeclaration from slither.utils.expression_manipulations import SplitTernaryExpression -from slither.utils.utils import unroll from slither.visitors.expression.export_values import ExportValues from slither.visitors.expression.has_conditional import HasConditional from slither.solc_parsing.exceptions import ParsingError @@ -139,14 +132,20 @@ class FunctionSolc(Function): if 'constant' in attributes: self._view = attributes['constant'] - self._is_constructor = False + if self._name == '': + self._function_type = FunctionType.FALLBACK + else: + self._function_type = FunctionType.NORMAL + + if self._name == self.contract_declarer.name: + self._function_type = FunctionType.CONSTRUCTOR - if 'isConstructor' in attributes: - self._is_constructor = attributes['isConstructor'] + if 'isConstructor' in attributes and attributes['isConstructor']: + self._function_type = FunctionType.CONSTRUCTOR if 'kind' in attributes: if attributes['kind'] == 'constructor': - self._is_constructor = True + self._function_type = FunctionType.CONSTRUCTOR if 'visibility' in attributes: self._visibility = attributes['visibility'] @@ -1038,120 +1037,5 @@ class FunctionSolc(Function): # endregion - ################################################################################### - ################################################################################### - # region SlithIr and SSA - ################################################################################### - ################################################################################### - - def get_last_ssa_state_variables_instances(self): - if not self.is_implemented: - return dict() - - # node, values - to_explore = [(self._entry_point, dict())] - # node -> values - explored = dict() - # name -> instances - ret = dict() - - while to_explore: - node, values = to_explore[0] - to_explore = to_explore[1::] - - if node.type != NodeType.ENTRYPOINT: - for ir_ssa in node.irs_ssa: - if isinstance(ir_ssa, OperationWithLValue): - lvalue = ir_ssa.lvalue - if isinstance(lvalue, ReferenceVariable): - lvalue = lvalue.points_to_origin - if isinstance(lvalue, StateVariable): - values[lvalue.canonical_name] = {lvalue} - - # Check for fixpoint - if node in explored: - if values == explored[node]: - continue - for k, instances in values.items(): - if not k in explored[node]: - explored[node][k] = set() - explored[node][k] |= instances - values = explored[node] - else: - explored[node] = values - - # Return condition - if not node.sons and node.type != NodeType.THROW: - for name, instances in values.items(): - if name not in ret: - ret[name] = set() - ret[name] |= instances - - for son in node.sons: - to_explore.append((son, dict(values))) - return ret - - @staticmethod - def _unchange_phi(ir): - if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1: - return False - if not ir.rvalues: - return True - return ir.rvalues[0] == ir.lvalue - - def fix_phi(self, last_state_variables_instances, initial_state_variables_instances): - for node in self.nodes: - for ir in node.irs_ssa: - if node == self.entry_point: - if isinstance(ir.lvalue, StateIRVariable): - additional = [initial_state_variables_instances[ir.lvalue.canonical_name]] - additional += last_state_variables_instances[ir.lvalue.canonical_name] - ir.rvalues = list(set(additional + ir.rvalues)) - # function parameter - else: - # find index of the parameter - idx = self.parameters.index(ir.lvalue.non_ssa_version) - # find non ssa version of that index - additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes] - additional = unroll(additional) - additional = [a for a in additional if not isinstance(a, Constant)] - ir.rvalues = list(set(additional + ir.rvalues)) - if isinstance(ir, PhiCallback): - callee_ir = ir.callee_ir - if isinstance(callee_ir, InternalCall): - last_ssa = callee_ir.function.get_last_ssa_state_variables_instances() - if ir.lvalue.canonical_name in last_ssa: - ir.rvalues = list(last_ssa[ir.lvalue.canonical_name]) - else: - ir.rvalues = [ir.lvalue] - else: - additional = last_state_variables_instances[ir.lvalue.canonical_name] - ir.rvalues = list(set(additional + ir.rvalues)) - - node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)] - - def generate_slithir_and_analyze(self): - for node in self.nodes: - node.slithir_generation() - - for modifier_statement in self.modifiers_statements: - modifier_statement.node.slithir_generation() - - for modifier_statement in self.explicit_base_constructor_calls_statements: - modifier_statement.node.slithir_generation() - - self._analyze_read_write() - self._analyze_calls() - - def generate_slithir_ssa(self, all_ssa_state_variables_instances): - compute_dominators(self.nodes) - compute_dominance_frontier(self.nodes) - transform_slithir_vars_to_ssa(self) - add_ssa_ir(self, all_ssa_state_variables_instances) - - def update_read_write_using_ssa(self): - for node in self.nodes: - node.update_read_write_using_ssa() - self._analyze_read_write() diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index 6d11e2cc8..e6ac1c932 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -373,10 +373,14 @@ class SlitherSolc(Slither): contract.analyze_content_modifiers() contract.analyze_content_functions() + + contract.set_is_analyzed(True) def _convert_to_slithir(self): + for contract in self.contracts: + contract.add_constructor_variables() contract.convert_expression_to_slithir() self._propagate_function_calls() for contract in self.contracts: