mirror of https://github.com/ConsenSys/mythril
Add rf based tx prioritiser (#1798)
* Add rf based tx prioritiser * Handle older solc version feature extraction gracefullypull/1799/head
parent
ebd7df9601
commit
6fa927e296
@ -0,0 +1 @@ |
||||
from .rf_prioritiser import RfTxPrioritiser |
@ -0,0 +1,62 @@ |
||||
import pickle |
||||
from sklearn.ensemble import RandomForestClassifier |
||||
import numpy as np |
||||
from itertools import chain |
||||
|
||||
import logging |
||||
|
||||
log = logging.getLogger(__name__) |
||||
|
||||
|
||||
class RfTxPrioritiser: |
||||
def __init__(self, contract, depth=3, model_path=None): |
||||
self.rf_path = None |
||||
self.contract = contract |
||||
self.depth = depth |
||||
|
||||
with open(model_path, "rb") as file: |
||||
self.model = pickle.load(file) |
||||
if self.contract.features is None: |
||||
log.info( |
||||
"There are no available features. Rf based Tx Prioritisation turned off." |
||||
) |
||||
return None |
||||
self.preprocessed_features = self.preprocess_features(self.contract.features) |
||||
self.recent_predictions = [] |
||||
|
||||
def preprocess_features(self, features_dict): |
||||
flat_features = [] |
||||
for function_name, function_features in features_dict.items(): |
||||
function_features_values = list(function_features.values()) |
||||
flat_features.extend(function_features_values) |
||||
|
||||
return np.array(flat_features).reshape(1, -1) |
||||
|
||||
def __next__(self, address): |
||||
predictions_sequence = [] |
||||
current_features = np.concatenate( |
||||
[ |
||||
self.preprocessed_features, |
||||
np.array(self.recent_predictions).reshape(1, -1), |
||||
], |
||||
axis=1, |
||||
) |
||||
|
||||
for i in range(self.depth): |
||||
predictions = self.model.predict_proba(current_features) |
||||
most_likely_next_tx = np.argmax(predictions, axis=1)[0] |
||||
predictions_sequence.append(most_likely_next_tx) |
||||
current_features = np.concatenate( |
||||
[ |
||||
self.preprocessed_features, |
||||
np.array( |
||||
self.recent_predictions + predictions_sequence[: i + 1] |
||||
).reshape(1, -1), |
||||
], |
||||
axis=1, |
||||
) |
||||
|
||||
self.recent_predictions.extend(predictions_sequence) |
||||
while len(self.recent_predictions) > self.depth: |
||||
self.recent_predictions.pop(0) |
||||
return predictions_sequence |
@ -0,0 +1,59 @@ |
||||
import pytest |
||||
import numpy as np |
||||
from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser |
||||
from unittest.mock import Mock, patch, mock_open |
||||
|
||||
|
||||
def mock_predict_proba(X): |
||||
print(X) |
||||
if X[0][-1] == 1: |
||||
return np.array([[0.1, 0.7, 0.1, 0.1]]) |
||||
elif X[0][-1] == 2: |
||||
return np.array([[0.1, 0.1, 0.7, 0.1]]) |
||||
else: |
||||
return np.array([[0.1, 0.1, 0.1, 0.7]]) |
||||
|
||||
|
||||
class MockSolidityContract: |
||||
def __init__(self, features): |
||||
self.features = features |
||||
|
||||
|
||||
@pytest.fixture |
||||
def rftp_instance(): |
||||
contract = MockSolidityContract( |
||||
features={"function1": {"feature1": 1, "feature2": 2}} |
||||
) |
||||
|
||||
with patch("pickle.load") as mock_pickle_load, patch("builtins.open", mock_open()): |
||||
mock_model = Mock() |
||||
mock_model.predict_proba = mock_predict_proba |
||||
mock_pickle_load.return_value = mock_model |
||||
|
||||
rftp = RfTxPrioritiser(contract=contract, model_path="path/to/mock/model.pkl") |
||||
|
||||
return rftp |
||||
|
||||
|
||||
def test_preprocess_features(rftp_instance): |
||||
expected_features = np.array([[1, 2]]) |
||||
assert np.array_equal(rftp_instance.preprocessed_features, expected_features) |
||||
|
||||
|
||||
@pytest.mark.parametrize( |
||||
"address,previous_predictions,expected_sequence", |
||||
[ |
||||
(1, [], [2, 2, 2]), |
||||
(2, [], [2, 2, 2]), |
||||
(1, [0], [3, 3, 3]), |
||||
(2, [1], [1, 1, 1]), |
||||
(3, [1, 2], [2, 2, 2]), |
||||
(1, [0, 2, 5], [3, 3, 3]), |
||||
], |
||||
) |
||||
def test_next_method(rftp_instance, address, previous_predictions, expected_sequence): |
||||
rftp_instance.recent_predictions = previous_predictions |
||||
predictions_sequence = rftp_instance.__next__(address=address) |
||||
|
||||
assert len(predictions_sequence) == 3 |
||||
assert predictions_sequence == expected_sequence |
Loading…
Reference in new issue