diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index 6aeed770..561a75dd 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -17,7 +17,7 @@ from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintS from mythril.laser.ethereum.strategy.beam import BeamSearch from mythril.laser.ethereum.natives import PRECOMPILE_COUNT from mythril.laser.ethereum.transaction.symbolic import ACTORS - +from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser from mythril.laser.plugin.loader import LaserPluginLoader from mythril.laser.plugin.plugins import ( @@ -100,6 +100,11 @@ class SymExecWrapper: else: raise ValueError("Invalid strategy argument supplied") + if args.incremental_txs is False: + tx_strategy = RfTxPrioritiser(contract) + else: + tx_strategy = None + creator_account = Account( hex(ACTORS.creator.value), "", dynamic_loader=None, contract_name=None ) @@ -128,6 +133,7 @@ class SymExecWrapper: transaction_count=transaction_count, requires_statespace=requires_statespace, beam_width=beam_width, + tx_strategy=tx_strategy, ) if loop_bound is not None: diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index a4faedaa..77a1acbe 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -63,6 +63,7 @@ class LaserEVM: iprof=None, use_reachability_check=True, beam_width=None, + tx_strategy=None, ) -> None: """ Initializes the laser evm object @@ -75,6 +76,9 @@ class LaserEVM: :param transaction_count: The amount of transactions to execute :param requires_statespace: Variable indicating whether the statespace should be recorded :param iprof: Instruction Profiler + :param use_reachability_check: Runs reachability check by solving constraints + :param beam_width: The beam width for search strategies + :param tx_strategy: A global tx prioritisation strategy """ self.execution_info: List[ExecutionInfo] = [] @@ -87,6 +91,7 @@ class LaserEVM: self.strategy = strategy(self.work_list, max_depth, beam_width=beam_width) self.max_depth = max_depth self.transaction_count = transaction_count + self.tx_strategy = tx_strategy self.execution_timeout = execution_timeout or 0 self.create_timeout = create_timeout or 0 @@ -221,20 +226,35 @@ class LaserEVM: """ for hook in self._start_exec_trans_hooks: hook() - - if self.executed_transactions is False: - self._execute_transactions(address) - + if self.tx_strategy is None: + if self.executed_transactions is False: + self.time = datetime.now() + self._execute_transactions_incremental( + address, txs=args.transaction_sequences + ) + else: + self.time = datetime.now() + self._execute_transactions_non_ordered(address) for hook in self._stop_exec_trans_hooks: hook() - def _execute_transactions(self, address): - """This function executes multiple transactions on the address + def _execute_transactions_ordered(self, address): + """ + This function executes multiple transactions non-incrementally, using some type priority ordering + + :param address: Address of the contract + :return: + """ + for txs in self.strategy: + log.info(f"Executing the sequence: {txs}") + self._execute_transactions_incremental(address, txs=tx) + + def _execute_transactions_incremental(self, address, txs=None): + """This function executes multiple transactions incrementally on the address :param address: Address of the contract :return: """ - self.time = datetime.now() for i in range(self.transaction_count): if len(self.open_states) == 0: @@ -267,9 +287,7 @@ class LaserEVM: i, len(self.open_states) ) ) - func_hashes = ( - args.transaction_sequences[i] if args.transaction_sequences else None - ) + func_hashes = txs[i] if txs else None if func_hashes: for itr, func_hash in enumerate(func_hashes): diff --git a/mythril/laser/ethereum/tx_prioritiser/__init__.py b/mythril/laser/ethereum/tx_prioritiser/__init__.py new file mode 100644 index 00000000..2c078ed7 --- /dev/null +++ b/mythril/laser/ethereum/tx_prioritiser/__init__.py @@ -0,0 +1 @@ +from .rf_prioritiser import RfTxPrioritiser diff --git a/mythril/laser/ethereum/tx_prioritiser/rf_prioritiser.py b/mythril/laser/ethereum/tx_prioritiser/rf_prioritiser.py new file mode 100644 index 00000000..0820e62c --- /dev/null +++ b/mythril/laser/ethereum/tx_prioritiser/rf_prioritiser.py @@ -0,0 +1,62 @@ +import pickle +from sklearn.ensemble import RandomForestClassifier +import numpy as np +from itertools import chain + +import logging + +log = logging.getLogger(__name__) + + +class RfTxPrioritiser: + def __init__(self, contract, depth=3, model_path=None): + self.rf_path = None + self.contract = contract + self.depth = depth + + with open(model_path, "rb") as file: + self.model = pickle.load(file) + if self.contract.features is None: + log.info( + "There are no available features. Rf based Tx Prioritisation turned off." + ) + return None + self.preprocessed_features = self.preprocess_features(self.contract.features) + self.recent_predictions = [] + + def preprocess_features(self, features_dict): + flat_features = [] + for function_name, function_features in features_dict.items(): + function_features_values = list(function_features.values()) + flat_features.extend(function_features_values) + + return np.array(flat_features).reshape(1, -1) + + def __next__(self, address): + predictions_sequence = [] + current_features = np.concatenate( + [ + self.preprocessed_features, + np.array(self.recent_predictions).reshape(1, -1), + ], + axis=1, + ) + + for i in range(self.depth): + predictions = self.model.predict_proba(current_features) + most_likely_next_tx = np.argmax(predictions, axis=1)[0] + predictions_sequence.append(most_likely_next_tx) + current_features = np.concatenate( + [ + self.preprocessed_features, + np.array( + self.recent_predictions + predictions_sequence[: i + 1] + ).reshape(1, -1), + ], + axis=1, + ) + + self.recent_predictions.extend(predictions_sequence) + while len(self.recent_predictions) > self.depth: + self.recent_predictions.pop(0) + return predictions_sequence diff --git a/mythril/solidity/features.py b/mythril/solidity/features.py index 5fae7cda..afe140f0 100644 --- a/mythril/solidity/features.py +++ b/mythril/solidity/features.py @@ -17,8 +17,6 @@ class SolidityFeatureExtractor: self.find_variables_in_if(modifier_node) ) - print(modifier_vars) - for node in function_nodes: function_name = self.get_function_name(node) contains_selfdestruct = self.contains_selfdestruct(node) @@ -28,7 +26,7 @@ class SolidityFeatureExtractor: contains_staticcall = self.contains_staticcall(node) all_require_vars = self.find_variables_in_require(node) ether_vars = self.extract_address_variable(node) - print(ether_vars) + for modifier in node.get("modifiers", []): all_require_vars.update(modifier_vars[modifier["modifierName"]["name"]]) is_payable = self.is_function_payable(node) diff --git a/mythril/solidity/soliditycontract.py b/mythril/solidity/soliditycontract.py index 06dc8b22..a7f368ff 100644 --- a/mythril/solidity/soliditycontract.py +++ b/mythril/solidity/soliditycontract.py @@ -7,6 +7,7 @@ import mythril.laser.ethereum.util as helper from mythril.ethereum.evmcontract import EVMContract from mythril.ethereum.util import get_solc_json from mythril.exceptions import NoContractFoundError +from mythril.solidity.features import SolidityFeatureExtractor log = logging.getLogger(__name__) @@ -189,6 +190,13 @@ class SolidityContract(EVMContract): self.solc_indices = self.get_solc_indices(input_file, data) self.solc_json = data self.input_file = input_file + if "ast" in data["sources"][str(input_file)]: + # Not available in old solidity versions, around ~0.4.11 + self.features = SolidityFeatureExtractor( + data["sources"][str(input_file)]["ast"] + ).extract_features() + else: + self.features = None has_contract = False # If a contract name has been specified, find the bytecode of that specific contract diff --git a/mythril/support/support_args.py b/mythril/support/support_args.py index 5722a8e2..9add0cd9 100644 --- a/mythril/support/support_args.py +++ b/mythril/support/support_args.py @@ -22,6 +22,7 @@ class Args(object, metaclass=Singleton): self.solc_args = None self.disable_coverage_strategy = False self.disable_mutation_pruner = False + self.incremental_txs = True args = Args() diff --git a/requirements.txt b/requirements.txt index 7180de1e..83c8c450 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ jinja2>=2.9 MarkupSafe<2.1.0 mock mypy-extensions<1.0.0 +numpy persistent>=4.2.0 py-flags py-evm==0.5.0a1 @@ -32,6 +33,7 @@ pytest_mock requests rlp<3 semantic_version +scikit-learn transaction>=2.2.1 typing-extensions<4,>=3.7.4 z3-solver<4.12.2.0,>=4.8.8.0 diff --git a/tests/laser/tx_prioritisation_test.py b/tests/laser/tx_prioritisation_test.py new file mode 100644 index 00000000..db7bd297 --- /dev/null +++ b/tests/laser/tx_prioritisation_test.py @@ -0,0 +1,59 @@ +import pytest +import numpy as np +from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser +from unittest.mock import Mock, patch, mock_open + + +def mock_predict_proba(X): + print(X) + if X[0][-1] == 1: + return np.array([[0.1, 0.7, 0.1, 0.1]]) + elif X[0][-1] == 2: + return np.array([[0.1, 0.1, 0.7, 0.1]]) + else: + return np.array([[0.1, 0.1, 0.1, 0.7]]) + + +class MockSolidityContract: + def __init__(self, features): + self.features = features + + +@pytest.fixture +def rftp_instance(): + contract = MockSolidityContract( + features={"function1": {"feature1": 1, "feature2": 2}} + ) + + with patch("pickle.load") as mock_pickle_load, patch("builtins.open", mock_open()): + mock_model = Mock() + mock_model.predict_proba = mock_predict_proba + mock_pickle_load.return_value = mock_model + + rftp = RfTxPrioritiser(contract=contract, model_path="path/to/mock/model.pkl") + + return rftp + + +def test_preprocess_features(rftp_instance): + expected_features = np.array([[1, 2]]) + assert np.array_equal(rftp_instance.preprocessed_features, expected_features) + + +@pytest.mark.parametrize( + "address,previous_predictions,expected_sequence", + [ + (1, [], [2, 2, 2]), + (2, [], [2, 2, 2]), + (1, [0], [3, 3, 3]), + (2, [1], [1, 1, 1]), + (3, [1, 2], [2, 2, 2]), + (1, [0, 2, 5], [3, 3, 3]), + ], +) +def test_next_method(rftp_instance, address, previous_predictions, expected_sequence): + rftp_instance.recent_predictions = previous_predictions + predictions_sequence = rftp_instance.__next__(address=address) + + assert len(predictions_sequence) == 3 + assert predictions_sequence == expected_sequence