Improve state variable SSA conversion

WIP of data dependency based on SSA
pull/87/head
Josselin 6 years ago
parent d33c3e8b83
commit cacf3cf483
  1. 0
      slither/analyses/data_depencency/__init__.py
  2. 105
      slither/analyses/data_depencency/data_depency.py
  3. 11
      slither/slithir/operations/init_array.py
  4. 11
      slither/slithir/operations/phi.py
  5. 51
      slither/slithir/utils/ssa.py
  6. 2
      slither/slithir/variables/local_variable.py
  7. 4
      slither/solc_parsing/declarations/contract.py
  8. 4
      slither/solc_parsing/declarations/function.py
  9. 4
      slither/solc_parsing/slitherSolc.py

@ -0,0 +1,105 @@
"""
Compute the data depenency between all the SSA variables
"""
from slither.slithir.operations import Index, Member, OperationWithLValue
from slither.slithir.variables import ReferenceVariable, Constant
from slither.slithir.variables import (Constant, LocalIRVariable, StateIRVariable,
ReferenceVariable, TemporaryVariable,
TupleVariable)
KEY = "DATA_DEPENDENCY_SSA"
KEY_NON_SSA = "DATA_DEPENDENCY"
def compute_dependency(slither):
for contract in slither.contracts:
compute_dependency_contract(contract)
def compute_dependency_contract(contract):
if KEY in contract.context:
return
contract.context[KEY] = dict()
for function in contract.all_functions_called:
compute_dependency_function(function)
data_depencencies = function.context[KEY]
for (key, values) in data_depencencies.items():
if not key in contract.context[KEY]:
contract.context[KEY][key] = set(values)
else:
contract.context[KEY][key].union(values)
# transitive closure
changed = True
while changed:
changed = False
# Need to create new set() as its changed during iteration
data_depencencies = {k: set([v for v in values]) for k, values in contract.context[KEY].items()}
for key, items in data_depencencies.items():
for item in items:
if item in data_depencencies:
additional_items = contract.context[KEY][item]
for additional_item in additional_items:
if not additional_item in items and additional_item != key:
changed = True
contract.context[KEY][key].add(additional_item)
contract.context[KEY_NON_SSA] = convert_to_non_ssa(contract.context[KEY])
def compute_dependency_function(function):
if KEY in function.context:
return function.context[KEY]
function.context[KEY] = dict()
for node in function.nodes:
for ir in node.irs_ssa:
if isinstance(ir, OperationWithLValue) and ir.lvalue:
lvalue = ir.lvalue
# if isinstance(ir.lvalue, ReferenceVariable):
# lvalue = lvalue.points_to_origin
# # TODO fix incorrect points_to for BALANCE
# if not lvalue:
# continue
if not lvalue in function.context[KEY]:
function.context[KEY][lvalue] = set()
if isinstance(ir, Index):
read = [ir.variable_left]
else:
read = ir.read
[function.context[KEY][lvalue].add(v) for v in read if not isinstance(v, Constant)]
function.context[KEY_NON_SSA] = convert_to_non_ssa(function.context[KEY])
def valid_non_ssa(v):
if isinstance(v, (TemporaryVariable,
ReferenceVariable,
TupleVariable)):
return False
return True
def convert_variable_to_non_ssa(v):
if isinstance(v, (LocalIRVariable, StateIRVariable)):
if isinstance(v, LocalIRVariable):
function = v.function
return function.get_local_variable_from_name(v.name)
else:
contract = v.contract
return contract.get_state_variable_from_name(v.name)
return v
def convert_to_non_ssa(data_depencies):
# Need to create new set() as its changed during iteration
ret = dict()
for (k, values) in data_depencies.items():
if not valid_non_ssa(k):
continue
var = convert_variable_to_non_ssa(k)
if not var in ret:
ret[var] = set()
ret[var] = ret[var].union(set([convert_variable_to_non_ssa(v) for v in
values if valid_non_ssa(v)]))
return ret

@ -21,18 +21,19 @@ class InitArray(OperationWithLValue):
self._init_values = init_values self._init_values = init_values
self._lvalue = lvalue self._lvalue = lvalue
@property
def read(self):
# if array inside the init values # if array inside the init values
def unroll(l): def _unroll(self, l):
ret = [] ret = []
for x in l: for x in l:
if not isinstance(x, list): if not isinstance(x, list):
ret += [x] ret += [x]
else: else:
ret += unroll(x) ret += self._unroll(x)
return ret return ret
return unroll(self.init_values)
@property
def read(self):
return self._unroll(self.init_values)
@property @property
def init_values(self): def init_values(self):

@ -21,9 +21,18 @@ class Phi(OperationWithLValue):
self._rvalues = [] self._rvalues = []
self._nodes = nodes self._nodes = nodes
def _unroll(self, l):
ret = []
for x in l:
if not isinstance(x, list):
ret += [x]
else:
ret += self._unroll(x)
return ret
@property @property
def read(self): def read(self):
return [self.rvalues] return self.rvalues
@property @property
def rvalues(self): def rvalues(self):

@ -42,13 +42,12 @@ def transform_slithir_vars_to_ssa(function):
for idx in range(len(tuple_variables)): for idx in range(len(tuple_variables)):
tuple_variables[idx].index = idx tuple_variables[idx].index = idx
def add_ssa_ir(function, all_state_variables_instances, all_state_variables_written): def add_ssa_ir(function, all_state_variables_instances):
''' '''
Add SSA version of the IR Add SSA version of the IR
Args: Args:
function function
all_state_variables_instances all_state_variables_instances
all_state_variables_written (set(str)): canonical name of all the state variables written
''' '''
if not function.is_implemented: if not function.is_implemented:
@ -106,6 +105,12 @@ def add_ssa_ir(function, all_state_variables_instances, all_state_variables_writ
all_state_variables_instances, all_state_variables_instances,
init_local_variables_instances, init_local_variables_instances,
[]) [])
fix_phi_rvalues_and_storage_ref(function.entry_point,
dict(init_local_variables_instances),
all_init_local_variables_instances,
dict(init_state_variables_instances),
all_state_variables_instances,
init_local_variables_instances)
def last_name(n, var, init_vars): def last_name(n, var, init_vars):
@ -188,7 +193,7 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
if node in visited: if node in visited:
return return
if node.fathers and any(not father in visited for father in node.fathers): if node.type in [NodeType.ENDIF, NodeType.ENDLOOP] and any(not father in visited for father in node.fathers):
return return
# visited is shared # visited is shared
@ -203,8 +208,14 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
assert isinstance(ir, Phi) assert isinstance(ir, Phi)
update_lvalue(ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances) update_lvalue(ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances)
# these variables are lived only during the liveness of the block
# They dont need phi function
temporary_variables_instances = dict()
reference_variables_instances = dict()
for ir in node.irs: for ir in node.irs:
new_ir = copy_ir(ir, local_variables_instances, state_variables_instances) new_ir = copy_ir(ir, local_variables_instances, state_variables_instances, temporary_variables_instances, reference_variables_instances)
update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances) update_lvalue(new_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances)
if new_ir: if new_ir:
@ -233,6 +244,14 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
else: else:
new_ir.lvalue.add_points_to(new_ir.rvalue) new_ir.lvalue.add_points_to(new_ir.rvalue)
for succ in node.dominator_successors:
generate_ssa_irs(succ, dict(local_variables_instances), all_local_variables_instances, dict(state_variables_instances), all_state_variables_instances, init_local_variables_instances, visited)
for dominated in node.dominance_frontier:
generate_ssa_irs(dominated, dict(local_variables_instances), all_local_variables_instances, dict(state_variables_instances), all_state_variables_instances, init_local_variables_instances, visited)
def fix_phi_rvalues_and_storage_ref(node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances, init_local_variables_instances):
for ir in node.irs_ssa: for ir in node.irs_ssa:
if isinstance(ir, (Phi)) and not ir.rvalues: if isinstance(ir, (Phi)) and not ir.rvalues:
variables = [last_name(dst, ir.lvalue, init_local_variables_instances) for dst in ir.nodes] variables = [last_name(dst, ir.lvalue, init_local_variables_instances) for dst in ir.nodes]
@ -255,13 +274,10 @@ def generate_ssa_irs(node, local_variables_instances, all_local_variables_instan
phi_ir.rvalues = [origin] phi_ir.rvalues = [origin]
node.add_ssa_ir(phi_ir) node.add_ssa_ir(phi_ir)
update_lvalue(phi_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances) update_lvalue(phi_ir, node, local_variables_instances, all_local_variables_instances, state_variables_instances, all_state_variables_instances)
for succ in node.dominator_successors: for succ in node.dominator_successors:
generate_ssa_irs(succ, dict(local_variables_instances), all_local_variables_instances, dict(state_variables_instances), all_state_variables_instances, init_local_variables_instances, visited) fix_phi_rvalues_and_storage_ref(succ, dict(local_variables_instances), all_local_variables_instances, dict(state_variables_instances), all_state_variables_instances, init_local_variables_instances)
for dominated in node.dominance_frontier:
generate_ssa_irs(dominated, dict(local_variables_instances), all_local_variables_instances, dict(state_variables_instances), all_state_variables_instances, init_local_variables_instances, visited)
def add_phi_origins(node, local_variables_definition, state_variables_definition): def add_phi_origins(node, local_variables_definition, state_variables_definition):
@ -294,11 +310,16 @@ def add_phi_origins(node, local_variables_definition, state_variables_definition
for succ in node.dominator_successors: for succ in node.dominator_successors:
add_phi_origins(succ, local_variables_definition, state_variables_definition) add_phi_origins(succ, local_variables_definition, state_variables_definition)
def copy_ir(ir, local_variables_instances, state_variables_instances): def copy_ir(ir, local_variables_instances, state_variables_instances, temporary_variables_instances, reference_variables_instances):
''' '''
Args: Args:
ir (Operation) ir (Operation)
variables_instances(dict(str -> Variable)) local_variables_instances(dict(str -> LocalVariable))
state_variables_instances(dict(str -> StateVariable))
temporary_variables_instances(dict(int -> Variable))
reference_variables_instances(dict(int -> Variable))
Note: temporary and reference can be indexed by int, as they dont need phi functions
''' '''
def get(variable): def get(variable):
@ -307,15 +328,19 @@ def copy_ir(ir, local_variables_instances, state_variables_instances):
if isinstance(variable, StateVariable) and variable.canonical_name in state_variables_instances: if isinstance(variable, StateVariable) and variable.canonical_name in state_variables_instances:
return state_variables_instances[variable.canonical_name] return state_variables_instances[variable.canonical_name]
elif isinstance(variable, ReferenceVariable): elif isinstance(variable, ReferenceVariable):
if not variable.index in reference_variables_instances:
new_variable = ReferenceVariable(variable.node, index=variable.index) new_variable = ReferenceVariable(variable.node, index=variable.index)
if variable.points_to: if variable.points_to:
new_variable.points_to = get(variable.points_to) new_variable.points_to = get(variable.points_to)
new_variable.set_type(variable.type) new_variable.set_type(variable.type)
return new_variable reference_variables_instances[variable.index] = new_variable
return reference_variables_instances[variable.index]
elif isinstance(variable, TemporaryVariable): elif isinstance(variable, TemporaryVariable):
if not variable.index in temporary_variables_instances:
new_variable = TemporaryVariable(variable.node, index=variable.index) new_variable = TemporaryVariable(variable.node, index=variable.index)
new_variable.set_type(variable.type) new_variable.set_type(variable.type)
return new_variable temporary_variables_instances[variable.index] = new_variable
return temporary_variables_instances[variable.index]
return variable return variable
def get_variable(ir, f): def get_variable(ir, f):

@ -58,4 +58,4 @@ class LocalIRVariable(LocalVariable, SlithIRVariable):
return '{}_{} (-> {})'.format(self._name, return '{}_{} (-> {})'.format(self._name,
self.index, self.index,
[v.name for v in self.points_to]) [v.name for v in self.points_to])
return '{}_{} ({})'.format(self._name, self.index, self.location) return '{}_{}'.format(self._name, self.index)

@ -355,7 +355,6 @@ class ContractSolc04(Contract):
func.generate_slithir_and_analyze() func.generate_slithir_and_analyze()
all_ssa_state_variables_instances = dict() all_ssa_state_variables_instances = dict()
all_state_variables_written = {v.canonical_name for v in self.all_state_variables_written}
for contract in self.inheritance: for contract in self.inheritance:
for v in contract.variables: for v in contract.variables:
@ -372,8 +371,7 @@ class ContractSolc04(Contract):
for func in self.functions + self.modifiers: for func in self.functions + self.modifiers:
if func.contract == self: if func.contract == self:
func.generate_slithir_ssa(all_ssa_state_variables_instances, func.generate_slithir_ssa(all_ssa_state_variables_instances)
all_state_variables_written)
def fix_phi(self): def fix_phi(self):
last_state_variables_instances = dict() last_state_variables_instances = dict()

@ -909,11 +909,11 @@ class FunctionSolc(Function):
self._analyze_read_write() self._analyze_read_write()
self._analyze_calls() self._analyze_calls()
def generate_slithir_ssa(self, all_ssa_state_variables_instances, all_written_state_variables): def generate_slithir_ssa(self, all_ssa_state_variables_instances):
compute_dominators(self.nodes) compute_dominators(self.nodes)
compute_dominance_frontier(self.nodes) compute_dominance_frontier(self.nodes)
transform_slithir_vars_to_ssa(self) transform_slithir_vars_to_ssa(self)
add_ssa_ir(self, all_ssa_state_variables_instances, all_written_state_variables) add_ssa_ir(self, all_ssa_state_variables_instances)
def update_read_write_using_ssa(self): def update_read_write_using_ssa(self):
for node in self.nodes: for node in self.nodes:

@ -9,6 +9,7 @@ from slither.solc_parsing.declarations.contract import ContractSolc04
from slither.core.slither_core import Slither from slither.core.slither_core import Slither
from slither.core.declarations.pragma_directive import Pragma from slither.core.declarations.pragma_directive import Pragma
from slither.core.declarations.import_directive import Import from slither.core.declarations.import_directive import Import
from slither.analyses.data_depencency.data_depency import compute_dependency
class SlitherSolc(Slither): class SlitherSolc(Slither):
@ -190,6 +191,8 @@ class SlitherSolc(Slither):
self._convert_to_slithir() self._convert_to_slithir()
compute_dependency(self)
# TODO refactor the following functions, and use a lambda function # TODO refactor the following functions, and use a lambda function
@property @property
@ -317,3 +320,4 @@ class SlitherSolc(Slither):
contract.fix_phi() contract.fix_phi()
contract.update_read_write_using_ssa() contract.update_read_write_using_ssa()

Loading…
Cancel
Save