From 10b3b4037d48c902770bdf0c93a1f45dd84af708 Mon Sep 17 00:00:00 2001 From: Josselin Date: Fri, 12 Oct 2018 11:51:48 +0100 Subject: [PATCH] SlithIR: improve support for type --- .../core/declarations/solidity_variables.py | 2 +- slither/slithir/convert.py | 401 ++++++++++++------ slither/slithir/operations/binary.py | 2 +- slither/slithir/operations/high_level_call.py | 5 +- slither/slithir/operations/internal_call.py | 4 +- slither/slithir/operations/library_call.py | 11 +- .../slithir/operations/return_operation.py | 5 +- slither/solc_parsing/slitherSolc.py | 1 - 8 files changed, 303 insertions(+), 128 deletions(-) diff --git a/slither/core/declarations/solidity_variables.py b/slither/core/declarations/solidity_variables.py index 1ad4a90ed..96e5dc2e9 100644 --- a/slither/core/declarations/solidity_variables.py +++ b/slither/core/declarations/solidity_variables.py @@ -104,7 +104,7 @@ class SolidityVariableComposed(SolidityVariable): @property def type(self): - return SOLIDITY_VARIABLES_COMPOSED[self.name] + return ElementaryType(SOLIDITY_VARIABLES_COMPOSED[self.name]) def __str__(self): return self._name diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index c64cf805a..a5eb2d468 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -116,7 +116,9 @@ def propage_type_and_convert_call(result, node): call_data = [] - for idx in range(len(result)): + idx = 0 + # use of while len() as result can be modified during the iteration + while idx < len(result): ins = result[idx] if isinstance(ins, TmpCall): @@ -144,11 +146,165 @@ def propage_type_and_convert_call(result, node): if isinstance(ins, (Call, NewContract, NewStructure)): ins.arguments = call_data - propagate_types(ins, node) + call_data = [] + + if is_temporary(ins): + del result[idx] + continue + + new_ins = propagate_types(ins, node) + if new_ins: + if isinstance(new_ins, (list,)): + assert len(new_ins) == 2 + result.insert(idx, new_ins[0]) + result.insert(idx+1, new_ins[1]) + idx = idx + 1 + else: + result[idx] = new_ins + idx = idx +1 return result +def convert_to_low_level(ir): + """ + Convert to a transfer/send/or low level call + The funciton assume to receive a correct IR + The checks must be done by the caller + """ + if ir.function_name == 'transfer': + assert len(ir.arguments) == 1 + ir = Transfer(ir.destination, ir.arguments[0]) + return ir + elif ir.function_name == 'send': + assert len(ir.arguments) == 1 + ir = Send(ir.destination, ir.arguments[0], ir.lvalue) + ir.lvalue.set_type(ElementaryType('bool')) + return ir + elif ir.function_name in ['call', 'delegatecall', 'callcode']: + new_ir = LowLevelCall(ir.destination, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + new_ir.call_gas = ir.call_gas + new_ir.call_value = ir.call_value + new_ir.arguments = ir.arguments + new_ir.lvalue.set_type(ElementaryType('bool')) + return new_ir + logger.error('Incorrect conversion to low level {}'.format(ir)) + exit(0) + +def convert_to_push(ir): + """ + Convert a call to a PUSH operaiton + + The funciton assume to receive a correct IR + The checks must be done by the caller + + May necessitate to create an intermediate operation (InitArray) + As a result, the function return may return a list + """ + if isinstance(ir.arguments[0], list): + ret = [] + + val = TemporaryVariable() + operation = InitArray(ir.arguments[0], val) + ret.append(operation) + + ir = Push(ir.destination, val) + + length = len(operation.init_values) + t = operation.init_values[0].type + ir.lvalue.set_type(ArrayType(t, length)) + + ret.insert(ir) + return ret + + ir = Push(ir.destination, ir.arguments[0]) + return ir + +def convert_to_library(ir, node, using_for): + contract = node.function.contract + t = ir.destination.type + for destination in using_for[t]: + lib_contract = contract.slither.get_contract_from_name(str(destination)) + if destination: + lib_call = LibraryCall(lib_contract, + ir.function_name, + ir.nbr_arguments, + ir.lvalue, + ir.type_call) + lib_call.call_gas = ir.call_gas + lib_call.arguments = [ir.destination] + ir.arguments + prev = ir + ir = lib_call + sig = '{}({})'.format(ir.function_name, + ','.join([str(x.type) for x in ir.arguments])) + func = lib_contract.get_function_from_signature(sig) + if not func: + func = lib_contract.get_state_variable_from_name(ir.function_name) + assert func + ir.function = func + if isinstance(func, Function): + t = func.return_type + # if its not a tuple, return a singleton + if len(t) == 1: + t = t[0] + else: + # otherwise its a variable (getter) + t = func.type + if t: + ir.lvalue.set_type(t) + else: + ir.lvalue = None + return ir + logger.error('Library not found {}'.format(ir)) + exit(0) + +def get_type(t): + """ + Convert a type to a str + If the instance is a Contract, return 'address' instead + """ + if isinstance(t, UserDefinedType): + if isinstance(t.type, Contract): + return 'address' + return str(t) + +def convert_type_of_high_level_call(ir, contract): + sig = '{}({})'.format(ir.function_name, + ','.join([get_type(x.type) for x in ir.arguments])) + func = contract.get_function_from_signature(sig) + if not func: + func = contract.get_state_variable_from_name(ir.function_name) + else: + return_type = func.return_type + # if its not a tuple; return a singleton + if return_type and len(return_type) == 1: + return_type = return_type[0] + if not func and ir.function_name in ['call', + 'delegatecall', + 'codecall', + 'transfer', + 'send']: + return convert_to_low_level(ir) + if not func: + logger.error('Function not found {}'.format(sig)) + ir.function = func + if isinstance(func, Function): + t = return_type + else: + # otherwise its a variable (getter) + t = func.type + if t: + ir.lvalue.set_type(t) + else: + ir.lvalue = None + + return None + def propagate_types(ir, node): # propagate the type + using_for = node.function.contract.using_for if isinstance(ir, OperationWithLValue): if not ir.lvalue.type: if isinstance(ir, Assignment): @@ -157,52 +313,47 @@ def propagate_types(ir, node): if BinaryType.return_bool(ir.type): ir.lvalue.set_type(ElementaryType('bool')) else: - ir.lvalue.set_type(ir.left_variable.type) + ir.lvalue.set_type(ir.variable_left.type) elif isinstance(ir, Delete): # nothing to propagate pass elif isinstance(ir, HighLevelCall): t = ir.destination.type - # can be None due to temporary operation - if t: - if isinstance(t, UserDefinedType): - # UserdefinedType - t = t.type - if isinstance(t, Contract): - sig = '{}({})'.format(ir.function_name, - ','.join([str(x.type) for x in ir.arguments])) - contract = node.slither.get_contract_from_name(t.name) - func = contract.get_function_from_signature(sig) - if not func: - func = t.get_state_variable_from_name(ir.function_name) - else: - return_type = func.return_type - if not func and ir.function_name in ['call', 'delegatecall','codecall']: - return - if not func: - logger.error('Function not found {}'.format(sig)) - ir.function = func - if isinstance(func, Function): - t = func.return_type - else: - # otherwise its a variable (getter) - t = func.type - if t: - ir.lvalue.set_type(t) - else: - ir.lvalue = None - if isinstance(t, ElementaryType): - print(t.name) - # TODO here handle library call - # we can probably directly remove the ins, as alow level - # or a lib - if t.name == 'address': - ir.lvalue.set_type(ElementaryType('bool')) + + # Temporary operaiton (they are removed later) + if t is None: + return + # convert library + if t in using_for: + return convert_to_library(ir, node, using_for) + + if isinstance(t, UserDefinedType): + # UserdefinedType + t = t.type + if isinstance(t, Contract): + contract = node.slither.get_contract_from_name(t.name) + return convert_type_of_high_level_call(ir, contract) + else: + return None + + # Convert HighLevelCall to LowLevelCall + if isinstance(t, ElementaryType) and t.name == 'address': + if ir.destination.name == 'this': + return convert_type_of_high_level_call(ir, node.function.contract) + else: + return convert_to_low_level(ir) + + # Convert push operations + # May need to insert a new operation + # Which leads to return a list of operation + if isinstance(t, ArrayType): + if ir.function_name == 'push' and len(ir.arguments) == 1: + return convert_to_push(ir) + elif isinstance(ir, Index): if isinstance(ir.variable_left.type, MappingType): ir.lvalue.set_type(ir.variable_left.type.type_to) - else: - assert isinstance(ir.variable_left.type, ArrayType) + elif isinstance(ir.variable_left.type, ArrayType): ir.lvalue.set_type(ir.variable_left.type.type) elif isinstance(ir, InitArray): @@ -210,9 +361,13 @@ def propagate_types(ir, node): t = ir.init_values[0].type ir.lvalue.set_type(ArrayType(t, length)) elif isinstance(ir, InternalCall): + # if its not a tuple, return a singleton return_type = ir.function.return_type if return_type: - ir.lvalue.set_type(return_type) + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + else: + ir.lvalue.set_type(return_type) else: ir.lvalue = None elif isinstance(ir, LowLevelCall): @@ -253,7 +408,11 @@ def propagate_types(ir, node): elif isinstance(ir, Send): ir.lvalue.set_type(ElementaryType('bool')) elif isinstance(ir, SolidityCall): - ir.lvalue.set_type(ir.function.return_type) + return_type = ir.function.return_type + if len(return_type) == 1: + ir.lvalue.set_type(return_type[0]) + else: + ir.lvalue.set_type(return_type) elif isinstance(ir, TypeConversion): ir.lvalue.set_type(ir.type) elif isinstance(ir, Unary): @@ -262,7 +421,7 @@ def propagate_types(ir, node): types = ir.tuple.type.type idx = ir.index t = types[idx] - ir.lvalue.set_type(t) + ir.lvalue.set_type(t) elif isinstance(ir, (Argument, TmpCall, TmpNewArray, TmpNewContract, TmpNewStructure, TmpNewElementaryType)): # temporary operation; they will be removed pass @@ -278,8 +437,8 @@ def apply_ir_heuristics(irs, node): irs = integrate_value_gas(irs) irs = propage_type_and_convert_call(irs, node) - irs = remove_temporary(irs) - irs = replace_calls(irs) +# irs = remove_temporary(irs) +# irs = replace_calls(irs) irs = remove_unused(irs) reset_variable_number(irs) @@ -306,6 +465,12 @@ def reset_variable_number(result): for idx in range(len(tuple_variables)): tuple_variables[idx].index = idx +def is_temporary(ins): + return isinstance(ins, (Argument, + TmpNewElementaryType, + TmpNewContract, + TmpNewArray, + TmpNewStructure)) def remove_temporary(result): @@ -344,56 +509,56 @@ def remove_unused(result): return result -def replace_calls(result): - ''' - replace call to push to a Push Operation - Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall - ''' - reset = True - def is_address(v): - if v in [SolidityVariableComposed('msg.sender'), - SolidityVariableComposed('tx.origin')]: - return True - if not isinstance(v, Variable): - return False - if not isinstance(v.type, ElementaryType): - return False - return v.type.type == 'address' - while reset: - reset = False - for idx in range(len(result)): - ins = result[idx] - if isinstance(ins, HighLevelCall): - # TODO better handle collision with function named push - if ins.function_name == 'push' and len(ins.arguments) == 1: - if isinstance(ins.arguments[0], list): - val = TemporaryVariable() - operation = InitArray(ins.arguments[0], val) - result.insert(idx, operation) - result[idx+1] = Push(ins.destination, val) - reset = True - break - else: - result[idx] = Push(ins.destination, ins.arguments[0]) - if is_address(ins.destination): - if ins.function_name == 'transfer': - assert len(ins.arguments) == 1 - result[idx] = Transfer(ins.destination, ins.arguments[0]) - elif ins.function_name == 'send': - assert len(ins.arguments) == 1 - result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue) - elif ins.function_name in ['call', 'delegatecall', 'callcode']: - # TODO: handle name collision - result[idx] = LowLevelCall(ins.destination, - ins.function_name, - ins.nbr_arguments, - ins.lvalue, - ins.type_call) - result[idx].call_gas = ins.call_gas - result[idx].call_value = ins.call_value - result[idx].arguments = ins.arguments - # other case are library on address - return result +#def replace_calls(result): +# ''' +# replace call to push to a Push Operation +# Replace to call 'call' 'delegatecall', 'callcode' to an LowLevelCall +# ''' +# reset = True +# def is_address(v): +# if v in [SolidityVariableComposed('msg.sender'), +# SolidityVariableComposed('tx.origin')]: +# return True +# if not isinstance(v, Variable): +# return False +# if not isinstance(v.type, ElementaryType): +# return False +# return v.type.type == 'address' +# while reset: +# reset = False +# for idx in range(len(result)): +# ins = result[idx] +# if isinstance(ins, HighLevelCall): +# # TODO better handle collision with function named push +# if ins.function_name == 'push' and len(ins.arguments) == 1: +# if isinstance(ins.arguments[0], list): +# val = TemporaryVariable() +# operation = InitArray(ins.arguments[0], val) +# result.insert(idx, operation) +# result[idx+1] = Push(ins.destination, val) +# reset = True +# break +# else: +# result[idx] = Push(ins.destination, ins.arguments[0]) +# if is_address(ins.destination): +# if ins.function_name == 'transfer': +# assert len(ins.arguments) == 1 +# result[idx] = Transfer(ins.destination, ins.arguments[0]) +# elif ins.function_name == 'send': +# assert len(ins.arguments) == 1 +# result[idx] = Send(ins.destination, ins.arguments[0], ins.lvalue) +# elif ins.function_name in ['call', 'delegatecall', 'callcode']: +# # TODO: handle name collision +# result[idx] = LowLevelCall(ins.destination, +# ins.function_name, +# ins.nbr_arguments, +# ins.lvalue, +# ins.type_call) +# result[idx].call_gas = ins.call_gas +# result[idx].call_value = ins.call_value +# result[idx].arguments = ins.arguments +# # other case are library on address +# return result def extract_tmp_call(ins): @@ -441,28 +606,28 @@ def extract_tmp_call(ins): raise Exception('Not extracted {} {}'.format(type(ins.called), ins)) -def convert_libs(result, contract): - using_for = contract.using_for - for idx in range(len(result)): - ir = result[idx] - if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable): - if ir.destination.type in using_for: - for destination in using_for[ir.destination.type]: - # destination is a UserDefinedType - destination = contract.slither.get_contract_from_name(str(destination)) - if destination: - lib_call = LibraryCall(destination, - ir.function_name, - ir.nbr_arguments, - ir.lvalue, - ir.type_call) - lib_call.call_gas = ir.call_gas - lib_call.arguments = [ir.destination] + ir.arguments - result[idx] = lib_call - break - assert destination - - return result +#def convert_libs(result, contract): +# using_for = contract.using_for +# for idx in range(len(result)): +# ir = result[idx] +# if isinstance(ir, HighLevelCall) and isinstance(ir.destination, Variable): +# if ir.destination.type in using_for: +# for destination in using_for[ir.destination.type]: +# # destination is a UserDefinedType +# destination = contract.slither.get_contract_from_name(str(destination)) +# if destination: +# lib_call = LibraryCall(destination, +# ir.function_name, +# ir.nbr_arguments, +# ir.lvalue, +# ir.type_call) +# lib_call.call_gas = ir.call_gas +# lib_call.arguments = [ir.destination] + ir.arguments +# result[idx] = lib_call +# break +# assert destination +# +# return result def convert_expression(expression, node): # handle standlone expression @@ -479,7 +644,7 @@ def convert_expression(expression, node): result = apply_ir_heuristics(result, node) - result = convert_libs(result, node.function.contract) +# result = convert_libs(result, node.function.contract) if result: if node.type in [NodeType.IF, NodeType.IFLOOP]: diff --git a/slither/slithir/operations/binary.py b/slither/slithir/operations/binary.py index 8abe46351..5a770f417 100644 --- a/slither/slithir/operations/binary.py +++ b/slither/slithir/operations/binary.py @@ -165,7 +165,7 @@ class Binary(OperationWithLValue): def __str__(self): return '{}({}) = {} {} {}'.format(str(self.lvalue), - str(self.lvalue.type), + self.lvalue.type, self.variable_left, self.type_str, self.variable_right) diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index a37e532eb..ad4cc5956 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -98,9 +98,10 @@ class HighLevelCall(Call, OperationWithLValue): txt = '{}HIGH_LEVEL_CALL, dest:{}({}), function:{}, arguments:{} {} {}' if not self.lvalue: lvalue = '' - else: - print(self.lvalue) + elif isinstance(self.lvalue.type, (list,)): lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) return txt.format(lvalue, self.destination, self.destination.type, diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index 4a7edb8e3..e8bbb3db7 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -34,8 +34,10 @@ class InternalCall(Call, OperationWithLValue): args = [str(a) for a in self.arguments] if not self.lvalue: lvalue = '' - else: + elif isinstance(self.lvalue.type, (list,)): lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) txt = '{}INTERNAL_CALL, {}.{}({})' return txt.format(lvalue, self.function.contract.name, diff --git a/slither/slithir/operations/library_call.py b/slither/slithir/operations/library_call.py index c1b3821eb..76abfbe8a 100644 --- a/slither/slithir/operations/library_call.py +++ b/slither/slithir/operations/library_call.py @@ -16,9 +16,14 @@ class LibraryCall(HighLevelCall): arguments = [] if self.arguments: arguments = self.arguments - txt = '{}({}) = LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}' - return txt.format(self.lvalue, - self.lvalue.type, + if not self.lvalue: + lvalue = '' + elif isinstance(self.lvalue.type, (list,)): + lvalue = '{}({}) = '.format(self.lvalue, ','.join(str(x) for x in self.lvalue.type)) + else: + lvalue = '{}({}) = '.format(self.lvalue, self.lvalue.type) + txt = '{}LIBRARY_CALL, dest:{}, function:{}, arguments:{} {}' + return txt.format(lvalue, self.destination, self.function_name, [str(x) for x in arguments], diff --git a/slither/slithir/operations/return_operation.py b/slither/slithir/operations/return_operation.py index 5489087ea..225e40823 100644 --- a/slither/slithir/operations/return_operation.py +++ b/slither/slithir/operations/return_operation.py @@ -8,7 +8,10 @@ class Return(Operation): Only present as last operation in RETURN node """ def __init__(self, value): - assert is_valid_rvalue(value) or isinstance(value, TupleVariable) + # Note: Can return None + # ex: return call() + # where call() dont return + assert is_valid_rvalue(value) or isinstance(value, TupleVariable) or value == None super(Return, self).__init__() self._value = value diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index 7f3157534..d0a69cc13 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -17,7 +17,6 @@ class SlitherSolc(Slither): self._contractsNotParsed = [] self._contracts_by_id = {} self._analyzed = False - print(filename) def _parse_contracts_from_json(self, json_data): first = json_data.find('{')