diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index cd73effb0..483a618af 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -997,51 +997,34 @@ class Function(ChildContract, SourceMapping): calls = [x.calls_as_expression for x in self.nodes] calls = [x for x in calls if x] calls = [item for sublist in calls for item in sublist] - # Remove dupplicate if they share the same string representation - # TODO: check if groupby is still necessary here - calls = [next(obj) for i, obj in\ - groupby(sorted(calls, key=lambda x: str(x)), lambda x: str(x))] - self._expression_calls = calls + self._expression_calls = list(set(calls)) internal_calls = [x.internal_calls for x in self.nodes] internal_calls = [x for x in internal_calls if x] internal_calls = [item for sublist in internal_calls for item in sublist] - internal_calls = [next(obj) for i, obj in - groupby(sorted(internal_calls, key=lambda x: str(x)), lambda x: str(x))] - self._internal_calls = internal_calls + self._internal_calls = list(set(internal_calls)) self._solidity_calls = [c for c in internal_calls if isinstance(c, SolidityFunction)] low_level_calls = [x.low_level_calls for x in self.nodes] low_level_calls = [x for x in low_level_calls if x] low_level_calls = [item for sublist in low_level_calls for item in sublist] - low_level_calls = [next(obj) for i, obj in - groupby(sorted(low_level_calls, key=lambda x: str(x)), lambda x: str(x))] - - self._low_level_calls = low_level_calls + self._low_level_calls = list(set(low_level_calls)) high_level_calls = [x.high_level_calls for x in self.nodes] high_level_calls = [x for x in high_level_calls if x] high_level_calls = [item for sublist in high_level_calls for item in sublist] - high_level_calls = [next(obj) for i, obj in - groupby(sorted(high_level_calls, key=lambda x: str(x)), lambda x: str(x))] - - self._high_level_calls = high_level_calls + self._high_level_calls = list(set(high_level_calls)) library_calls = [x.library_calls for x in self.nodes] library_calls = [x for x in library_calls if x] library_calls = [item for sublist in library_calls for item in sublist] - library_calls = [next(obj) for i, obj in - groupby(sorted(library_calls, key=lambda x: str(x)), lambda x: str(x))] - - self._library_calls = library_calls + self._library_calls = list(set(library_calls)) external_calls_as_expressions = [x.external_calls_as_expressions for x in self.nodes] external_calls_as_expressions = [x for x in external_calls_as_expressions if x] external_calls_as_expressions = [item for sublist in external_calls_as_expressions for item in sublist] - external_calls_as_expressions = [next(obj) for i, obj in - groupby(sorted(external_calls_as_expressions, key=lambda x: str(x)), lambda x: str(x))] - self._external_calls_as_expressions = external_calls_as_expressions + self._external_calls_as_expressions = list(set(external_calls_as_expressions)) diff --git a/slither/slithir/convert.py b/slither/slithir/convert.py index e09cd2737..23e84db37 100644 --- a/slither/slithir/convert.py +++ b/slither/slithir/convert.py @@ -230,7 +230,7 @@ def propagate_type_and_convert_call(result, node): ins = result[idx] if isinstance(ins, TmpCall): - new_ins = extract_tmp_call(ins) + new_ins = extract_tmp_call(ins, node.function.contract) if new_ins: new_ins.set_node(ins.node) ins = new_ins @@ -323,12 +323,12 @@ def propagate_types(ir, node): t_type = t.type if isinstance(t_type, Contract): contract = node.slither.get_contract_from_name(t_type.name) - return convert_type_of_high_level_call(ir, contract) + return convert_type_of_high_and_internal_level_call(ir, contract) # Convert HighLevelCall to LowLevelCall if isinstance(t, ElementaryType) and t.name == 'address': if ir.destination.name == 'this': - return convert_type_of_high_level_call(ir, node.function.contract) + return convert_type_of_high_and_internal_level_call(ir, node.function.contract) return convert_to_low_level(ir) # Convert push operations @@ -350,6 +350,8 @@ def propagate_types(ir, node): ir.lvalue.set_type(ArrayType(t, length)) elif isinstance(ir, InternalCall): # if its not a tuple, return a singleton + if ir.function is None: + convert_type_of_high_and_internal_level_call(ir, ir.contract) return_type = ir.function.return_type if return_type: if len(return_type) == 1: @@ -435,7 +437,7 @@ def propagate_types(ir, node): logger.error('Not handling {} during type propgation'.format(type(ir))) exit(-1) -def extract_tmp_call(ins): +def extract_tmp_call(ins, contract): assert isinstance(ins, TmpCall) if isinstance(ins.called, Variable) and isinstance(ins.called.type, FunctionType): @@ -443,6 +445,11 @@ def extract_tmp_call(ins): call.call_id = ins.call_id return call if isinstance(ins.ori, Member): + # If there is a call on an inherited contract, it is an internal call + if ins.ori.variable_left in contract.inheritance + [contract]: + internalcall = InternalCall(ins.ori.variable_right, ins.ori.variable_left, ins.nbr_arguments, ins.lvalue, ins.type_call) + internalcall.call_id = ins.call_id + return internalcall if isinstance(ins.ori.variable_left, Contract): st = ins.ori.variable_left.get_structure_from_name(ins.ori.variable_right) if st: @@ -457,7 +464,7 @@ def extract_tmp_call(ins): return msgcall if isinstance(ins.ori, TmpCall): - r = extract_tmp_call(ins.ori) + r = extract_tmp_call(ins.ori, contract) return r if isinstance(ins.called, SolidityVariableComposed): if str(ins.called) == 'block.blockhash': @@ -671,7 +678,7 @@ def convert_type_library_call(ir, lib_contract): ir.lvalue = None return ir -def convert_type_of_high_level_call(ir, contract): +def convert_type_of_high_and_internal_level_call(ir, contract): func = None sigs = get_sig(ir) for sig in sigs: diff --git a/slither/slithir/operations/internal_call.py b/slither/slithir/operations/internal_call.py index 0af0de2ff..5f2210e90 100644 --- a/slither/slithir/operations/internal_call.py +++ b/slither/slithir/operations/internal_call.py @@ -2,14 +2,20 @@ from slither.core.declarations.function import Function from slither.slithir.operations.call import Call from slither.slithir.operations.lvalue import OperationWithLValue from slither.core.variables.variable import Variable - +from slither.slithir.variables import Constant class InternalCall(Call, OperationWithLValue): - def __init__(self, function, nbr_arguments, result, type_call): - assert isinstance(function, Function) + def __init__(self, function, contract, nbr_arguments, result, type_call): super(InternalCall, self).__init__() - self._function = function + if isinstance(function, Function): + self._function = function + self._function_name = function.name + else: + isinstance(function, Constant) + self._function = None + self._function_name = function + self._contract = contract self._nbr_arguments = nbr_arguments self._type_call = type_call self._lvalue = result @@ -22,6 +28,18 @@ class InternalCall(Call, OperationWithLValue): def function(self): return self._function + @function.setter + def function(self, f): + self._function = f + + @property + def contract(self): + return self._contract + + @property + def function_name(self): + return self._function_name + @property def nbr_arguments(self): return self._nbr_arguments diff --git a/slither/slithir/utils/ssa.py b/slither/slithir/utils/ssa.py index 3ef917dbd..74238589d 100644 --- a/slither/slithir/utils/ssa.py +++ b/slither/slithir/utils/ssa.py @@ -565,7 +565,7 @@ def copy_ir(ir, *instances): nbr_arguments = ir.nbr_arguments lvalue = get_variable(ir, lambda x: x.lvalue, *instances) type_call = ir.type_call - new_ir = InternalCall(function, nbr_arguments, lvalue, type_call) + new_ir = InternalCall(function, function.contract, nbr_arguments, lvalue, type_call) new_ir.arguments = get_arguments(ir, *instances) return new_ir elif isinstance(ir, InternalDynamicCall): diff --git a/slither/visitors/slithir/expression_to_slithir.py b/slither/visitors/slithir/expression_to_slithir.py index a1f682f88..239c80ace 100644 --- a/slither/visitors/slithir/expression_to_slithir.py +++ b/slither/visitors/slithir/expression_to_slithir.py @@ -131,7 +131,7 @@ class ExpressionToSlithIR(ExpressionVisitor): val = TupleVariable(self._node) else: val = TemporaryVariable(self._node) - internal_call = InternalCall(called, len(args), val, expression.type_call) + internal_call = InternalCall(called, called.contract, len(args), val, expression.type_call) self._result.append(internal_call) set_val(expression, val) else: