Add rf based tx prioritiser (#1798)

* Add rf based tx prioritiser

* Handle older solc version feature extraction gracefully
pull/1799/head
Nikhil Parasaram 1 year ago committed by GitHub
parent ebd7df9601
commit 6fa927e296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 8
      mythril/analysis/symbolic.py
  2. 38
      mythril/laser/ethereum/svm.py
  3. 1
      mythril/laser/ethereum/tx_prioritiser/__init__.py
  4. 62
      mythril/laser/ethereum/tx_prioritiser/rf_prioritiser.py
  5. 4
      mythril/solidity/features.py
  6. 8
      mythril/solidity/soliditycontract.py
  7. 1
      mythril/support/support_args.py
  8. 2
      requirements.txt
  9. 59
      tests/laser/tx_prioritisation_test.py

@ -17,7 +17,7 @@ from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintS
from mythril.laser.ethereum.strategy.beam import BeamSearch
from mythril.laser.ethereum.natives import PRECOMPILE_COUNT
from mythril.laser.ethereum.transaction.symbolic import ACTORS
from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser
from mythril.laser.plugin.loader import LaserPluginLoader
from mythril.laser.plugin.plugins import (
@ -100,6 +100,11 @@ class SymExecWrapper:
else:
raise ValueError("Invalid strategy argument supplied")
if args.incremental_txs is False:
tx_strategy = RfTxPrioritiser(contract)
else:
tx_strategy = None
creator_account = Account(
hex(ACTORS.creator.value), "", dynamic_loader=None, contract_name=None
)
@ -128,6 +133,7 @@ class SymExecWrapper:
transaction_count=transaction_count,
requires_statespace=requires_statespace,
beam_width=beam_width,
tx_strategy=tx_strategy,
)
if loop_bound is not None:

@ -63,6 +63,7 @@ class LaserEVM:
iprof=None,
use_reachability_check=True,
beam_width=None,
tx_strategy=None,
) -> None:
"""
Initializes the laser evm object
@ -75,6 +76,9 @@ class LaserEVM:
:param transaction_count: The amount of transactions to execute
:param requires_statespace: Variable indicating whether the statespace should be recorded
:param iprof: Instruction Profiler
:param use_reachability_check: Runs reachability check by solving constraints
:param beam_width: The beam width for search strategies
:param tx_strategy: A global tx prioritisation strategy
"""
self.execution_info: List[ExecutionInfo] = []
@ -87,6 +91,7 @@ class LaserEVM:
self.strategy = strategy(self.work_list, max_depth, beam_width=beam_width)
self.max_depth = max_depth
self.transaction_count = transaction_count
self.tx_strategy = tx_strategy
self.execution_timeout = execution_timeout or 0
self.create_timeout = create_timeout or 0
@ -221,20 +226,35 @@ class LaserEVM:
"""
for hook in self._start_exec_trans_hooks:
hook()
if self.executed_transactions is False:
self._execute_transactions(address)
if self.tx_strategy is None:
if self.executed_transactions is False:
self.time = datetime.now()
self._execute_transactions_incremental(
address, txs=args.transaction_sequences
)
else:
self.time = datetime.now()
self._execute_transactions_non_ordered(address)
for hook in self._stop_exec_trans_hooks:
hook()
def _execute_transactions(self, address):
"""This function executes multiple transactions on the address
def _execute_transactions_ordered(self, address):
"""
This function executes multiple transactions non-incrementally, using some type priority ordering
:param address: Address of the contract
:return:
"""
for txs in self.strategy:
log.info(f"Executing the sequence: {txs}")
self._execute_transactions_incremental(address, txs=tx)
def _execute_transactions_incremental(self, address, txs=None):
"""This function executes multiple transactions incrementally on the address
:param address: Address of the contract
:return:
"""
self.time = datetime.now()
for i in range(self.transaction_count):
if len(self.open_states) == 0:
@ -267,9 +287,7 @@ class LaserEVM:
i, len(self.open_states)
)
)
func_hashes = (
args.transaction_sequences[i] if args.transaction_sequences else None
)
func_hashes = txs[i] if txs else None
if func_hashes:
for itr, func_hash in enumerate(func_hashes):

@ -0,0 +1 @@
from .rf_prioritiser import RfTxPrioritiser

@ -0,0 +1,62 @@
import pickle
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from itertools import chain
import logging
log = logging.getLogger(__name__)
class RfTxPrioritiser:
def __init__(self, contract, depth=3, model_path=None):
self.rf_path = None
self.contract = contract
self.depth = depth
with open(model_path, "rb") as file:
self.model = pickle.load(file)
if self.contract.features is None:
log.info(
"There are no available features. Rf based Tx Prioritisation turned off."
)
return None
self.preprocessed_features = self.preprocess_features(self.contract.features)
self.recent_predictions = []
def preprocess_features(self, features_dict):
flat_features = []
for function_name, function_features in features_dict.items():
function_features_values = list(function_features.values())
flat_features.extend(function_features_values)
return np.array(flat_features).reshape(1, -1)
def __next__(self, address):
predictions_sequence = []
current_features = np.concatenate(
[
self.preprocessed_features,
np.array(self.recent_predictions).reshape(1, -1),
],
axis=1,
)
for i in range(self.depth):
predictions = self.model.predict_proba(current_features)
most_likely_next_tx = np.argmax(predictions, axis=1)[0]
predictions_sequence.append(most_likely_next_tx)
current_features = np.concatenate(
[
self.preprocessed_features,
np.array(
self.recent_predictions + predictions_sequence[: i + 1]
).reshape(1, -1),
],
axis=1,
)
self.recent_predictions.extend(predictions_sequence)
while len(self.recent_predictions) > self.depth:
self.recent_predictions.pop(0)
return predictions_sequence

@ -17,8 +17,6 @@ class SolidityFeatureExtractor:
self.find_variables_in_if(modifier_node)
)
print(modifier_vars)
for node in function_nodes:
function_name = self.get_function_name(node)
contains_selfdestruct = self.contains_selfdestruct(node)
@ -28,7 +26,7 @@ class SolidityFeatureExtractor:
contains_staticcall = self.contains_staticcall(node)
all_require_vars = self.find_variables_in_require(node)
ether_vars = self.extract_address_variable(node)
print(ether_vars)
for modifier in node.get("modifiers", []):
all_require_vars.update(modifier_vars[modifier["modifierName"]["name"]])
is_payable = self.is_function_payable(node)

@ -7,6 +7,7 @@ import mythril.laser.ethereum.util as helper
from mythril.ethereum.evmcontract import EVMContract
from mythril.ethereum.util import get_solc_json
from mythril.exceptions import NoContractFoundError
from mythril.solidity.features import SolidityFeatureExtractor
log = logging.getLogger(__name__)
@ -189,6 +190,13 @@ class SolidityContract(EVMContract):
self.solc_indices = self.get_solc_indices(input_file, data)
self.solc_json = data
self.input_file = input_file
if "ast" in data["sources"][str(input_file)]:
# Not available in old solidity versions, around ~0.4.11
self.features = SolidityFeatureExtractor(
data["sources"][str(input_file)]["ast"]
).extract_features()
else:
self.features = None
has_contract = False
# If a contract name has been specified, find the bytecode of that specific contract

@ -22,6 +22,7 @@ class Args(object, metaclass=Singleton):
self.solc_args = None
self.disable_coverage_strategy = False
self.disable_mutation_pruner = False
self.incremental_txs = True
args = Args()

@ -20,6 +20,7 @@ jinja2>=2.9
MarkupSafe<2.1.0
mock
mypy-extensions<1.0.0
numpy
persistent>=4.2.0
py-flags
py-evm==0.5.0a1
@ -32,6 +33,7 @@ pytest_mock
requests
rlp<3
semantic_version
scikit-learn
transaction>=2.2.1
typing-extensions<4,>=3.7.4
z3-solver<4.12.2.0,>=4.8.8.0

@ -0,0 +1,59 @@
import pytest
import numpy as np
from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser
from unittest.mock import Mock, patch, mock_open
def mock_predict_proba(X):
print(X)
if X[0][-1] == 1:
return np.array([[0.1, 0.7, 0.1, 0.1]])
elif X[0][-1] == 2:
return np.array([[0.1, 0.1, 0.7, 0.1]])
else:
return np.array([[0.1, 0.1, 0.1, 0.7]])
class MockSolidityContract:
def __init__(self, features):
self.features = features
@pytest.fixture
def rftp_instance():
contract = MockSolidityContract(
features={"function1": {"feature1": 1, "feature2": 2}}
)
with patch("pickle.load") as mock_pickle_load, patch("builtins.open", mock_open()):
mock_model = Mock()
mock_model.predict_proba = mock_predict_proba
mock_pickle_load.return_value = mock_model
rftp = RfTxPrioritiser(contract=contract, model_path="path/to/mock/model.pkl")
return rftp
def test_preprocess_features(rftp_instance):
expected_features = np.array([[1, 2]])
assert np.array_equal(rftp_instance.preprocessed_features, expected_features)
@pytest.mark.parametrize(
"address,previous_predictions,expected_sequence",
[
(1, [], [2, 2, 2]),
(2, [], [2, 2, 2]),
(1, [0], [3, 3, 3]),
(2, [1], [1, 1, 1]),
(3, [1, 2], [2, 2, 2]),
(1, [0, 2, 5], [3, 3, 3]),
],
)
def test_next_method(rftp_instance, address, previous_predictions, expected_sequence):
rftp_instance.recent_predictions = previous_predictions
predictions_sequence = rftp_instance.__next__(address=address)
assert len(predictions_sequence) == 3
assert predictions_sequence == expected_sequence
Loading…
Cancel
Save