Add a new type of function: constructor_variable to hold state variable initialization

Add Function.FunctionType to determine if the function is a constructor/fallback/constructor_variables
API change: function.is_fallback() becomes a property
pull/293/head
Josselin 5 years ago
parent 22f8ed731d
commit 8da3b86a30
  1. 209
      slither/core/declarations/function.py
  2. 47
      slither/solc_parsing/declarations/contract.py
  3. 142
      slither/solc_parsing/declarations/function.py
  4. 4
      slither/solc_parsing/slitherSolc.py

@ -4,6 +4,7 @@
import logging
from collections import namedtuple
from itertools import groupby
from enum import Enum
from slither.core.children.child_contract import ChildContract
from slither.core.children.child_inheritance import ChildInheritance
@ -13,8 +14,9 @@ from slither.core.declarations.solidity_variables import (SolidityFunction,
from slither.core.expressions import (Identifier, IndexAccess, MemberAccess,
UnaryOperation)
from slither.core.source_mapping.source_mapping import SourceMapping
from slither.core.variables.state_variable import StateVariable
from slither.core.variables.state_variable import StateVariable
from slither.utils.utils import unroll
logger = logging.getLogger("Function")
@ -22,6 +24,12 @@ ReacheableNode = namedtuple('ReacheableNode', ['node', 'ir'])
ModifierStatements = namedtuple('Modifier', ['modifier', 'node'])
class FunctionType(Enum):
NORMAL = 0
CONSTRUCTOR = 1
FALLBACK = 2
CONSTRUCTOR_VARIABLES = 3 # Fake function to hold variable declaration statements
class Function(ChildContract, ChildInheritance, SourceMapping):
"""
Function class
@ -34,7 +42,7 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
self._pure = None
self._payable = None
self._visibility = None
self._is_constructor = None
self._is_implemented = None
self._is_empty = None
self._entry_point = None
@ -94,6 +102,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
self._reachable_from_nodes = set()
self._reachable_from_functions = set()
# Constructor, fallback, State variable constructor
self._function_type = None
self._is_constructor = None
###################################################################################
###################################################################################
@ -106,11 +117,12 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
str: function name
"""
if self._name == '':
if self.is_constructor:
return 'constructor'
else:
return 'fallback'
if self._function_type == FunctionType.CONSTRUCTOR:
return 'constructor'
elif self._function_type == FunctionType.FALLBACK:
return 'fallback'
elif self._function_type == FunctionType.CONSTRUCTOR_VARIABLES:
return 'slither_constructor_variables'
return self._name
@property
@ -131,12 +143,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
name, parameters, _ = self.signature
return self.contract_declarer.name + '.' + name + '(' + ','.join(parameters) + ')'
@property
def is_constructor(self):
"""
bool: True if the function is the constructor
"""
return self._is_constructor or self._name == self.contract_declarer.name
@property
def contains_assembly(self):
@ -154,6 +160,41 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
return self.contract_declarer == contract
# endregion
###################################################################################
###################################################################################
# region Type (FunctionType)
###################################################################################
###################################################################################
def set_function_type(self, t):
assert isinstance(t, FunctionType)
self._function_type = t
@property
def is_constructor(self):
"""
bool: True if the function is the constructor
"""
return self._function_type == FunctionType.CONSTRUCTOR
@property
def is_constructor_variables(self):
"""
bool: True if the function is the constructor of the variables
Slither has a inbuilt function to hold the state variables initialization
"""
return self._function_type == FunctionType.CONSTRUCTOR_VARIABLES
@property
def is_fallback(self):
"""
Determine if the function is the fallback function for the contract
Returns
(bool)
"""
return self._function_type == FunctionType.FALLBACK
# endregion
###################################################################################
###################################################################################
@ -182,6 +223,9 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
return self._visibility
def set_visibility(self, v):
self._visibility = v
@property
def view(self):
"""
@ -248,6 +292,11 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
"""
return self._entry_point
def add_node(self, node):
if not self._entry_point:
self._entry_point = node
self._nodes.append(node)
# endregion
###################################################################################
###################################################################################
@ -1005,13 +1054,6 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
args_vars = self.all_solidity_variables_used_as_args()
return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars
def is_fallback(self):
"""
Determine if the function is the fallback function for the contract
Returns
(bool)
"""
return self._name == "" and not self.is_constructor
# endregion
###################################################################################
@ -1119,6 +1161,133 @@ class Function(ChildContract, ChildInheritance, SourceMapping):
# endregion
###################################################################################
###################################################################################
# region SlithIr and SSA
###################################################################################
###################################################################################
def get_last_ssa_state_variables_instances(self):
from slither.slithir.variables import ReferenceVariable
from slither.slithir.operations import OperationWithLValue
from slither.core.cfg.node import NodeType
if not self.is_implemented:
return dict()
# node, values
to_explore = [(self._entry_point, dict())]
# node -> values
explored = dict()
# name -> instances
ret = dict()
while to_explore:
node, values = to_explore[0]
to_explore = to_explore[1::]
if node.type != NodeType.ENTRYPOINT:
for ir_ssa in node.irs_ssa:
if isinstance(ir_ssa, OperationWithLValue):
lvalue = ir_ssa.lvalue
if isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to_origin
if isinstance(lvalue, StateVariable):
values[lvalue.canonical_name] = {lvalue}
# Check for fixpoint
if node in explored:
if values == explored[node]:
continue
for k, instances in values.items():
if not k in explored[node]:
explored[node][k] = set()
explored[node][k] |= instances
values = explored[node]
else:
explored[node] = values
# Return condition
if not node.sons and node.type != NodeType.THROW:
for name, instances in values.items():
if name not in ret:
ret[name] = set()
ret[name] |= instances
for son in node.sons:
to_explore.append((son, dict(values)))
return ret
@staticmethod
def _unchange_phi(ir):
from slither.slithir.operations import (Phi, PhiCallback)
if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1:
return False
if not ir.rvalues:
return True
return ir.rvalues[0] == ir.lvalue
def fix_phi(self, last_state_variables_instances, initial_state_variables_instances):
from slither.slithir.operations import (InternalCall, PhiCallback)
from slither.slithir.variables import (Constant, StateIRVariable)
for node in self.nodes:
for ir in node.irs_ssa:
if node == self.entry_point:
if isinstance(ir.lvalue, StateIRVariable):
additional = [initial_state_variables_instances[ir.lvalue.canonical_name]]
additional += last_state_variables_instances[ir.lvalue.canonical_name]
ir.rvalues = list(set(additional + ir.rvalues))
# function parameter
else:
# find index of the parameter
idx = self.parameters.index(ir.lvalue.non_ssa_version)
# find non ssa version of that index
additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes]
additional = unroll(additional)
additional = [a for a in additional if not isinstance(a, Constant)]
ir.rvalues = list(set(additional + ir.rvalues))
if isinstance(ir, PhiCallback):
callee_ir = ir.callee_ir
if isinstance(callee_ir, InternalCall):
last_ssa = callee_ir.function.get_last_ssa_state_variables_instances()
if ir.lvalue.canonical_name in last_ssa:
ir.rvalues = list(last_ssa[ir.lvalue.canonical_name])
else:
ir.rvalues = [ir.lvalue]
else:
additional = last_state_variables_instances[ir.lvalue.canonical_name]
ir.rvalues = list(set(additional + ir.rvalues))
node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)]
def generate_slithir_and_analyze(self):
for node in self.nodes:
node.slithir_generation()
for modifier_statement in self.modifiers_statements:
modifier_statement.node.slithir_generation()
for modifier_statement in self.explicit_base_constructor_calls_statements:
modifier_statement.node.slithir_generation()
self._analyze_read_write()
self._analyze_calls()
def generate_slithir_ssa(self, all_ssa_state_variables_instances):
from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa
from slither.core.dominators.utils import (compute_dominance_frontier,
compute_dominators)
compute_dominators(self.nodes)
compute_dominance_frontier(self.nodes)
transform_slithir_vars_to_ssa(self)
add_ssa_ir(self, all_ssa_state_variables_instances)
def update_read_write_using_ssa(self):
for node in self.nodes:
node.update_read_write_using_ssa()
self._analyze_read_write()
###################################################################################
###################################################################################
# region Built in definitions

@ -1,7 +1,10 @@
import logging
from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function, FunctionType
from slither.core.declarations.enum import Enum
from slither.core.cfg.node import Node, NodeType
from slither.core.expressions import AssignmentOperation, Identifier, AssignmentOperationType
from slither.slithir.variables import StateIRVariable
from slither.solc_parsing.declarations.event import EventSolc
from slither.solc_parsing.declarations.function import FunctionSolc
@ -319,7 +322,6 @@ class ContractSolc04(Contract):
:return:
"""
all_elements = {}
accessible_elements = {}
for father in self.inheritance:
for element in getter(father):
@ -368,6 +370,49 @@ class ContractSolc04(Contract):
pass
return
def _create_node(self, func, counter, variable):
# Function uses to create node for state variable declaration statements
node = Node(NodeType.STANDALONE, counter)
node.set_offset(variable.source_mapping, self.slither)
node.set_function(func)
func.add_node(node)
print(variable.expression)
expression = AssignmentOperation(Identifier(variable),
variable.expression,
AssignmentOperationType.ASSIGN,
variable.type)
node.add_expression(expression)
return node
def add_constructor_variables(self):
if self.state_variables:
found_candidate = False
for (idx, variable_candidate) in enumerate(self.state_variables):
if variable_candidate.expression and not variable_candidate.is_constant:
found_candidate = True
break
if found_candidate:
constructor_variable = Function()
constructor_variable.set_function_type(FunctionType.CONSTRUCTOR_VARIABLES)
constructor_variable.set_contract(self)
constructor_variable.set_contract_declarer(self)
constructor_variable.set_visibility('internal')
self._functions[constructor_variable.canonical_name] = constructor_variable
prev_node = self._create_node(constructor_variable, 0, variable_candidate)
counter = 1
for v in self.state_variables[idx+1:]:
if v.expression and not v.is_constant:
next_node = self._create_node(constructor_variable, counter, v)
prev_node.add_son(next_node)
next_node.add_father(prev_node)
counter += 1
def analyze_state_variables(self):
for var in self.variables:
var.analyze(self)

@ -4,16 +4,10 @@ import logging
from slither.core.cfg.node import NodeType, link_nodes
from slither.core.declarations.contract import Contract
from slither.core.declarations.function import Function, ModifierStatements
from slither.core.dominators.utils import (compute_dominance_frontier,
compute_dominators)
from slither.core.declarations.function import Function, ModifierStatements, FunctionType
from slither.core.expressions import AssignmentOperation
from slither.core.variables.state_variable import StateVariable
from slither.slithir.operations import (InternalCall, OperationWithLValue, Phi,
PhiCallback)
from slither.slithir.utils.ssa import add_ssa_ir, transform_slithir_vars_to_ssa
from slither.slithir.variables import (Constant, ReferenceVariable,
StateIRVariable)
from slither.solc_parsing.cfg.node import NodeSolc
from slither.solc_parsing.expressions.expression_parsing import \
parse_expression
@ -23,7 +17,6 @@ from slither.solc_parsing.variables.local_variable_init_from_tuple import \
from slither.solc_parsing.variables.variable_declaration import \
MultipleVariablesDeclaration
from slither.utils.expression_manipulations import SplitTernaryExpression
from slither.utils.utils import unroll
from slither.visitors.expression.export_values import ExportValues
from slither.visitors.expression.has_conditional import HasConditional
from slither.solc_parsing.exceptions import ParsingError
@ -139,14 +132,20 @@ class FunctionSolc(Function):
if 'constant' in attributes:
self._view = attributes['constant']
self._is_constructor = False
if self._name == '':
self._function_type = FunctionType.FALLBACK
else:
self._function_type = FunctionType.NORMAL
if self._name == self.contract_declarer.name:
self._function_type = FunctionType.CONSTRUCTOR
if 'isConstructor' in attributes:
self._is_constructor = attributes['isConstructor']
if 'isConstructor' in attributes and attributes['isConstructor']:
self._function_type = FunctionType.CONSTRUCTOR
if 'kind' in attributes:
if attributes['kind'] == 'constructor':
self._is_constructor = True
self._function_type = FunctionType.CONSTRUCTOR
if 'visibility' in attributes:
self._visibility = attributes['visibility']
@ -1038,120 +1037,5 @@ class FunctionSolc(Function):
# endregion
###################################################################################
###################################################################################
# region SlithIr and SSA
###################################################################################
###################################################################################
def get_last_ssa_state_variables_instances(self):
if not self.is_implemented:
return dict()
# node, values
to_explore = [(self._entry_point, dict())]
# node -> values
explored = dict()
# name -> instances
ret = dict()
while to_explore:
node, values = to_explore[0]
to_explore = to_explore[1::]
if node.type != NodeType.ENTRYPOINT:
for ir_ssa in node.irs_ssa:
if isinstance(ir_ssa, OperationWithLValue):
lvalue = ir_ssa.lvalue
if isinstance(lvalue, ReferenceVariable):
lvalue = lvalue.points_to_origin
if isinstance(lvalue, StateVariable):
values[lvalue.canonical_name] = {lvalue}
# Check for fixpoint
if node in explored:
if values == explored[node]:
continue
for k, instances in values.items():
if not k in explored[node]:
explored[node][k] = set()
explored[node][k] |= instances
values = explored[node]
else:
explored[node] = values
# Return condition
if not node.sons and node.type != NodeType.THROW:
for name, instances in values.items():
if name not in ret:
ret[name] = set()
ret[name] |= instances
for son in node.sons:
to_explore.append((son, dict(values)))
return ret
@staticmethod
def _unchange_phi(ir):
if not isinstance(ir, (Phi, PhiCallback)) or len(ir.rvalues) > 1:
return False
if not ir.rvalues:
return True
return ir.rvalues[0] == ir.lvalue
def fix_phi(self, last_state_variables_instances, initial_state_variables_instances):
for node in self.nodes:
for ir in node.irs_ssa:
if node == self.entry_point:
if isinstance(ir.lvalue, StateIRVariable):
additional = [initial_state_variables_instances[ir.lvalue.canonical_name]]
additional += last_state_variables_instances[ir.lvalue.canonical_name]
ir.rvalues = list(set(additional + ir.rvalues))
# function parameter
else:
# find index of the parameter
idx = self.parameters.index(ir.lvalue.non_ssa_version)
# find non ssa version of that index
additional = [n.ir.arguments[idx] for n in self.reachable_from_nodes]
additional = unroll(additional)
additional = [a for a in additional if not isinstance(a, Constant)]
ir.rvalues = list(set(additional + ir.rvalues))
if isinstance(ir, PhiCallback):
callee_ir = ir.callee_ir
if isinstance(callee_ir, InternalCall):
last_ssa = callee_ir.function.get_last_ssa_state_variables_instances()
if ir.lvalue.canonical_name in last_ssa:
ir.rvalues = list(last_ssa[ir.lvalue.canonical_name])
else:
ir.rvalues = [ir.lvalue]
else:
additional = last_state_variables_instances[ir.lvalue.canonical_name]
ir.rvalues = list(set(additional + ir.rvalues))
node.irs_ssa = [ir for ir in node.irs_ssa if not self._unchange_phi(ir)]
def generate_slithir_and_analyze(self):
for node in self.nodes:
node.slithir_generation()
for modifier_statement in self.modifiers_statements:
modifier_statement.node.slithir_generation()
for modifier_statement in self.explicit_base_constructor_calls_statements:
modifier_statement.node.slithir_generation()
self._analyze_read_write()
self._analyze_calls()
def generate_slithir_ssa(self, all_ssa_state_variables_instances):
compute_dominators(self.nodes)
compute_dominance_frontier(self.nodes)
transform_slithir_vars_to_ssa(self)
add_ssa_ir(self, all_ssa_state_variables_instances)
def update_read_write_using_ssa(self):
for node in self.nodes:
node.update_read_write_using_ssa()
self._analyze_read_write()

@ -373,10 +373,14 @@ class SlitherSolc(Slither):
contract.analyze_content_modifiers()
contract.analyze_content_functions()
contract.set_is_analyzed(True)
def _convert_to_slithir(self):
for contract in self.contracts:
contract.add_constructor_variables()
contract.convert_expression_to_slithir()
self._propagate_function_calls()
for contract in self.contracts:

Loading…
Cancel
Save