SlithIR: improve support for type

pull/20/head
Josselin 6 years ago
parent 53d9a9de71
commit 10b3b4037d
  1. 2
      slither/core/declarations/solidity_variables.py
  2. 383
      slither/slithir/convert.py
  3. 2
      slither/slithir/operations/binary.py
  4. 5
      slither/slithir/operations/high_level_call.py
  5. 4
      slither/slithir/operations/internal_call.py
  6. 11
      slither/slithir/operations/library_call.py
  7. 5
      slither/slithir/operations/return_operation.py
  8. 1
      slither/solc_parsing/slitherSolc.py

@ -104,7 +104,7 @@ class SolidityVariableComposed(SolidityVariable):
@property @property
def type(self): def type(self):
return SOLIDITY_VARIABLES_COMPOSED[self.name] return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name])
def __str__(self): def __str__(self):
return self._name return self._name

@ -116,7 +116,9 @@ def propage_type_and_convert_call(result, node):
call_data = [] call_data = []
for idx in range(len(result)): idx = 0
# use of while len() as result can be modified during the iteration
while idx < len(result):
ins = result[idx] ins = result[idx]
if isinstance(ins, TmpCall): if isinstance(ins, TmpCall):
@ -144,11 +146,165 @@ def propage_type_and_convert_call(result, node):
if isinstance(ins, (Call, NewContract, NewStructure)): if isinstance(ins, (Call, NewContract, NewStructure)):
ins.arguments = call_data ins.arguments = call_data
propagate_types(ins, node) call_data = []
if is_temporary(ins):
del result[idx]
continue
new_ins = propagate_types(ins, node)
if new_ins:
if isinstance(new_ins, (list,)):
assert len(new_ins) == 2
result.insert(idx, new_ins[0])
result.insert(idx+1, new_ins[1])
idx = idx + 1
else:
result[idx] = new_ins
idx = idx +1
return result return result
def convert_to_low_level(ir):
"""
Convert to a transfer/send/or low level call
The funciton assume to receive a correct IR
The checks must be done by the caller
"""
if ir.function_name == 'transfer':
assert len(ir.arguments) == 1
ir = Transfer(ir.destination, ir.arguments[0])
return ir
elif ir.function_name == 'send':
assert len(ir.arguments) == 1
ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
ir.lvalue.set_type(ElementaryType('bool'))
return ir
elif ir.function_name in ['call', 'delegatecall', 'callcode']:
new_ir = LowLevelCall(ir.destination,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
new_ir.call_gas = ir.call_gas
new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool'))
return new_ir
logger.error('Incorrect conversion to low level {}'.format(ir))
exit(0)
def convert_to_push(ir):
"""
Convert a call to a PUSH operaiton
The funciton assume to receive a correct IR
The checks must be done by the caller
May necessitate to create an intermediate operation (InitArray)
As a result, the function return may return a list
"""
if isinstance(ir.arguments[0], list):
ret = []
val = TemporaryVariable()
operation = InitArray(ir.arguments[0], val)
ret.append(operation)
ir = Push(ir.destination, val)
length = len(operation.init_values)
t = operation.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
ret.insert(ir)
return ret
ir = Push(ir.destination, ir.arguments[0])
return ir
def convert_to_library(ir, node, using_for):
contract = node.function.contract
t = ir.destination.type
for destination in using_for[t]:
lib_contract = contract.slither.get_contract_from_name(str(destination))
if destination:
lib_call = LibraryCall(lib_contract,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
prev = ir
ir = lib_call
sig = '{}({})'.format(ir.function_name,
','.join([str(x.type) for x in ir.arguments]))
func = lib_contract.get_function_from_signature(sig)
if not func:
func = lib_contract.get_state_variable_from_name(ir.function_name)
assert func
ir.function = func
if isinstance(func, Function):
t = func.return_type
# if its not a tuple, return a singleton
if len(t) == 1:
t = t[0]
else:
# otherwise its a variable (getter)
t = func.type
if t:
ir.lvalue.set_type(t)
else:
ir.lvalue = None
return ir
logger.error('Library not found {}'.format(ir))
exit(0)
def get_type(t):
"""
Convert a type to a str
If the instance is a Contract, return 'address' instead
"""
if isinstance(t, UserDefinedType):
if isinstance(t.type, Contract):
return 'address'
return str(t)
def convert_type_of_high_level_call(ir, contract):
sig = '{}({})'.format(ir.function_name,
','.join([get_type(x.type) for x in ir.arguments]))
func = contract.get_function_from_signature(sig)
if not func:
func = contract.get_state_variable_from_name(ir.function_name)
else:
return_type = func.return_type
# if its not a tuple; return a singleton
if return_type and len(return_type) == 1:
return_type = return_type[0]
if not func and ir.function_name in ['call',
'delegatecall',
'codecall',
'transfer',
'send']:
return convert_to_low_level(ir)
if not func:
logger.error('Function not found {}'.format(sig))
ir.function = func
if isinstance(func, Function):
t = return_type
else:
# otherwise its a variable (getter)
t = func.type
if t:
ir.lvalue.set_type(t)
else:
ir.lvalue = None
return None
def propagate_types(ir, node): def propagate_types(ir, node):
# propagate the type # propagate the type
using_for = node.function.contract.using_for
if isinstance(ir, OperationWithLValue): if isinstance(ir, OperationWithLValue):
if not ir.lvalue.type: if not ir.lvalue.type:
if isinstance(ir, Assignment): if isinstance(ir, Assignment):
@ -157,52 +313,47 @@ def propagate_types(ir, node):
if BinaryType.return_bool(ir.type): if BinaryType.return_bool(ir.type):
ir.lvalue.set_type(ElementaryType('bool')) ir.lvalue.set_type(ElementaryType('bool'))
else: else:
ir.lvalue.set_type(ir.left_variable.type) ir.lvalue.set_type(ir.variable_left.type)
elif isinstance(ir, Delete): elif isinstance(ir, Delete):
# nothing to propagate # nothing to propagate
pass pass
elif isinstance(ir, HighLevelCall): elif isinstance(ir, HighLevelCall):
t = ir.destination.type t = ir.destination.type
# can be None due to temporary operation
if t: # Temporary operaiton (they are removed later)
if t is None:
return
# convert library
if t in using_for:
return convert_to_library(ir, node, using_for)
if isinstance(t, UserDefinedType): if isinstance(t, UserDefinedType):
# UserdefinedType # UserdefinedType
t = t.type t = t.type
if isinstance(t, Contract): if isinstance(t, Contract):
sig = '{}({})'.format(ir.function_name,
','.join([str(x.type) for x in ir.arguments]))
contract = node.slither.get_contract_from_name(t.name) contract = node.slither.get_contract_from_name(t.name)
func = contract.get_function_from_signature(sig) return convert_type_of_high_level_call(ir, contract)
if not func:
func = t.get_state_variable_from_name(ir.function_name)
else:
return_type = func.return_type
if not func and ir.function_name in ['call', 'delegatecall','codecall']:
return
if not func:
logger.error('Function not found {}'.format(sig))
ir.function = func
if isinstance(func, Function):
t = func.return_type
else: else:
# otherwise its a variable (getter) return None
t = func.type
if t: # Convert HighLevelCall to LowLevelCall
ir.lvalue.set_type(t) if isinstance(t, ElementaryType) and t.name == 'address':
if ir.destination.name == 'this':
return convert_type_of_high_level_call(ir, node.function.contract)
else: else:
ir.lvalue = None return convert_to_low_level(ir)
if isinstance(t, ElementaryType):
print(t.name) # Convert push operations
# TODO here handle library call # May need to insert a new operation
# we can probably directly remove the ins, as alow level # Which leads to return a list of operation
# or a lib if isinstance(t, ArrayType):
if t.name == 'address': if ir.function_name == 'push' and len(ir.arguments) == 1:
ir.lvalue.set_type(ElementaryType('bool')) return convert_to_push(ir)
elif isinstance(ir, Index): elif isinstance(ir, Index):
if isinstance(ir.variable_left.type, MappingType): if isinstance(ir.variable_left.type, MappingType):
ir.lvalue.set_type(ir.variable_left.type.type_to) ir.lvalue.set_type(ir.variable_left.type.type_to)
else: elif isinstance(ir.variable_left.type, ArrayType):
assert isinstance(ir.variable_left.type, ArrayType)
ir.lvalue.set_type(ir.variable_left.type.type) ir.lvalue.set_type(ir.variable_left.type.type)
elif isinstance(ir, InitArray): elif isinstance(ir, InitArray):
@ -210,8 +361,12 @@ def propagate_types(ir, node):
t = ir.init_values[0].type t = ir.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length)) ir.lvalue.set_type(ArrayType(t, length))
elif isinstance(ir, InternalCall): elif isinstance(ir, InternalCall):
# if its not a tuple, return a singleton
return_type = ir.function.return_type return_type = ir.function.return_type
if return_type: if return_type:
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type) ir.lvalue.set_type(return_type)
else: else:
ir.lvalue = None ir.lvalue = None
@ -253,7 +408,11 @@ def propagate_types(ir, node):
elif isinstance(ir, Send): elif isinstance(ir, Send):
ir.lvalue.set_type(ElementaryType('bool')) ir.lvalue.set_type(ElementaryType('bool'))
elif isinstance(ir, SolidityCall): elif isinstance(ir, SolidityCall):
ir.lvalue.set_type(ir.function.return_type) return_type = ir.function.return_type
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type)
elif isinstance(ir, TypeConversion): elif isinstance(ir, TypeConversion):
ir.lvalue.set_type(ir.type) ir.lvalue.set_type(ir.type)
elif isinstance(ir, Unary): elif isinstance(ir, Unary):
@ -278,8 +437,8 @@ def apply_ir_heuristics(irs, node):
irs = integrate_value_gas(irs) irs = integrate_value_gas(irs)
irs = propage_type_and_convert_call(irs, node) irs = propage_type_and_convert_call(irs, node)
irs = remove_temporary(irs) # irs = remove_temporary(irs)
irs = replace_calls(irs) # irs = replace_calls(irs)
irs = remove_unused(irs) irs = remove_unused(irs)
reset_variable_number(irs) reset_variable_number(irs)
@ -306,6 +465,12 @@ def reset_variable_number(result):
for idx in range(len(tuple_variables)): for idx in range(len(tuple_variables)):
tuple_variables[idx].index = idx tuple_variables[idx].index = idx
def is_temporary(ins):
return isinstance(ins, (Argument,
TmpNewElementaryType,
TmpNewContract,
TmpNewArray,
TmpNewStructure))
def remove_temporary(result): def remove_temporary(result):
@ -344,56 +509,56 @@ def remove_unused(result):
return result return result
def replace_calls(result): #def replace_calls(result):
''' # '''
replace call to push to a Push Operation # replace call to push to a Push Operation
Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall # Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall
''' # '''
reset = True # reset = True
def is_address(v): # def is_address(v):
if v in [SolidityVariableComposed('msg.sender'), # if v in [SolidityVariableComposed('msg.sender'),
SolidityVariableComposed('tx.origin')]: # SolidityVariableComposed('tx.origin')]:
return True # return True
if not isinstance(v, Variable): # if not isinstance(v, Variable):
return False # return False
if not isinstance(v.type, ElementaryType): # if not isinstance(v.type, ElementaryType):
return False # return False
return v.type.type == 'address' # return v.type.type == 'address'
while reset: # while reset:
reset = False # reset = False
for idx in range(len(result)): # for idx in range(len(result)):
ins = result[idx] # ins = result[idx]
if isinstance(ins, HighLevelCall): # if isinstance(ins, HighLevelCall):
# TODO better handle collision with function named push # # TODO better handle collision with function named push
if ins.function_name == 'push' and len(ins.arguments) == 1: # if ins.function_name == 'push' and len(ins.arguments) == 1:
if isinstance(ins.arguments[0], list): # if isinstance(ins.arguments[0], list):
val = TemporaryVariable() # val = TemporaryVariable()
operation = InitArray(ins.arguments[0], val) # operation = InitArray(ins.arguments[0], val)
result.insert(idx, operation) # result.insert(idx, operation)
result[idx+1] = Push(ins.destination, val) # result[idx+1] = Push(ins.destination, val)
reset = True # reset = True
break # break
else: # else:
result[idx] = Push(ins.destination, ins.arguments[0]) # result[idx] = Push(ins.destination, ins.arguments[0])
if is_address(ins.destination): # if is_address(ins.destination):
if ins.function_name == 'transfer': # if ins.function_name == 'transfer':
assert len(ins.arguments) == 1 # assert len(ins.arguments) == 1
result[idx] = Transfer(ins.destination, ins.arguments[0]) # result[idx] = Transfer(ins.destination, ins.arguments[0])
elif ins.function_name == 'send': # elif ins.function_name == 'send':
assert len(ins.arguments) == 1 # assert len(ins.arguments) == 1
result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue) # result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue)
elif ins.function_name in ['call', 'delegatecall', 'callcode']: # elif ins.function_name in ['call', 'delegatecall', 'callcode']:
# TODO: handle name collision # # TODO: handle name collision
result[idx] = LowLevelCall(ins.destination, # result[idx] = LowLevelCall(ins.destination,
ins.function_name, # ins.function_name,
ins.nbr_arguments, # ins.nbr_arguments,
ins.lvalue, # ins.lvalue,
ins.type_call) # ins.type_call)
result[idx].call_gas = ins.call_gas # result[idx].call_gas = ins.call_gas
result[idx].call_value = ins.call_value # result[idx].call_value = ins.call_value
result[idx].arguments = ins.arguments # result[idx].arguments = ins.arguments
# other case are library on address # # other case are library on address
return result # return result
def extract_tmp_call(ins): def extract_tmp_call(ins):
@ -441,28 +606,28 @@ def extract_tmp_call(ins):
raise Exception('Not extracted {} {}'.format(type(ins.called), ins)) raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
def convert_libs(result, contract): #def convert_libs(result, contract):
using_for = contract.using_for # using_for = contract.using_for
for idx in range(len(result)): # for idx in range(len(result)):
ir = result[idx] # ir = result[idx]
if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable): # if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable):
if ir.destination.type in using_for: # if ir.destination.type in using_for:
for destination in using_for[ir.destination.type]: # for destination in using_for[ir.destination.type]:
# destination is a UserDefinedType # # destination is a UserDefinedType
destination = contract.slither.get_contract_from_name(str(destination)) # destination = contract.slither.get_contract_from_name(str(destination))
if destination: # if destination:
lib_call = LibraryCall(destination, # lib_call = LibraryCall(destination,
ir.function_name, # ir.function_name,
ir.nbr_arguments, # ir.nbr_arguments,
ir.lvalue, # ir.lvalue,
ir.type_call) # ir.type_call)
lib_call.call_gas = ir.call_gas # lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments # lib_call.arguments = [ir.destination] + ir.arguments
result[idx] = lib_call # result[idx] = lib_call
break # break
assert destination # assert destination
#
return result # return result
def convert_expression(expression, node): def convert_expression(expression, node):
# handle standlone expression # handle standlone expression
@ -479,7 +644,7 @@ def convert_expression(expression, node):
result = apply_ir_heuristics(result, node) result = apply_ir_heuristics(result, node)
result = convert_libs(result, node.function.contract) # result = convert_libs(result, node.function.contract)
if result: if result:
if node.type in [NodeType.IF, NodeType.IFLOOP]: if node.type in [NodeType.IF, NodeType.IFLOOP]:

@ -165,7 +165,7 @@ class Binary(OperationWithLValue):
def __str__(self): def __str__(self):
return '{}({}) = {} {} {}'.format(str(self.lvalue), return '{}({}) = {} {} {}'.format(str(self.lvalue),
str(self.lvalue.type), self.lvalue.type,
self.variable_left, self.variable_left,
self.type_str, self.type_str,
self.variable_right) self.variable_right)

@ -98,9 +98,10 @@ class HighLevelCall(Call, OperationWithLValue):
txt = '{}HIGH_LEVEL_CALL, dest:{}({}), function:{}, arguments:{} {} {}' txt = '{}HIGH_LEVEL_CALL, dest:{}({}), function:{}, arguments:{} {} {}'
if not self.lvalue: if not self.lvalue:
lvalue = '' lvalue = ''
else: elif isinstance(self.lvalue.type, (list,)):
print(self.lvalue)
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
return txt.format(lvalue, return txt.format(lvalue,
self.destination, self.destination,
self.destination.type, self.destination.type,

@ -34,8 +34,10 @@ class InternalCall(Call, OperationWithLValue):
args = [str(a) for a in self.arguments] args = [str(a) for a in self.arguments]
if not self.lvalue: if not self.lvalue:
lvalue = '' lvalue = ''
else: elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
txt = '{}INTERNAL_CALL, {}.{}({})' txt = '{}INTERNAL_CALL, {}.{}({})'
return txt.format(lvalue, return txt.format(lvalue,
self.function.contract.name, self.function.contract.name,

@ -16,9 +16,14 @@ class LibraryCall(HighLevelCall):
arguments = [] arguments = []
if self.arguments: if self.arguments:
arguments = self.arguments arguments = self.arguments
txt = '{}({}) = LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}' if not self.lvalue:
return txt.format(self.lvalue, lvalue = ''
self.lvalue.type, elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
txt = '{}LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}'
return txt.format(lvalue,
self.destination, self.destination,
self.function_name, self.function_name,
[str(x) for x in arguments], [str(x) for x in arguments],

@ -8,7 +8,10 @@ class Return(Operation):
Only present as last operation in RETURN node Only present as last operation in RETURN node
""" """
def __init__(self, value): def __init__(self, value):
assert is_valid_rvalue(value) or isinstance(value, TupleVariable) # Note: Can return None
# ex: return call()
# where call() dont return
assert is_valid_rvalue(value) or isinstance(value, TupleVariable) or value == None
super(Return, self).__init__() super(Return, self).__init__()
self._value = value self._value = value

@ -17,7 +17,6 @@ class SlitherSolc(Slither):
self._contractsNotParsed = [] self._contractsNotParsed = []
self._contracts_by_id = {} self._contracts_by_id = {}
self._analyzed = False self._analyzed = False
print(filename)
def _parse_contracts_from_json(self, json_data): def _parse_contracts_from_json(self, json_data):
first = json_data.find('{') first = json_data.find('{')

Loading…
Cancel
Save