Merge branch 'develop' into specify-attacker-creator-address

specify-attacker-creator-address
Nathan 5 years ago committed by GitHub
commit ed06c3148e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      mythril/analysis/modules/exceptions.py
  2. 4
      mythril/analysis/modules/suicide.py
  3. 58
      mythril/analysis/solver.py
  4. 3
      mythril/laser/ethereum/cfg.py
  5. 38
      mythril/laser/ethereum/instructions.py
  6. 117
      mythril/laser/ethereum/keccak_function_manager.py
  7. 31
      mythril/laser/ethereum/state/account.py
  8. 5
      mythril/laser/ethereum/state/constraints.py
  9. 7
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  10. 7
      mythril/laser/ethereum/svm.py
  11. 3
      mythril/laser/ethereum/transaction/concolic.py
  12. 2
      mythril/laser/ethereum/transaction/symbolic.py
  13. 22
      mythril/laser/smt/__init__.py
  14. 5
      mythril/laser/smt/bitvec_helper.py
  15. 25
      mythril/laser/smt/function.py
  16. 4
      mythril/laser/smt/solver/solver.py
  17. 5
      tests/instructions/create_test.py
  18. 11
      tests/laser/evm_testsuite/evm_test.py
  19. 138
      tests/laser/keccak_tests.py

@ -55,11 +55,9 @@ class ReachableExceptionsModule(DetectionModule):
"Note that explicit `assert()` should only be used to check invariants. " "Note that explicit `assert()` should only be used to check invariants. "
"Use `require()` for regular input checking." "Use `require()` for regular input checking."
) )
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints state, state.mstate.constraints
) )
issue = Issue( issue = Issue(
contract=state.environment.active_account.contract_name, contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name, function_name=state.environment.active_function_name,

@ -9,7 +9,7 @@ from mythril.laser.ethereum.transaction.transaction_models import (
ContractCreationTransaction, ContractCreationTransaction,
) )
import logging import logging
import json
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -69,7 +69,6 @@ class SuicideModule(DetectionModule):
for tx in state.world_state.transaction_sequence: for tx in state.world_state.transaction_sequence:
if not isinstance(tx, ContractCreationTransaction): if not isinstance(tx, ContractCreationTransaction):
constraints.append(tx.caller == ACTORS.attacker) constraints.append(tx.caller == ACTORS.attacker)
try: try:
try: try:
transaction_sequence = solver.get_transaction_sequence( transaction_sequence = solver.get_transaction_sequence(
@ -85,7 +84,6 @@ class SuicideModule(DetectionModule):
state, state.mstate.constraints + constraints state, state.mstate.constraints + constraints
) )
description_tail = "Arbitrary senders can kill this contract." description_tail = "Arbitrary senders can kill this contract."
issue = Issue( issue = Issue(
contract=state.environment.active_account.contract_name, contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name, function_name=state.environment.active_function_name,

@ -1,12 +1,16 @@
"""This module contains analysis module helpers to solve path constraints.""" """This module contains analysis module helpers to solve path constraints."""
from functools import lru_cache from functools import lru_cache
from typing import Dict, Tuple, Union from typing import Dict, List, Tuple, Union
from z3 import sat, unknown, FuncInterp from z3 import sat, unknown, FuncInterp
import z3 import z3
from mythril.analysis.analysis_args import analysis_args from mythril.analysis.analysis_args import analysis_args
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.keccak_function_manager import (
keccak_function_manager,
hash_matcher,
)
from mythril.laser.ethereum.transaction import BaseTransaction from mythril.laser.ethereum.transaction import BaseTransaction
from mythril.laser.smt import UGE, Optimize, symbol_factory from mythril.laser.smt import UGE, Optimize, symbol_factory
from mythril.laser.ethereum.time_handler import time_handler from mythril.laser.ethereum.time_handler import time_handler
@ -18,6 +22,7 @@ import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# LRU cache works great when used in powers of 2 # LRU cache works great when used in powers of 2
@lru_cache(maxsize=2 ** 23) @lru_cache(maxsize=2 ** 23)
def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True): def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True):
@ -48,7 +53,6 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True
s.minimize(e) s.minimize(e)
for e in maximize: for e in maximize:
s.maximize(e) s.maximize(e)
result = s.check() result = s.check()
if result == sat: if result == sat:
return s.model() return s.model()
@ -97,7 +101,6 @@ def get_transaction_sequence(
tx_constraints, minimize = _set_minimisation_constraints( tx_constraints, minimize = _set_minimisation_constraints(
transaction_sequence, constraints.copy(), [], 5000, global_state.world_state transaction_sequence, constraints.copy(), [], 5000, global_state.world_state
) )
try: try:
model = get_model(tx_constraints, minimize=minimize) model = get_model(tx_constraints, minimize=minimize)
except UnsatError: except UnsatError:
@ -122,12 +125,59 @@ def get_transaction_sequence(
).as_long() ).as_long()
concrete_initial_state = _get_concrete_state(initial_accounts, min_price_dict) concrete_initial_state = _get_concrete_state(initial_accounts, min_price_dict)
if isinstance(transaction_sequence[0], ContractCreationTransaction):
code = transaction_sequence[0].code
_replace_with_actual_sha(concrete_transactions, model, code)
else:
_replace_with_actual_sha(concrete_transactions, model)
steps = {"initialState": concrete_initial_state, "steps": concrete_transactions} steps = {"initialState": concrete_initial_state, "steps": concrete_transactions}
return steps return steps
def _replace_with_actual_sha(
concrete_transactions: List[Dict[str, str]], model: z3.Model, code=None
):
for tx in concrete_transactions:
if hash_matcher not in tx["input"]:
continue
if code is not None and code.bytecode in tx["input"]:
s_index = len(code.bytecode) + 2
else:
s_index = 10
for i in range(s_index, len(tx["input"])):
data_slice = tx["input"][i : i + 64]
if hash_matcher not in data_slice or len(data_slice) != 64:
continue
find_input = symbol_factory.BitVecVal(int(data_slice, 16), 256)
input_ = None
for size in keccak_function_manager.store_function:
_, inverse = keccak_function_manager.get_function(size)
try:
input_ = symbol_factory.BitVecVal(
model.eval(inverse(find_input).raw).as_long(), size
)
except AttributeError:
continue
hex_input = hex(input_.value)[2:]
found = False
for new_tx in concrete_transactions:
if hex_input in new_tx["input"]:
found = True
break
if found:
break
if input_ is None:
continue
keccak = keccak_function_manager.find_concrete_keccak(input_)
hex_keccak = hex(keccak.value)[2:]
if len(hex_keccak) != 64:
hex_keccak = "0" * (64 - len(hex_keccak)) + hex_keccak
tx["input"] = tx["input"][:s_index] + tx["input"][s_index:].replace(
tx["input"][i : 64 + i], hex_keccak
)
def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]): def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]):
""" Gets a concrete state """ """ Gets a concrete state """
accounts = {} accounts = {}

@ -2,6 +2,7 @@
from enum import Enum from enum import Enum
from typing import Dict, List, TYPE_CHECKING from typing import Dict, List, TYPE_CHECKING
from mythril.laser.ethereum.state.constraints import Constraints
from flags import Flags from flags import Flags
if TYPE_CHECKING: if TYPE_CHECKING:
@ -46,7 +47,7 @@ class Node:
:param start_addr: :param start_addr:
:param constraints: :param constraints:
""" """
constraints = constraints if constraints else [] constraints = constraints if constraints else Constraints()
self.contract_name = contract_name self.contract_name = contract_name
self.start_addr = start_addr self.start_addr = start_addr
self.states = [] # type: List[GlobalState] self.states = [] # type: List[GlobalState]

@ -34,6 +34,7 @@ from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCall
import mythril.laser.ethereum.util as helper import mythril.laser.ethereum.util as helper
from mythril.laser.ethereum import util from mythril.laser.ethereum import util
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
from mythril.laser.ethereum.call import get_call_parameters, native_call, get_call_data from mythril.laser.ethereum.call import get_call_parameters, native_call, get_call_data
from mythril.laser.ethereum.evm_exceptions import ( from mythril.laser.ethereum.evm_exceptions import (
VmException, VmException,
@ -982,7 +983,7 @@ class Instruction:
if isinstance(op0, Expression): if isinstance(op0, Expression):
op0 = simplify(op0) op0 = simplify(op0)
state.stack.append( state.stack.append(
symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256) symbol_factory.BitVecSym("KECCAC_mem[{}]".format(hash(op0)), 256)
) )
gas_tuple = get_opcode_gas("SHA3") gas_tuple = get_opcode_gas("SHA3")
state.min_gas_used += gas_tuple[0] state.min_gas_used += gas_tuple[0]
@ -996,40 +997,21 @@ class Instruction:
b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
for b in state.memory[index : index + length] for b in state.memory[index : index + length]
] ]
if len(data_list) > 1: if len(data_list) > 1:
data = simplify(Concat(data_list)) data = simplify(Concat(data_list))
elif len(data_list) == 1: elif len(data_list) == 1:
data = data_list[0] data = data_list[0]
else: else:
# length is 0; this only matters for input of the BitVecFuncVal # TODO: handle finding x where func(x)==func("")
data = symbol_factory.BitVecVal(0, 1) result = keccak_function_manager.get_empty_keccak_hash()
state.stack.append(result)
if data.symbolic: return [global_state]
annotations = set() # type: Set[Any]
for b in state.memory[index : index + length]:
if isinstance(b, BitVec):
annotations = annotations.union(b.annotations)
argument_hash = hash(state.memory[index])
result = symbol_factory.BitVecFuncSym(
"KECCAC[invhash({})]".format(hash(argument_hash)),
"keccak256",
256,
input_=data,
annotations=annotations,
)
log.debug("Created BitVecFunc hash.")
else:
keccak = utils.sha3(data.value.to_bytes(length, byteorder="big"))
result = symbol_factory.BitVecFuncVal(
util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data
)
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak)))
result, condition = keccak_function_manager.create_keccak(data)
state.stack.append(result) state.stack.append(result)
state.constraints.append(condition)
return [global_state] return [global_state]
@StateTransition() @StateTransition()

@ -0,0 +1,117 @@
from ethereum import utils
from mythril.laser.smt import (
BitVec,
Function,
URem,
symbol_factory,
ULE,
And,
ULT,
Bool,
Or,
)
from typing import Dict, Tuple, List
TOTAL_PARTS = 10 ** 40
PART = (2 ** 256 - 1) // TOTAL_PARTS
INTERVAL_DIFFERENCE = 10 ** 30
hash_matcher = "fffffff" # This is usually the prefix for the hash in the output
class KeccakFunctionManager:
"""
A bunch of uninterpreted functions are considered like keccak256_160 ,...
where keccak256_160 means the input of keccak256() is 160 bit number.
the range of these functions are constrained to some mutually disjoint intervals
All the hashes modulo 64 are 0 as we need a spread among hashes for array type data structures
All the functions are kind of one to one due to constraint of the existence of inverse
for each encountered input.
For more info https://files.sri.inf.ethz.ch/website/papers/sp20-verx.pdf
"""
def __init__(self):
self.store_function = {} # type: Dict[int, Tuple[Function, Function]]
self.interval_hook_for_size = {} # type: Dict[int, int]
self._index_counter = TOTAL_PARTS - 34534
self.quick_inverse = {} # type: Dict[BitVec, BitVec] # This is for VMTests
@staticmethod
def find_concrete_keccak(data: BitVec) -> BitVec:
"""
Calculates concrete keccak
:param data: input bitvecval
:return: concrete keccak output
"""
keccak = symbol_factory.BitVecVal(
int.from_bytes(
utils.sha3(data.value.to_bytes(data.size() // 8, byteorder="big")),
"big",
),
256,
)
return keccak
def get_function(self, length: int) -> Tuple[Function, Function]:
"""
Returns the keccak functions for the corresponding length
:param length: input size
:return: tuple of keccak and it's inverse
"""
try:
func, inverse = self.store_function[length]
except KeyError:
func = Function("keccak256_{}".format(length), length, 256)
inverse = Function("keccak256_{}-1".format(length), 256, length)
self.store_function[length] = (func, inverse)
return func, inverse
@staticmethod
def get_empty_keccak_hash() -> BitVec:
"""
returns sha3("")
:return:
"""
val = 89477152217924674838424037953991966239322087453347756267410168184682657981552
return symbol_factory.BitVecVal(val, 256)
def create_keccak(self, data: BitVec) -> Tuple[BitVec, Bool]:
"""
Creates Keccak of the data
:param data: input
:return: Tuple of keccak and the condition it should satisfy
"""
length = data.size()
func, inverse = self.get_function(length)
condition = self._create_condition(func_input=data)
self.quick_inverse[func(data)] = data
return func(data), condition
def _create_condition(self, func_input: BitVec) -> Bool:
"""
Creates the constraints for hash
:param func_input: input of the hash
:return: condition
"""
length = func_input.size()
func, inv = self.get_function(length)
try:
index = self.interval_hook_for_size[length]
except KeyError:
self.interval_hook_for_size[length] = self._index_counter
index = self._index_counter
self._index_counter -= INTERVAL_DIFFERENCE
lower_bound = index * PART
upper_bound = lower_bound + PART
cond = And(
inv(func(func_input)) == func_input,
ULE(symbol_factory.BitVecVal(lower_bound, 256), func(func_input)),
ULT(func(func_input), symbol_factory.BitVecVal(upper_bound, 256)),
URem(func(func_input), symbol_factory.BitVecVal(64, 256)) == 0,
)
return cond
keccak_function_manager = KeccakFunctionManager()

@ -55,7 +55,6 @@ class Storage:
self._standard_storage = K(256, 256, 0) # type: BaseArray self._standard_storage = K(256, 256, 0) # type: BaseArray
else: else:
self._standard_storage = Array("Storage", 256, 256) self._standard_storage = Array("Storage", 256, 256)
self._map_storage = {} # type: Dict[BitVec, BaseArray]
self.printable_storage = {} # type: Dict[BitVec, BitVec] self.printable_storage = {} # type: Dict[BitVec, BitVec]
@ -73,10 +72,7 @@ class Storage:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec: def __getitem__(self, item: BitVec) -> BitVec:
storage, is_keccak_storage = self._get_corresponding_storage(item) storage = self._standard_storage
if is_keccak_storage:
sanitized_item = self._sanitize(cast(BitVecFunc, item).input_)
else:
sanitized_item = item sanitized_item = item
if ( if (
self.address self.address
@ -100,7 +96,6 @@ class Storage:
self.printable_storage[item] = storage[sanitized_item] self.printable_storage[item] = storage[sanitized_item]
except ValueError as e: except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e) log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item]) return simplify(storage[sanitized_item])
@staticmethod @staticmethod
@ -114,29 +109,12 @@ class Storage:
index = Extract(255, 0, key.input_) index = Extract(255, 0, key.input_)
return simplify(index) return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]: def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
index = self.get_map_index(key) return self._standard_storage
if index is None:
storage = self._standard_storage
is_keccak_storage = False
else:
storage_map = self._map_storage
try:
storage = storage_map[index]
except KeyError:
if isinstance(self._standard_storage, Array):
storage_map[index] = Array("Storage", 512, 256)
else:
storage_map[index] = K(512, 256, 0)
storage = storage_map[index]
is_keccak_storage = True
return storage, is_keccak_storage
def __setitem__(self, key, value: Any) -> None: def __setitem__(self, key, value: Any) -> None:
storage, is_keccak_storage = self._get_corresponding_storage(key) storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value self.printable_storage[key] = value
if is_keccak_storage:
key = self._sanitize(key.input_)
storage[key] = value storage[key] = value
if key.symbolic is False: if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value)) self.storage_keys_loaded.add(int(key.value))
@ -147,7 +125,6 @@ class Storage:
concrete=concrete, address=self.address, dynamic_loader=self.dynld concrete=concrete, address=self.address, dynamic_loader=self.dynld
) )
storage._standard_storage = deepcopy(self._standard_storage) storage._standard_storage = deepcopy(self._standard_storage)
storage._map_storage = deepcopy(self._map_storage)
storage.printable_storage = copy(self.printable_storage) storage.printable_storage = copy(self.printable_storage)
storage.storage_keys_loaded = copy(self.storage_keys_loaded) storage.storage_keys_loaded = copy(self.storage_keys_loaded)
return storage return storage

@ -35,6 +35,7 @@ class Constraints(list):
""" """
:return: True/False based on the existence of solution of constraints :return: True/False based on the existence of solution of constraints
""" """
if self._is_possible is not None: if self._is_possible is not None:
return self._is_possible return self._is_possible
solver = Solver() solver = Solver()
@ -109,8 +110,8 @@ class Constraints(list):
:param constraints: :param constraints:
:return: :return:
""" """
constraints = self._get_smt_bool_list(constraints) list_constraints = self._get_smt_bool_list(constraints)
super(Constraints, self).__iadd__(constraints) super(Constraints, self).__iadd__(list_constraints)
self._is_possible = None self._is_possible = None
return self return self

@ -3,6 +3,7 @@ from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.transaction import ContractCreationTransaction from mythril.laser.ethereum.transaction import ContractCreationTransaction
from typing import Dict, cast, List from typing import Dict, cast, List
from copy import copy
import logging import logging
@ -16,7 +17,9 @@ class JumpdestCountAnnotation(StateAnnotation):
self._reached_count = {} # type: Dict[str, int] self._reached_count = {} # type: Dict[str, int]
def __copy__(self): def __copy__(self):
return self result = JumpdestCountAnnotation()
result._reached_count = copy(self._reached_count)
return result
class BoundedLoopsStrategy(BasicSearchStrategy): class BoundedLoopsStrategy(BasicSearchStrategy):
@ -45,6 +48,7 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
:return: Global state :return: Global state
""" """
while True: while True:
state = self.super_strategy.get_strategic_global_state() state = self.super_strategy.get_strategic_global_state()
@ -56,7 +60,6 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
if len(annotations) == 0: if len(annotations) == 0:
annotation = JumpdestCountAnnotation() annotation = JumpdestCountAnnotation()
log.debug("Adding JumpdestCountAnnotation to GlobalState")
state.annotate(annotation) state.annotate(annotation)
else: else:
annotation = annotations[0] annotation = annotations[0]

@ -20,6 +20,7 @@ from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy
from abc import ABCMeta from abc import ABCMeta
from mythril.laser.ethereum.time_handler import time_handler from mythril.laser.ethereum.time_handler import time_handler
from mythril.laser.ethereum.transaction import ( from mythril.laser.ethereum.transaction import (
ContractCreationTransaction, ContractCreationTransaction,
TransactionEndSignal, TransactionEndSignal,
@ -29,6 +30,7 @@ from mythril.laser.ethereum.transaction import (
) )
from mythril.laser.smt import symbol_factory from mythril.laser.smt import symbol_factory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -206,6 +208,7 @@ class LaserEVM:
i, len(self.open_states) i, len(self.open_states)
) )
) )
for hook in self._start_sym_trans_hooks: for hook in self._start_sym_trans_hooks:
hook() hook()
@ -245,7 +248,6 @@ class LaserEVM:
except NotImplementedError: except NotImplementedError:
log.debug("Encountered unimplemented instruction") log.debug("Encountered unimplemented instruction")
continue continue
new_states = [ new_states = [
state for state in new_states if state.mstate.constraints.is_possible state for state in new_states if state.mstate.constraints.is_possible
] ]
@ -357,16 +359,15 @@ class LaserEVM:
] ]
log.debug("Ending transaction %s.", transaction) log.debug("Ending transaction %s.", transaction)
if return_global_state is None: if return_global_state is None:
if ( if (
not isinstance(transaction, ContractCreationTransaction) not isinstance(transaction, ContractCreationTransaction)
or transaction.return_data or transaction.return_data
) and not end_signal.revert: ) and not end_signal.revert:
check_potential_issues(global_state) check_potential_issues(global_state)
end_signal.global_state.world_state.node = global_state.node end_signal.global_state.world_state.node = global_state.node
self._add_world_state(end_signal.global_state) self._add_world_state(end_signal.global_state)
new_global_states = [] new_global_states = []
else: else:
# First execute the post hook for the transaction ending instruction # First execute the post hook for the transaction ending instruction

@ -88,6 +88,9 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None:
condition=None, condition=None,
) )
) )
global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node
new_node.states.append(global_state) new_node.states.append(global_state)

@ -185,7 +185,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
) )
global_state.mstate.constraints += transaction.world_state.node.constraints global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints.as_list new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction) global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node global_state.node = new_node

@ -22,6 +22,7 @@ from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.function import Function
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool as SMTBool from mythril.laser.smt.bool import Bool as SMTBool
@ -47,6 +48,16 @@ class SymbolFactory(Generic[T, U]):
""" """
raise NotImplementedError raise NotImplementedError
@staticmethod
def BoolSym(name: str, annotations: Annotations = None) -> T:
"""
Creates a boolean symbol
:param name: The name of the Bool variable
:param annotations: The annotations to initialize the bool with
:return: The freshly created Bool()
"""
raise NotImplementedError
@staticmethod @staticmethod
def BitVecVal(value: int, size: int, annotations: Annotations = None) -> U: def BitVecVal(value: int, size: int, annotations: Annotations = None) -> U:
"""Creates a new bit vector with a concrete value. """Creates a new bit vector with a concrete value.
@ -125,6 +136,17 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
raw = z3.BoolVal(value) raw = z3.BoolVal(value)
return SMTBool(raw, annotations) return SMTBool(raw, annotations)
@staticmethod
def BoolSym(name: str, annotations: Annotations = None) -> SMTBool:
"""
Creates a boolean symbol
:param name: The name of the Bool variable
:param annotations: The annotations to initialize the bool with
:return: The freshly created Bool()
"""
raw = z3.Bool(name)
return SMTBool(raw, annotations)
@staticmethod @staticmethod
def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec: def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec:
"""Creates a new bit vector with a concrete value.""" """Creates a new bit vector with a concrete value."""

@ -1,8 +1,7 @@
from typing import Union, overload, List, Set, cast, Any, Optional, Callable from typing import Union, overload, List, Set, cast, Any, Callable
from operator import lshift, rshift, ne, eq
import z3 import z3
from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper

@ -0,0 +1,25 @@
from typing import cast
import z3
from mythril.laser.smt.bitvec import BitVec
class Function:
"""An uninterpreted function."""
def __init__(self, name: str, domain: int, value_range: int):
"""Initializes an uninterpreted function.
:param name: Name of the Function
:param domain: The domain for the Function (10 -> all the values that a bv of size 10 could take)
:param value_range: The range for the values of the function (10 -> all the values that a bv of size 10 could take)
"""
self.domain = z3.BitVecSort(domain)
self.range = z3.BitVecSort(value_range)
self.raw = z3.Function(name, self.domain, self.range)
def __call__(self, item: BitVec) -> BitVec:
"""Function accessor, item can be symbolic."""
return BitVec(
cast(z3.BitVecRef, self.raw(item.raw)), annotations=item.annotations
)

@ -45,14 +45,14 @@ class BaseSolver(Generic[T]):
self.add(*constraints) self.add(*constraints)
@stat_smt_query @stat_smt_query
def check(self) -> z3.CheckSatResult: def check(self, *args) -> z3.CheckSatResult:
"""Returns z3 smt check result. """Returns z3 smt check result.
Also suppresses the stdout when running z3 library's check() to avoid unnecessary output Also suppresses the stdout when running z3 library's check() to avoid unnecessary output
:return: The evaluated result which is either of sat, unsat or unknown :return: The evaluated result which is either of sat, unsat or unknown
""" """
old_stdout = sys.stdout old_stdout = sys.stdout
sys.stdout = open(os.devnull, "w") sys.stdout = open(os.devnull, "w")
evaluate = self.raw.check() evaluate = self.raw.check(args)
sys.stdout = old_stdout sys.stdout = old_stdout
return evaluate return evaluate

@ -1,6 +1,6 @@
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.cfg import Node
from mythril.laser.ethereum.state.environment import Environment from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.world_state import WorldState
@ -30,11 +30,12 @@ def execute_create():
calldata = ConcreteCalldata(0, code_raw) calldata = ConcreteCalldata(0, code_raw)
world_state = WorldState() world_state = WorldState()
world_state.node = Node("Contract")
account = world_state.create_account(balance=1000000, address=101) account = world_state.create_account(balance=1000000, address=101)
account.code = Disassembly("60a760006000f000") account.code = Disassembly("60a760006000f000")
environment = Environment(account, None, calldata, None, None, None) environment = Environment(account, None, calldata, None, None, None)
og_state = GlobalState( og_state = GlobalState(
world_state, environment, None, MachineState(gas_limit=8000000) world_state, environment, world_state.node, MachineState(gas_limit=8000000)
) )
og_state.transaction_stack.append( og_state.transaction_stack.append(
(MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None)

@ -1,10 +1,10 @@
from mythril.laser.ethereum.svm import LaserEVM from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.transaction.concolic import execute_message_call from mythril.laser.ethereum.transaction.concolic import execute_message_call
from mythril.laser.smt import Expression, BitVec, symbol_factory from mythril.laser.smt import Expression, BitVec, symbol_factory
from mythril.analysis.solver import get_model
from datetime import datetime from datetime import datetime
import binascii import binascii
@ -117,7 +117,6 @@ def test_vmtest(
# Arrange # Arrange
if test_name in ignored_test_names: if test_name in ignored_test_names:
return return
world_state = WorldState() world_state = WorldState()
for address, details in pre_condition.items(): for address, details in pre_condition.items():
@ -178,6 +177,14 @@ def test_vmtest(
expected = int(value, 16) expected = int(value, 16)
actual = account.storage[symbol_factory.BitVecVal(int(index, 16), 256)] actual = account.storage[symbol_factory.BitVecVal(int(index, 16), 256)]
if isinstance(actual, Expression): if isinstance(actual, Expression):
if (
actual.symbolic
and actual in keccak_function_manager.quick_inverse
):
actual = keccak_function_manager.find_concrete_keccak(
keccak_function_manager.quick_inverse[actual]
)
else:
actual = actual.value actual = actual.value
actual = 1 if actual is True else 0 if actual is False else actual actual = 1 if actual is True else 0 if actual is False else actual
else: else:

@ -0,0 +1,138 @@
from mythril.laser.smt import Solver, symbol_factory, And
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
import z3
import pytest
@pytest.mark.parametrize(
"input1, input2, expected",
[
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(101, 8), z3.unsat),
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 16), z3.unsat),
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 8), z3.sat),
(
symbol_factory.BitVecSym("N1", 256),
symbol_factory.BitVecSym("N2", 256),
z3.sat,
),
(
symbol_factory.BitVecVal(100, 256),
symbol_factory.BitVecSym("N1", 256),
z3.sat,
),
(
symbol_factory.BitVecVal(100, 8),
symbol_factory.BitVecSym("N1", 256),
z3.unsat,
),
],
)
def test_keccak_basic(input1, input2, expected):
s = Solver()
o1, c1 = keccak_function_manager.create_keccak(input1)
o2, c2 = keccak_function_manager.create_keccak(input2)
s.add(And(c1, c2))
s.add(o1 == o2)
assert s.check() == expected
def test_keccak_symbol_and_val():
"""
check keccak(100) == keccak(n) && n == 10
:return:
"""
s = Solver()
hundred = symbol_factory.BitVecVal(100, 256)
n = symbol_factory.BitVecSym("n", 256)
o1, c1 = keccak_function_manager.create_keccak(hundred)
o2, c2 = keccak_function_manager.create_keccak(n)
s.add(And(c1, c2))
s.add(o1 == o2)
s.add(n == symbol_factory.BitVecVal(10, 256))
assert s.check() == z3.unsat
def test_keccak_complex_eq():
"""
check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1
o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2))
s.add(o1 == o2)
s.add(a != b)
assert s.check() == z3.unsat
def test_keccak_complex_eq2():
"""
check for keccak(keccak(b)*2) == keccak(keccak(a)*2)
This isn't combined with prev test because incremental solving here requires extra-extra work
(solution is literally the opposite of prev one) so it will take forever to solve.
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1
o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2))
s.add(o1 == o2)
assert s.check() == z3.sat
def test_keccak_simple_number():
"""
check for keccak(b) == 10
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
ten = symbol_factory.BitVecVal(10, 256)
o, c = keccak_function_manager.create_keccak(a)
s.add(c)
s.add(ten == o)
assert s.check() == z3.unsat
def test_keccak_other_num():
"""
check keccak(keccak(a)*2) == b
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 256)
o, c = keccak_function_manager.create_keccak(a)
two = symbol_factory.BitVecVal(2, 256)
o = two * o
s.add(c)
o, c = keccak_function_manager.create_keccak(o)
s.add(c)
s.add(b == o)
assert s.check() == z3.sat
Loading…
Cancel
Save