mirror of https://github.com/ConsenSys/mythril
Add rf based tx prioritiser (#1798)
* Add rf based tx prioritiser * Handle older solc version feature extraction gracefullypull/1799/head
parent
ebd7df9601
commit
6fa927e296
@ -0,0 +1 @@ |
|||||||
|
from .rf_prioritiser import RfTxPrioritiser |
@ -0,0 +1,62 @@ |
|||||||
|
import pickle |
||||||
|
from sklearn.ensemble import RandomForestClassifier |
||||||
|
import numpy as np |
||||||
|
from itertools import chain |
||||||
|
|
||||||
|
import logging |
||||||
|
|
||||||
|
log = logging.getLogger(__name__) |
||||||
|
|
||||||
|
|
||||||
|
class RfTxPrioritiser: |
||||||
|
def __init__(self, contract, depth=3, model_path=None): |
||||||
|
self.rf_path = None |
||||||
|
self.contract = contract |
||||||
|
self.depth = depth |
||||||
|
|
||||||
|
with open(model_path, "rb") as file: |
||||||
|
self.model = pickle.load(file) |
||||||
|
if self.contract.features is None: |
||||||
|
log.info( |
||||||
|
"There are no available features. Rf based Tx Prioritisation turned off." |
||||||
|
) |
||||||
|
return None |
||||||
|
self.preprocessed_features = self.preprocess_features(self.contract.features) |
||||||
|
self.recent_predictions = [] |
||||||
|
|
||||||
|
def preprocess_features(self, features_dict): |
||||||
|
flat_features = [] |
||||||
|
for function_name, function_features in features_dict.items(): |
||||||
|
function_features_values = list(function_features.values()) |
||||||
|
flat_features.extend(function_features_values) |
||||||
|
|
||||||
|
return np.array(flat_features).reshape(1, -1) |
||||||
|
|
||||||
|
def __next__(self, address): |
||||||
|
predictions_sequence = [] |
||||||
|
current_features = np.concatenate( |
||||||
|
[ |
||||||
|
self.preprocessed_features, |
||||||
|
np.array(self.recent_predictions).reshape(1, -1), |
||||||
|
], |
||||||
|
axis=1, |
||||||
|
) |
||||||
|
|
||||||
|
for i in range(self.depth): |
||||||
|
predictions = self.model.predict_proba(current_features) |
||||||
|
most_likely_next_tx = np.argmax(predictions, axis=1)[0] |
||||||
|
predictions_sequence.append(most_likely_next_tx) |
||||||
|
current_features = np.concatenate( |
||||||
|
[ |
||||||
|
self.preprocessed_features, |
||||||
|
np.array( |
||||||
|
self.recent_predictions + predictions_sequence[: i + 1] |
||||||
|
).reshape(1, -1), |
||||||
|
], |
||||||
|
axis=1, |
||||||
|
) |
||||||
|
|
||||||
|
self.recent_predictions.extend(predictions_sequence) |
||||||
|
while len(self.recent_predictions) > self.depth: |
||||||
|
self.recent_predictions.pop(0) |
||||||
|
return predictions_sequence |
@ -0,0 +1,59 @@ |
|||||||
|
import pytest |
||||||
|
import numpy as np |
||||||
|
from mythril.laser.ethereum.tx_prioritiser import RfTxPrioritiser |
||||||
|
from unittest.mock import Mock, patch, mock_open |
||||||
|
|
||||||
|
|
||||||
|
def mock_predict_proba(X): |
||||||
|
print(X) |
||||||
|
if X[0][-1] == 1: |
||||||
|
return np.array([[0.1, 0.7, 0.1, 0.1]]) |
||||||
|
elif X[0][-1] == 2: |
||||||
|
return np.array([[0.1, 0.1, 0.7, 0.1]]) |
||||||
|
else: |
||||||
|
return np.array([[0.1, 0.1, 0.1, 0.7]]) |
||||||
|
|
||||||
|
|
||||||
|
class MockSolidityContract: |
||||||
|
def __init__(self, features): |
||||||
|
self.features = features |
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture |
||||||
|
def rftp_instance(): |
||||||
|
contract = MockSolidityContract( |
||||||
|
features={"function1": {"feature1": 1, "feature2": 2}} |
||||||
|
) |
||||||
|
|
||||||
|
with patch("pickle.load") as mock_pickle_load, patch("builtins.open", mock_open()): |
||||||
|
mock_model = Mock() |
||||||
|
mock_model.predict_proba = mock_predict_proba |
||||||
|
mock_pickle_load.return_value = mock_model |
||||||
|
|
||||||
|
rftp = RfTxPrioritiser(contract=contract, model_path="path/to/mock/model.pkl") |
||||||
|
|
||||||
|
return rftp |
||||||
|
|
||||||
|
|
||||||
|
def test_preprocess_features(rftp_instance): |
||||||
|
expected_features = np.array([[1, 2]]) |
||||||
|
assert np.array_equal(rftp_instance.preprocessed_features, expected_features) |
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize( |
||||||
|
"address,previous_predictions,expected_sequence", |
||||||
|
[ |
||||||
|
(1, [], [2, 2, 2]), |
||||||
|
(2, [], [2, 2, 2]), |
||||||
|
(1, [0], [3, 3, 3]), |
||||||
|
(2, [1], [1, 1, 1]), |
||||||
|
(3, [1, 2], [2, 2, 2]), |
||||||
|
(1, [0, 2, 5], [3, 3, 3]), |
||||||
|
], |
||||||
|
) |
||||||
|
def test_next_method(rftp_instance, address, previous_predictions, expected_sequence): |
||||||
|
rftp_instance.recent_predictions = previous_predictions |
||||||
|
predictions_sequence = rftp_instance.__next__(address=address) |
||||||
|
|
||||||
|
assert len(predictions_sequence) == 3 |
||||||
|
assert predictions_sequence == expected_sequence |
Loading…
Reference in new issue