SlithIR: improve support for type

pull/20/head
Josselin 6 years ago
parent 53d9a9de71
commit 10b3b4037d
  1. 2
      slither/core/declarations/solidity_variables.py
  2. 401
      slither/slithir/convert.py
  3. 2
      slither/slithir/operations/binary.py
  4. 5
      slither/slithir/operations/high_level_call.py
  5. 4
      slither/slithir/operations/internal_call.py
  6. 11
      slither/slithir/operations/library_call.py
  7. 5
      slither/slithir/operations/return_operation.py
  8. 1
      slither/solc_parsing/slitherSolc.py

@ -104,7 +104,7 @@ class SolidityVariableComposed(SolidityVariable):
@property
def type(self):
return SOLIDITY_VARIABLES_COMPOSED[self.name]
return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name])
def __str__(self):
return self._name

@ -116,7 +116,9 @@ def propage_type_and_convert_call(result, node):
call_data = []
for idx in range(len(result)):
idx = 0
# use of while len() as result can be modified during the iteration
while idx < len(result):
ins = result[idx]
if isinstance(ins, TmpCall):
@ -144,11 +146,165 @@ def propage_type_and_convert_call(result, node):
if isinstance(ins, (Call, NewContract, NewStructure)):
ins.arguments = call_data
propagate_types(ins, node)
call_data = []
if is_temporary(ins):
del result[idx]
continue
new_ins = propagate_types(ins, node)
if new_ins:
if isinstance(new_ins, (list,)):
assert len(new_ins) == 2
result.insert(idx, new_ins[0])
result.insert(idx+1, new_ins[1])
idx = idx + 1
else:
result[idx] = new_ins
idx = idx +1
return result
def convert_to_low_level(ir):
"""
Convert to a transfer/send/or low level call
The funciton assume to receive a correct IR
The checks must be done by the caller
"""
if ir.function_name == 'transfer':
assert len(ir.arguments) == 1
ir = Transfer(ir.destination, ir.arguments[0])
return ir
elif ir.function_name == 'send':
assert len(ir.arguments) == 1
ir = Send(ir.destination, ir.arguments[0], ir.lvalue)
ir.lvalue.set_type(ElementaryType('bool'))
return ir
elif ir.function_name in ['call', 'delegatecall', 'callcode']:
new_ir = LowLevelCall(ir.destination,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
new_ir.call_gas = ir.call_gas
new_ir.call_value = ir.call_value
new_ir.arguments = ir.arguments
new_ir.lvalue.set_type(ElementaryType('bool'))
return new_ir
logger.error('Incorrect conversion to low level {}'.format(ir))
exit(0)
def convert_to_push(ir):
"""
Convert a call to a PUSH operaiton
The funciton assume to receive a correct IR
The checks must be done by the caller
May necessitate to create an intermediate operation (InitArray)
As a result, the function return may return a list
"""
if isinstance(ir.arguments[0], list):
ret = []
val = TemporaryVariable()
operation = InitArray(ir.arguments[0], val)
ret.append(operation)
ir = Push(ir.destination, val)
length = len(operation.init_values)
t = operation.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
ret.insert(ir)
return ret
ir = Push(ir.destination, ir.arguments[0])
return ir
def convert_to_library(ir, node, using_for):
contract = node.function.contract
t = ir.destination.type
for destination in using_for[t]:
lib_contract = contract.slither.get_contract_from_name(str(destination))
if destination:
lib_call = LibraryCall(lib_contract,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
prev = ir
ir = lib_call
sig = '{}({})'.format(ir.function_name,
','.join([str(x.type) for x in ir.arguments]))
func = lib_contract.get_function_from_signature(sig)
if not func:
func = lib_contract.get_state_variable_from_name(ir.function_name)
assert func
ir.function = func
if isinstance(func, Function):
t = func.return_type
# if its not a tuple, return a singleton
if len(t) == 1:
t = t[0]
else:
# otherwise its a variable (getter)
t = func.type
if t:
ir.lvalue.set_type(t)
else:
ir.lvalue = None
return ir
logger.error('Library not found {}'.format(ir))
exit(0)
def get_type(t):
"""
Convert a type to a str
If the instance is a Contract, return 'address' instead
"""
if isinstance(t, UserDefinedType):
if isinstance(t.type, Contract):
return 'address'
return str(t)
def convert_type_of_high_level_call(ir, contract):
sig = '{}({})'.format(ir.function_name,
','.join([get_type(x.type) for x in ir.arguments]))
func = contract.get_function_from_signature(sig)
if not func:
func = contract.get_state_variable_from_name(ir.function_name)
else:
return_type = func.return_type
# if its not a tuple; return a singleton
if return_type and len(return_type) == 1:
return_type = return_type[0]
if not func and ir.function_name in ['call',
'delegatecall',
'codecall',
'transfer',
'send']:
return convert_to_low_level(ir)
if not func:
logger.error('Function not found {}'.format(sig))
ir.function = func
if isinstance(func, Function):
t = return_type
else:
# otherwise its a variable (getter)
t = func.type
if t:
ir.lvalue.set_type(t)
else:
ir.lvalue = None
return None
def propagate_types(ir, node):
# propagate the type
using_for = node.function.contract.using_for
if isinstance(ir, OperationWithLValue):
if not ir.lvalue.type:
if isinstance(ir, Assignment):
@ -157,52 +313,47 @@ def propagate_types(ir, node):
if BinaryType.return_bool(ir.type):
ir.lvalue.set_type(ElementaryType('bool'))
else:
ir.lvalue.set_type(ir.left_variable.type)
ir.lvalue.set_type(ir.variable_left.type)
elif isinstance(ir, Delete):
# nothing to propagate
pass
elif isinstance(ir, HighLevelCall):
t = ir.destination.type
# can be None due to temporary operation
if t:
if isinstance(t, UserDefinedType):
# UserdefinedType
t = t.type
if isinstance(t, Contract):
sig = '{}({})'.format(ir.function_name,
','.join([str(x.type) for x in ir.arguments]))
contract = node.slither.get_contract_from_name(t.name)
func = contract.get_function_from_signature(sig)
if not func:
func = t.get_state_variable_from_name(ir.function_name)
else:
return_type = func.return_type
if not func and ir.function_name in ['call', 'delegatecall','codecall']:
return
if not func:
logger.error('Function not found {}'.format(sig))
ir.function = func
if isinstance(func, Function):
t = func.return_type
else:
# otherwise its a variable (getter)
t = func.type
if t:
ir.lvalue.set_type(t)
else:
ir.lvalue = None
if isinstance(t, ElementaryType):
print(t.name)
# TODO here handle library call
# we can probably directly remove the ins, as alow level
# or a lib
if t.name == 'address':
ir.lvalue.set_type(ElementaryType('bool'))
# Temporary operaiton (they are removed later)
if t is None:
return
# convert library
if t in using_for:
return convert_to_library(ir, node, using_for)
if isinstance(t, UserDefinedType):
# UserdefinedType
t = t.type
if isinstance(t, Contract):
contract = node.slither.get_contract_from_name(t.name)
return convert_type_of_high_level_call(ir, contract)
else:
return None
# Convert HighLevelCall to LowLevelCall
if isinstance(t, ElementaryType) and t.name == 'address':
if ir.destination.name == 'this':
return convert_type_of_high_level_call(ir, node.function.contract)
else:
return convert_to_low_level(ir)
# Convert push operations
# May need to insert a new operation
# Which leads to return a list of operation
if isinstance(t, ArrayType):
if ir.function_name == 'push' and len(ir.arguments) == 1:
return convert_to_push(ir)
elif isinstance(ir, Index):
if isinstance(ir.variable_left.type, MappingType):
ir.lvalue.set_type(ir.variable_left.type.type_to)
else:
assert isinstance(ir.variable_left.type, ArrayType)
elif isinstance(ir.variable_left.type, ArrayType):
ir.lvalue.set_type(ir.variable_left.type.type)
elif isinstance(ir, InitArray):
@ -210,9 +361,13 @@ def propagate_types(ir, node):
t = ir.init_values[0].type
ir.lvalue.set_type(ArrayType(t, length))
elif isinstance(ir, InternalCall):
# if its not a tuple, return a singleton
return_type = ir.function.return_type
if return_type:
ir.lvalue.set_type(return_type)
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type)
else:
ir.lvalue = None
elif isinstance(ir, LowLevelCall):
@ -253,7 +408,11 @@ def propagate_types(ir, node):
elif isinstance(ir, Send):
ir.lvalue.set_type(ElementaryType('bool'))
elif isinstance(ir, SolidityCall):
ir.lvalue.set_type(ir.function.return_type)
return_type = ir.function.return_type
if len(return_type) == 1:
ir.lvalue.set_type(return_type[0])
else:
ir.lvalue.set_type(return_type)
elif isinstance(ir, TypeConversion):
ir.lvalue.set_type(ir.type)
elif isinstance(ir, Unary):
@ -262,7 +421,7 @@ def propagate_types(ir, node):
types = ir.tuple.type.type
idx = ir.index
t = types[idx]
ir.lvalue.set_type(t)
ir.lvalue.set_type(t)
elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)):
# temporary operation; they will be removed
pass
@ -278,8 +437,8 @@ def apply_ir_heuristics(irs, node):
irs = integrate_value_gas(irs)
irs = propage_type_and_convert_call(irs, node)
irs = remove_temporary(irs)
irs = replace_calls(irs)
# irs = remove_temporary(irs)
# irs = replace_calls(irs)
irs = remove_unused(irs)
reset_variable_number(irs)
@ -306,6 +465,12 @@ def reset_variable_number(result):
for idx in range(len(tuple_variables)):
tuple_variables[idx].index = idx
def is_temporary(ins):
return isinstance(ins, (Argument,
TmpNewElementaryType,
TmpNewContract,
TmpNewArray,
TmpNewStructure))
def remove_temporary(result):
@ -344,56 +509,56 @@ def remove_unused(result):
return result
def replace_calls(result):
'''
replace call to push to a Push Operation
Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall
'''
reset = True
def is_address(v):
if v in [SolidityVariableComposed('msg.sender'),
SolidityVariableComposed('tx.origin')]:
return True
if not isinstance(v, Variable):
return False
if not isinstance(v.type, ElementaryType):
return False
return v.type.type == 'address'
while reset:
reset = False
for idx in range(len(result)):
ins = result[idx]
if isinstance(ins, HighLevelCall):
# TODO better handle collision with function named push
if ins.function_name == 'push' and len(ins.arguments) == 1:
if isinstance(ins.arguments[0], list):
val = TemporaryVariable()
operation = InitArray(ins.arguments[0], val)
result.insert(idx, operation)
result[idx+1] = Push(ins.destination, val)
reset = True
break
else:
result[idx] = Push(ins.destination, ins.arguments[0])
if is_address(ins.destination):
if ins.function_name == 'transfer':
assert len(ins.arguments) == 1
result[idx] = Transfer(ins.destination, ins.arguments[0])
elif ins.function_name == 'send':
assert len(ins.arguments) == 1
result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue)
elif ins.function_name in ['call', 'delegatecall', 'callcode']:
# TODO: handle name collision
result[idx] = LowLevelCall(ins.destination,
ins.function_name,
ins.nbr_arguments,
ins.lvalue,
ins.type_call)
result[idx].call_gas = ins.call_gas
result[idx].call_value = ins.call_value
result[idx].arguments = ins.arguments
# other case are library on address
return result
#def replace_calls(result):
# '''
# replace call to push to a Push Operation
# Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall
# '''
# reset = True
# def is_address(v):
# if v in [SolidityVariableComposed('msg.sender'),
# SolidityVariableComposed('tx.origin')]:
# return True
# if not isinstance(v, Variable):
# return False
# if not isinstance(v.type, ElementaryType):
# return False
# return v.type.type == 'address'
# while reset:
# reset = False
# for idx in range(len(result)):
# ins = result[idx]
# if isinstance(ins, HighLevelCall):
# # TODO better handle collision with function named push
# if ins.function_name == 'push' and len(ins.arguments) == 1:
# if isinstance(ins.arguments[0], list):
# val = TemporaryVariable()
# operation = InitArray(ins.arguments[0], val)
# result.insert(idx, operation)
# result[idx+1] = Push(ins.destination, val)
# reset = True
# break
# else:
# result[idx] = Push(ins.destination, ins.arguments[0])
# if is_address(ins.destination):
# if ins.function_name == 'transfer':
# assert len(ins.arguments) == 1
# result[idx] = Transfer(ins.destination, ins.arguments[0])
# elif ins.function_name == 'send':
# assert len(ins.arguments) == 1
# result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue)
# elif ins.function_name in ['call', 'delegatecall', 'callcode']:
# # TODO: handle name collision
# result[idx] = LowLevelCall(ins.destination,
# ins.function_name,
# ins.nbr_arguments,
# ins.lvalue,
# ins.type_call)
# result[idx].call_gas = ins.call_gas
# result[idx].call_value = ins.call_value
# result[idx].arguments = ins.arguments
# # other case are library on address
# return result
def extract_tmp_call(ins):
@ -441,28 +606,28 @@ def extract_tmp_call(ins):
raise Exception('Not extracted {} {}'.format(type(ins.called), ins))
def convert_libs(result, contract):
using_for = contract.using_for
for idx in range(len(result)):
ir = result[idx]
if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable):
if ir.destination.type in using_for:
for destination in using_for[ir.destination.type]:
# destination is a UserDefinedType
destination = contract.slither.get_contract_from_name(str(destination))
if destination:
lib_call = LibraryCall(destination,
ir.function_name,
ir.nbr_arguments,
ir.lvalue,
ir.type_call)
lib_call.call_gas = ir.call_gas
lib_call.arguments = [ir.destination] + ir.arguments
result[idx] = lib_call
break
assert destination
return result
#def convert_libs(result, contract):
# using_for = contract.using_for
# for idx in range(len(result)):
# ir = result[idx]
# if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable):
# if ir.destination.type in using_for:
# for destination in using_for[ir.destination.type]:
# # destination is a UserDefinedType
# destination = contract.slither.get_contract_from_name(str(destination))
# if destination:
# lib_call = LibraryCall(destination,
# ir.function_name,
# ir.nbr_arguments,
# ir.lvalue,
# ir.type_call)
# lib_call.call_gas = ir.call_gas
# lib_call.arguments = [ir.destination] + ir.arguments
# result[idx] = lib_call
# break
# assert destination
#
# return result
def convert_expression(expression, node):
# handle standlone expression
@ -479,7 +644,7 @@ def convert_expression(expression, node):
result = apply_ir_heuristics(result, node)
result = convert_libs(result, node.function.contract)
# result = convert_libs(result, node.function.contract)
if result:
if node.type in [NodeType.IF, NodeType.IFLOOP]:

@ -165,7 +165,7 @@ class Binary(OperationWithLValue):
def __str__(self):
return '{}({}) = {} {} {}'.format(str(self.lvalue),
str(self.lvalue.type),
self.lvalue.type,
self.variable_left,
self.type_str,
self.variable_right)

@ -98,9 +98,10 @@ class HighLevelCall(Call, OperationWithLValue):
txt = '{}HIGH_LEVEL_CALL, dest:{}({}), function:{}, arguments:{} {} {}'
if not self.lvalue:
lvalue = ''
else:
print(self.lvalue)
elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
return txt.format(lvalue,
self.destination,
self.destination.type,

@ -34,8 +34,10 @@ class InternalCall(Call, OperationWithLValue):
args = [str(a) for a in self.arguments]
if not self.lvalue:
lvalue = ''
else:
elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
txt = '{}INTERNAL_CALL, {}.{}({})'
return txt.format(lvalue,
self.function.contract.name,

@ -16,9 +16,14 @@ class LibraryCall(HighLevelCall):
arguments = []
if self.arguments:
arguments = self.arguments
txt = '{}({}) = LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}'
return txt.format(self.lvalue,
self.lvalue.type,
if not self.lvalue:
lvalue = ''
elif isinstance(self.lvalue.type, (list,)):
lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type))
else:
lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type)
txt = '{}LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}'
return txt.format(lvalue,
self.destination,
self.function_name,
[str(x) for x in arguments],

@ -8,7 +8,10 @@ class Return(Operation):
Only present as last operation in RETURN node
"""
def __init__(self, value):
assert is_valid_rvalue(value) or isinstance(value, TupleVariable)
# Note: Can return None
# ex: return call()
# where call() dont return
assert is_valid_rvalue(value) or isinstance(value, TupleVariable) or value == None
super(Return, self).__init__()
self._value = value

@ -17,7 +17,6 @@ class SlitherSolc(Slither):
self._contractsNotParsed = []
self._contracts_by_id = {}
self._analyzed = False
print(filename)
def _parse_contracts_from_json(self, json_data):
first = json_data.find('{')

Loading…
Cancel
Save