From cd1ec36dadc82ee81824ce53b00aa8c92e4fed11 Mon Sep 17 00:00:00 2001 From: alpharush <0xalpharush@protonmail.com> Date: Thu, 16 Mar 2023 14:25:23 -0500 Subject: [PATCH] include salt in NewContract op, add test infra --- .github/workflows/features.yml | 1 + slither/slithir/operations/high_level_call.py | 2 +- slither/slithir/operations/new_contract.py | 4 +- tests/slithir/operation_reads.sol | 17 +++++++ tests/slithir/test_operation_reads.py | 49 +++++++++++++++++++ 5 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 tests/slithir/operation_reads.sol create mode 100644 tests/slithir/test_operation_reads.py diff --git a/.github/workflows/features.yml b/.github/workflows/features.yml index b11a137d9..2c112e0aa 100644 --- a/.github/workflows/features.yml +++ b/.github/workflows/features.yml @@ -50,6 +50,7 @@ jobs: pytest tests/test_features.py pytest tests/test_constant_folding.py pytest tests/slithir/test_ternary_expressions.py + pytest tests/slithir/test_operation_reads.py pytest tests/test_functions_ids.py pytest tests/test_function.py pytest tests/test_source_mapping.py diff --git a/slither/slithir/operations/high_level_call.py b/slither/slithir/operations/high_level_call.py index 93fb73bd4..a12917519 100644 --- a/slither/slithir/operations/high_level_call.py +++ b/slither/slithir/operations/high_level_call.py @@ -76,7 +76,7 @@ class HighLevelCall(Call, OperationWithLValue): def read(self) -> List[SourceMapping]: all_read = [self.destination, self.call_gas, self.call_value] + self._unroll(self.arguments) # remove None - return [x for x in all_read if x] + [self.destination] + return [x for x in all_read if x] @property def destination(self) -> SourceMapping: diff --git a/slither/slithir/operations/new_contract.py b/slither/slithir/operations/new_contract.py index 879d12df6..08ddbd960 100644 --- a/slither/slithir/operations/new_contract.py +++ b/slither/slithir/operations/new_contract.py @@ -52,7 +52,9 @@ class NewContract(Call, OperationWithLValue): # pylint: disable=too-many-instan @property def read(self) -> List[Any]: - return self._unroll(self.arguments) + all_read = [self.call_salt, self.call_value] + self._unroll(self.arguments) + # remove None + return [x for x in all_read if x] @property def contract_created(self) -> Contract: diff --git a/tests/slithir/operation_reads.sol b/tests/slithir/operation_reads.sol new file mode 100644 index 000000000..22adc2288 --- /dev/null +++ b/tests/slithir/operation_reads.sol @@ -0,0 +1,17 @@ + +contract Placeholder { + constructor() payable {} +} + +contract NewContract { + bytes32 internal constant state_variable_read = bytes32(0); + + function readAllStateVariables() external { + new Placeholder{salt: state_variable_read} (); + } + + function readAllLocalVariables() external { + bytes32 local_variable_read = bytes32(0); + new Placeholder{salt: local_variable_read} (); + } +} \ No newline at end of file diff --git a/tests/slithir/test_operation_reads.py b/tests/slithir/test_operation_reads.py new file mode 100644 index 000000000..aa183333f --- /dev/null +++ b/tests/slithir/test_operation_reads.py @@ -0,0 +1,49 @@ +from collections import namedtuple +from slither import Slither +from slither.slithir.operations import Operation, NewContract + + +def check_num_local_vars_read(function, slithir_op: Operation, num_reads_expected: int): + for node in function.nodes: + for operation in node.irs: + if isinstance(operation, slithir_op): + assert len(operation.read) == num_reads_expected + assert len(node.local_variables_read) == num_reads_expected + + +def check_num_states_vars_read(function, slithir_op: Operation, num_reads_expected: int): + for node in function.nodes: + for operation in node.irs: + if isinstance(operation, slithir_op): + assert len(operation.read) == num_reads_expected + assert len(node.state_variables_read) == num_reads_expected + + +OperationTest = namedtuple("OperationTest", "contract_name slithir_op") + +OPERATION_TEST = [OperationTest("NewContract", NewContract)] + + +def test_operation_reads() -> None: + """ + Every slithir operation has its own contract and reads all local and state variables in readAllLocalVariables and readAllStateVariables, respectively. + """ + slither = Slither("./tests/slithir/operation_reads.sol") + + for op_test in OPERATION_TEST: + print(op_test) + available = slither.get_contract_from_name(op_test.contract_name) + assert len(available) == 1 + target = available[0] + + num_state_variables = len(target.state_variables_ordered) + state_function = target.get_function_from_signature("readAllStateVariables()") + check_num_states_vars_read(state_function, op_test.slithir_op, num_state_variables) + + local_function = target.get_function_from_signature("readAllLocalVariables()") + num_local_vars = len(local_function.local_variables) + check_num_local_vars_read(local_function, op_test.slithir_op, num_local_vars) + + +if __name__ == "__main__": + test_operation_reads()