Using cache search strategy (#1714)

* Fix JSON serialisation issue

* Add cache

* Add tests

* Remove a debug statement
pull/1718/head
Nikhil Parasaram 2 years ago committed by GitHub
parent 2fb67f66fa
commit 0fac8e3d6d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      mythril/analysis/symbolic.py
  2. 4
      mythril/interfaces/cli.py
  3. 12
      mythril/laser/ethereum/state/constraints.py
  4. 5
      mythril/laser/ethereum/strategy/__init__.py
  5. 64
      mythril/laser/ethereum/strategy/basic.py
  6. 11
      mythril/laser/ethereum/strategy/beam.py
  7. 47
      mythril/laser/ethereum/strategy/constraint_strategy.py
  8. 6
      mythril/laser/ethereum/svm.py
  9. 18
      mythril/support/model.py
  10. 46
      mythril/support/support_utils.py
  11. 14
      tests/integration_tests/analysis_tests.py

@ -13,6 +13,7 @@ from mythril.laser.ethereum.strategy.basic import (
ReturnWeightedRandomStrategy,
BasicSearchStrategy,
)
from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintStrategy
from mythril.laser.ethereum.strategy.beam import BeamSearch
from mythril.laser.ethereum.natives import PRECOMPILE_COUNT
from mythril.laser.ethereum.transaction.symbolic import ACTORS
@ -94,6 +95,8 @@ class SymExecWrapper:
elif "beam-search: " in strategy:
beam_width = int(strategy.split("beam-search: ")[1])
s_strategy = BeamSearch
elif "delayed" in strategy:
s_strategy = DelayConstraintStrategy
else:
raise ValueError("Invalid strategy argument supplied")

@ -449,7 +449,7 @@ def add_analysis_args(options):
options.add_argument(
"--strategy",
choices=["dfs", "bfs", "naive-random", "weighted-random"],
choices=["dfs", "bfs", "naive-random", "weighted-random", "delayed"],
default="bfs",
help="Symbolic execution strategy",
)
@ -622,7 +622,7 @@ def validate_args(args: Namespace):
args.outform,
"The transaction sequence is in incorrect format, It should be "
"[list of possible function hashes in 1st transaction, "
"list of possible func hashes in 2nd tx, ..]"
"list of possible func hashes in 2nd tx, ...] "
"If any list is empty then all possible functions are considered for that transaction",
)
if len(args.transaction_sequences) != args.transaction_count:

@ -42,6 +42,18 @@ class Constraints(list):
return False
return True
def get_model(self, solver_timeout=None) -> bool:
"""
:param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout
:return: True/False based on the existence of solution of constraints
"""
try:
return get_model(self, solver_timeout=solver_timeout)
except SolverTimeOutException:
return None
except UnsatError:
return None
def append(self, constraint: Union[bool, Bool]) -> None:
"""

@ -20,6 +20,9 @@ class BasicSearchStrategy(ABC):
""""""
raise NotImplementedError("Must be implemented by a subclass")
def run_check(self):
return True
def __next__(self):
try:
global_state = self.get_strategic_global_state()
@ -43,7 +46,7 @@ class CriterionSearchStrategy(BasicSearchStrategy):
if self._satisfied_criterion:
raise StopIteration
try:
global_state = self.get_strategic_global_state()
return self.get_strategic_global_state()
except StopIteration:
raise StopIteration

@ -20,6 +20,13 @@ class DepthFirstSearchStrategy(BasicSearchStrategy):
"""
return self.work_list.pop()
def view_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
return self.work_list[-1]
class BreadthFirstSearchStrategy(BasicSearchStrategy):
"""Implements a breadth first search strategy I.E.
@ -34,17 +41,44 @@ class BreadthFirstSearchStrategy(BasicSearchStrategy):
"""
return self.work_list.pop(0)
def view_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
return self.work_list[0]
class ReturnRandomNaivelyStrategy(BasicSearchStrategy):
"""chooses a random state from the worklist with equal likelihood."""
def __init__(self, work_list, max_depth, **kwargs):
super().__init__(work_list, max_depth, **kwargs)
self.previous_random_value = -1
def get_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
if len(self.work_list) > 0:
return self.work_list.pop(randrange(len(self.work_list)))
if self.previous_random_value == -1:
return self.work_list.pop(randrange(len(self.work_list)))
else:
new_state = self.work_list.pop(self.previous_random_value)
self.previous_random_value = -1
return new_state
else:
raise IndexError
def view_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
if len(self.work_list) > 0:
self.previous_random_value = randrange(len(self.work_list))
return self.work_list[self.previous_random_value]
else:
raise IndexError
@ -53,6 +87,10 @@ class ReturnWeightedRandomStrategy(BasicSearchStrategy):
"""chooses a random state from the worklist with likelihood based on
inverse proportion to depth."""
def __init__(self, work_list, max_depth, **kwargs):
super().__init__(work_list, max_depth, **kwargs)
self.previous_random_value = -1
def get_strategic_global_state(self) -> GlobalState:
"""
@ -61,6 +99,24 @@ class ReturnWeightedRandomStrategy(BasicSearchStrategy):
probability_distribution = [
1 / (global_state.mstate.depth + 1) for global_state in self.work_list
]
return self.work_list.pop(
choices(range(len(self.work_list)), probability_distribution)[0]
)
if self.previous_random_value != -1:
ns = self.work_list.pop(self.previous_random_value)
self.previous_random_value = -1
return ns
else:
return self.work_list.pop(
choices(range(len(self.work_list)), probability_distribution)[0]
)
def view_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
probability_distribution = [
1 / (global_state.mstate.depth + 1) for global_state in self.work_list
]
self.previous_random_value = choices(
range(len(self.work_list)), probability_distribution
)[0]
return self.work_list[self.previous_random_value]

@ -19,6 +19,17 @@ class BeamSearch(BasicSearchStrategy):
self.work_list.sort(key=lambda state: self.beam_priority(state), reverse=True)
del self.work_list[self.beam_width :]
def view_strategic_global_state(self) -> GlobalState:
"""
:return:
"""
self.sort_and_eliminate_states()
if len(self.work_list) > 0:
return self.work_list[0]
else:
raise IndexError
def get_strategic_global_state(self) -> GlobalState:
"""

@ -0,0 +1,47 @@
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy
from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.transaction import ContractCreationTransaction
from mythril.laser.smt import And, simplify
from mythril.support.support_utils import ModelCache
from typing import Dict, cast, List
from collections import OrderedDict
from copy import copy, deepcopy
from functools import lru_cache
import logging
import z3
from time import time
log = logging.getLogger(__name__)
class DelayConstraintStrategy(BasicSearchStrategy):
def __init__(self, work_list, max_depth, **kwargs):
super().__init__(work_list, max_depth)
self.model_cache = ModelCache()
self.pending_worklist = []
log.info("Loaded search strategy extension: DelayConstraintStrategy")
def get_strategic_global_state(self) -> GlobalState:
"""Returns the next state
:return: Global state
"""
while True:
if len(self.work_list) == 0:
state = self.pending_worklist.pop(0)
model = state.world_state.constraints.get_model()
if model is not None:
self.model_cache.put(model, 1)
return state
else:
state = self.work_list[0]
c_val = self.model_cache.check_quick_sat(
simplify(And(*state.world_state.constraints)).raw
)
if c_val == False:
self.pending_worklist.append(state)
self.work_list.pop(0)
else:
return self.work_list.pop(0)

@ -17,6 +17,7 @@ from mythril.laser.plugin.signals import PluginSkipWorldState, PluginSkipState
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy
from mythril.laser.ethereum.strategy.constraint_strategy import DelayConstraintStrategy
from abc import ABCMeta
from mythril.laser.ethereum.time_handler import time_handler
@ -309,7 +310,10 @@ class LaserEVM:
except NotImplementedError:
log.debug("Encountered unimplemented instruction")
continue
if len(new_states) > 1 and random.uniform(0, 1) < args.pruning_factor:
if self.strategy.run_check() and (
len(new_states) > 1 and random.uniform(0, 1) < args.pruning_factor
):
new_states = [
state
for state in new_states

@ -1,15 +1,21 @@
from functools import lru_cache
from z3 import sat, unknown
from z3 import sat, unknown, is_true
from pathlib import Path
from mythril.support.support_utils import ModelCache
from mythril.support.support_args import args
from mythril.laser.smt import Optimize
from mythril.laser.smt import Optimize, simplify, And
from mythril.laser.ethereum.time_handler import time_handler
from mythril.exceptions import UnsatError, SolverTimeOutException
import logging
from collections import OrderedDict
from copy import deepcopy
from time import time
log = logging.getLogger(__name__)
# LRU cache works great when used in powers of 2
model_cache = ModelCache()
@lru_cache(maxsize=2**23)
@ -42,6 +48,11 @@ def get_model(
constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool]
if len(maximize) + len(minimize) == 0:
ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw)
if ret_model:
return ret_model
for constraint in constraints:
s.add(constraint)
for e in minimize:
@ -63,6 +74,7 @@ def get_model(
result = s.check()
if result == sat:
model_cache.model_cache.put(s.model(), 1)
return s.model()
elif result == unknown:
log.debug("Timeout/Error encountered while solving expression using z3")

@ -1,8 +1,12 @@
"""This module contains utility functions for the Mythril support package."""
from collections import OrderedDict
from copy import deepcopy
from eth_hash.auto import keccak
from functools import lru_cache
from typing import Dict
from z3 import is_true, simplify, And
import logging
from eth_hash.auto import keccak
log = logging.getLogger(__name__)
@ -27,6 +31,46 @@ class Singleton(type):
return cls._instances[cls]
class LRUCache:
def __init__(self, size):
self.size = size
self.lru_cache = OrderedDict()
def get(self, key):
try:
value = self.lru_cache.pop(key)
self.lru_cache[key] = value
return value
except KeyError:
return -1
def put(self, key, value):
try:
self.lru_cache.pop(key)
except KeyError:
if len(self.lru_cache) >= self.size:
self.lru_cache.popitem(last=False)
self.lru_cache[key] = value
class ModelCache:
def __init__(self):
self.model_cache = LRUCache(size=100)
@lru_cache(maxsize=2**10)
def check_quick_sat(self, constraints) -> bool:
model_list = list(reversed(self.model_cache.lru_cache.keys()))
for model in reversed(self.model_cache.lru_cache.keys()):
model_copy = deepcopy(model)
if is_true(model_copy.eval(constraints, model_completion=True)):
self.model_cache.put(model, self.model_cache.get(model) + 1)
return model
return False
def put(self, key, value):
self.model_cache.put(key, value)
@lru_cache(maxsize=2**10)
def get_code_hash(code) -> str:
"""

@ -55,3 +55,17 @@ def test_analysis(file_name, tx_data, calldata):
0
]
assert test_case["steps"][tx_data["TX_OUTPUT"]]["input"] == calldata
@pytest.mark.parametrize("file_name, tx_data, calldata", test_data)
def test_analysis_delayed(file_name, tx_data, calldata):
bytecode_file = str(TESTDATA / "inputs" / file_name)
command = f"""python3 {MYTH} analyze -f {bytecode_file} -t {tx_data["TX_COUNT"]} -o jsonv2 -m {tx_data["MODULE"]} --solver-timeout 60000 --strategy delayed"""
output = json.loads(output_of(command))
assert len(output[0]["issues"]) == tx_data["ISSUE_COUNT"]
if calldata:
test_case = output[0]["issues"][tx_data["ISSUE_NUMBER"]]["extra"]["testCases"][
0
]
assert test_case["steps"][tx_data["TX_OUTPUT"]]["input"] == calldata

Loading…
Cancel
Save