diff --git a/slither/core/cfg/node.py b/slither/core/cfg/node.py index 0115a358b..cdee33684 100644 --- a/slither/core/cfg/node.py +++ b/slither/core/cfg/node.py @@ -15,9 +15,9 @@ from slither.slithir.convert import convert_expression from slither.slithir.operations import (Balance, HighLevelCall, Index, InternalCall, Length, LibraryCall, LowLevelCall, Member, - OperationWithLValue, SolidityCall) + OperationWithLValue, SolidityCall, Phi, PhiCallback) from slither.slithir.variables import (Constant, ReferenceVariable, - TemporaryVariable, TupleVariable) + TemporaryVariable, TupleVariable, StateIRVariable, LocalIRVariable) from slither.visitors.expression.expression_printer import ExpressionPrinter from slither.visitors.expression.read_var import ReadVar from slither.visitors.expression.write_var import WriteVar @@ -125,8 +125,13 @@ class Node(SourceMapping, ChildFunction): self._expression = None self._variable_declaration = None self._node_id = node_id + self._vars_written = [] self._vars_read = [] + + self._ssa_vars_written = [] + self._ssa_vars_read = [] + self._internal_calls = [] self._solidity_calls = [] self._high_level_calls = [] @@ -139,9 +144,15 @@ class Node(SourceMapping, ChildFunction): self._state_vars_read = [] self._solidity_vars_read = [] + self._ssa_state_vars_written = [] + self._ssa_state_vars_read = [] + self._local_vars_read = [] self._local_vars_written = [] + self._ssa_local_vars_read = [] + self._ssa_local_vars_written = [] + self._expression_vars_written = [] self._expression_vars_read = [] self._expression_calls = [] @@ -484,22 +495,22 @@ class Node(SourceMapping, ChildFunction): self._find_read_write_call() + @staticmethod + def _is_slithir_var(var): + return isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) + def _find_read_write_call(self): - def is_slithir_var(var): - return isinstance(var, (Constant, ReferenceVariable, TemporaryVariable, TupleVariable)) for ir in self.irs: - self._vars_read += [v for v in ir.read if not is_slithir_var(v)] + self._vars_read += [v for v in ir.read if not self._is_slithir_var(v)] if isinstance(ir, OperationWithLValue): if isinstance(ir, (Index, Member, Length, Balance)): continue # Don't consider Member and Index operations -> ReferenceVariable var = ir.lvalue - # If its a reference, we loop until finding the origin if isinstance(var, (ReferenceVariable)): - while isinstance(var, ReferenceVariable): - var = var.points_to + var = var.points_to_origin # Only store non-slithIR variables - if not is_slithir_var(var) and var: + if not self._is_slithir_var(var) and var: self._vars_written.append(var) if isinstance(ir, InternalCall): @@ -532,4 +543,51 @@ class Node(SourceMapping, ChildFunction): self._high_level_calls = list(set(self._high_level_calls)) self._low_level_calls = list(set(self._low_level_calls)) + @staticmethod + def _convert_ssa(v): + if isinstance(v, StateIRVariable): + contract = v.contract + non_ssa_var = contract.get_state_variable_from_name(v.name) + return non_ssa_var + assert isinstance(v, LocalIRVariable) + function = v.function + non_ssa_var = function.get_local_variable_from_name(v.name) + return non_ssa_var + + def update_read_write_using_ssa(self): + if not self.expression: + return + for ir in self.irs_ssa: + self._ssa_vars_read += [v for v in ir.read if isinstance(v, + (StateIRVariable, + LocalIRVariable))] + if isinstance(ir, OperationWithLValue): + if isinstance(ir, (Index, Member, Length, Balance)): + continue # Don't consider Member and Index operations -> ReferenceVariable + var = ir.lvalue + if isinstance(var, (ReferenceVariable)): + var = var.points_to_origin + # Only store non-slithIR variables + if var and isinstance(var, (StateIRVariable, LocalIRVariable)): + if isinstance(ir, (PhiCallback)): + continue + self._ssa_vars_written.append(var) + + self._ssa_vars_read = list(set(self._ssa_vars_read)) + self._ssa_state_vars_read = [v for v in self._ssa_vars_read if isinstance(v, StateVariable)] + self._ssa_local_vars_read = [v for v in self._ssa_vars_read if isinstance(v, LocalVariable)] + self._ssa_vars_written = list(set(self._ssa_vars_written)) + self._ssa_state_vars_written = [v for v in self._ssa_vars_written if isinstance(v, StateVariable)] + self._ssa_local_vars_written = [v for v in self._ssa_vars_written if isinstance(v, LocalVariable)] + + vars_read = [self._convert_ssa(x) for x in self._ssa_vars_read] + vars_written = [self._convert_ssa(x) for x in self._ssa_vars_written] + + self._vars_read += [v for v in vars_read if v not in self._vars_read] + self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)] + self._local_vars_read = [v for v in self._vars_read if isinstance(v, LocalVariable)] + + self._vars_written += [v for v in vars_written if v not in self._vars_written] + self._state_vars_written = [v for v in self._vars_written if isinstance(v, StateVariable)] + self._local_vars_written = [v for v in self._vars_written if isinstance(v, LocalVariable)] diff --git a/slither/core/declarations/contract.py b/slither/core/declarations/contract.py index 86ec9f4fc..9994f0046 100644 --- a/slither/core/declarations/contract.py +++ b/slither/core/declarations/contract.py @@ -329,6 +329,10 @@ class Contract(ChildSlither, SourceMapping): 'transferFrom(address,address,uint256)' in full_names and\ 'approve(address,uint256)' in full_names + def update_read_write_using_ssa(self): + for function in self.functions + self.modifiers: + function.update_read_write_using_ssa() + def get_summary(self): """ Return the function summary diff --git a/slither/core/declarations/function.py b/slither/core/declarations/function.py index 7efcb0ace..71fd6ff18 100644 --- a/slither/core/declarations/function.py +++ b/slither/core/declarations/function.py @@ -669,3 +669,13 @@ class Function(ChildContract, SourceMapping): conditional_vars = self.all_conditional_solidity_variables_read() args_vars = self.all_solidity_variables_used_as_args() return SolidityVariableComposed('msg.sender') in conditional_vars + args_vars + + def get_local_variable_from_name(self, variable_name): + """ + Return a local variable from a name + Args: + varible_name (str): name of the variable + Returns: + LocalVariable + """ + return next((v for v in self.variables if v.name == variable_name), None) diff --git a/slither/detectors/reentrancy/reentrancy.py b/slither/detectors/reentrancy/reentrancy.py index cdfafe46b..cb09ddf9f 100644 --- a/slither/detectors/reentrancy/reentrancy.py +++ b/slither/detectors/reentrancy/reentrancy.py @@ -10,7 +10,6 @@ from slither.core.declarations import Function, SolidityFunction from slither.core.expressions import UnaryOperation, UnaryOperationType from slither.detectors.abstract_detector import (AbstractDetector, DetectorClassification) -from slither.visitors.expression.export_values import ExportValues from slither.slithir.operations import (HighLevelCall, LowLevelCall, LibraryCall, Send, Transfer) diff --git a/slither/solc_parsing/declarations/contract.py b/slither/solc_parsing/declarations/contract.py index c35e884f3..f3cb4d8df 100644 --- a/slither/solc_parsing/declarations/contract.py +++ b/slither/solc_parsing/declarations/contract.py @@ -390,5 +390,6 @@ class ContractSolc04(Contract): for func in self.functions + self.modifiers: func.fix_phi(last_state_variables_instances, initial_state_variables_instances) + def __hash__(self): return self._id diff --git a/slither/solc_parsing/declarations/function.py b/slither/solc_parsing/declarations/function.py index fc783b53c..9aa2e2663 100644 --- a/slither/solc_parsing/declarations/function.py +++ b/slither/solc_parsing/declarations/function.py @@ -914,7 +914,11 @@ class FunctionSolc(Function): compute_dominance_frontier(self.nodes) transform_slithir_vars_to_ssa(self) add_ssa_ir(self, all_ssa_state_variables_instances, all_written_state_variables) - + + def update_read_write_using_ssa(self): + for node in self.nodes: + node.update_read_write_using_ssa() + self._analyze_read_write() def split_ternary_node(self, node, condition, true_expr, false_expr): condition_node = self._new_node(NodeType.IF, node.source_mapping) @@ -961,3 +965,4 @@ class FunctionSolc(Function): self._nodes = [n for n in self._nodes if n.node_id != node.node_id] + diff --git a/slither/solc_parsing/slitherSolc.py b/slither/solc_parsing/slitherSolc.py index 91628309e..f8bec0ab8 100644 --- a/slither/solc_parsing/slitherSolc.py +++ b/slither/solc_parsing/slitherSolc.py @@ -315,4 +315,5 @@ class SlitherSolc(Slither): contract.convert_expression_to_slithir() for contract in self.contracts: contract.fix_phi() + contract.update_read_write_using_ssa()