Use UIP over BitVecFunc()

pull/1220/head
Nikhil 5 years ago
parent 00dc912f2c
commit 475ad946b8
  1. 30
      mythril/laser/ethereum/instructions.py
  2. 58
      mythril/laser/ethereum/keccak_function_manager.py
  3. 31
      mythril/laser/ethereum/state/account.py
  4. 1
      mythril/laser/ethereum/svm.py
  5. 1
      mythril/laser/smt/__init__.py
  6. 25
      mythril/laser/smt/function.py

@ -31,6 +31,7 @@ from mythril.laser.smt import symbol_factory
import mythril.laser.ethereum.util as helper import mythril.laser.ethereum.util as helper
from mythril.laser.ethereum import util from mythril.laser.ethereum import util
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
from mythril.laser.ethereum.call import get_call_parameters, native_call from mythril.laser.ethereum.call import get_call_parameters, native_call
from mythril.laser.ethereum.evm_exceptions import ( from mythril.laser.ethereum.evm_exceptions import (
VmException, VmException,
@ -947,33 +948,10 @@ class Instruction:
else: else:
# length is 0; this only matters for input of the BitVecFuncVal # length is 0; this only matters for input of the BitVecFuncVal
data = symbol_factory.BitVecVal(0, 1) data = symbol_factory.BitVecVal(0, 1)
result, constraints = keccak_function_manager.create_keccak(data, length)
if data.symbolic:
annotations = set() # type: Set[Any]
for b in state.memory[index : index + length]:
if isinstance(b, BitVec):
annotations = annotations.union(b.annotations)
argument_hash = hash(state.memory[index])
result = symbol_factory.BitVecFuncSym(
"KECCAC[invhash({})]".format(hash(argument_hash)),
"keccak256",
256,
input_=data,
annotations=annotations,
)
log.debug("Created BitVecFunc hash.")
else:
keccak = utils.sha3(data.value.to_bytes(length, byteorder="big"))
result = symbol_factory.BitVecFuncVal(
util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data
)
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak)))
state.stack.append(result) state.stack.append(result)
state.constraints += constraints
return [global_state] return [global_state]
@StateTransition() @StateTransition()

@ -0,0 +1,58 @@
from ethereum import utils
from mythril.laser.smt import BitVec, Function, URem, symbol_factory, ULE, And, ULT, Or
TOTAL_PARTS = 10 ** 40
PART = (2 ** 256 - 1) // TOTAL_PARTS
INTERVAL_DIFFERENCE = 10 ** 30
class KeccakFunctionManager:
def __init__(self):
self.sizes = {}
self.size_index = {}
self.index_counter = TOTAL_PARTS - 34534
self.size_values = {}
def create_keccak(self, data: BitVec, length: int):
length = length * 8
assert length == data.size()
try:
func, inverse = self.sizes[length]
except KeyError:
func = Function("keccak256_{}".format(length), length, 256)
inverse = Function("keccak256_{}-1".format(length), 256, length)
self.sizes[length] = (func, inverse)
self.size_values[length] = []
constraints = []
if data.symbolic is False:
keccak = symbol_factory.BitVecVal(
utils.sha3(data.value.to_bytes(length // 8, byteorder="big")), 256
)
constraints.append(func(data) == keccak)
constraints.append(inverse(func(data)) == data)
if data.symbolic is False:
return func(data), constraints
constraints.append(URem(func(data), symbol_factory.BitVecVal(63, 256)) == 0)
try:
index = self.size_index[length]
except KeyError:
self.size_index[length] = self.index_counter
index = self.index_counter
self.index_counter -= INTERVAL_DIFFERENCE
lower_bound = index * PART
upper_bound = (index + 1) * PART
condition = And(
ULE(symbol_factory.BitVecVal(lower_bound, 256), func(data)),
ULT(func(data), symbol_factory.BitVecVal(upper_bound, 256)),
)
for val in self.size_values[length]:
condition = Or(condition, func(data) == val)
constraints.append(condition)
return func(data), constraints
keccak_function_manager = KeccakFunctionManager()

@ -55,7 +55,6 @@ class Storage:
self._standard_storage = K(256, 256, 0) # type: BaseArray self._standard_storage = K(256, 256, 0) # type: BaseArray
else: else:
self._standard_storage = Array("Storage", 256, 256) self._standard_storage = Array("Storage", 256, 256)
self._map_storage = {} # type: Dict[BitVec, BaseArray]
self.printable_storage = {} # type: Dict[BitVec, BitVec] self.printable_storage = {} # type: Dict[BitVec, BitVec]
@ -73,10 +72,7 @@ class Storage:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec: def __getitem__(self, item: BitVec) -> BitVec:
storage, is_keccak_storage = self._get_corresponding_storage(item) storage = self._standard_storage
if is_keccak_storage:
sanitized_item = self._sanitize(cast(BitVecFunc, item).input_)
else:
sanitized_item = item sanitized_item = item
if ( if (
self.address self.address
@ -100,7 +96,6 @@ class Storage:
self.printable_storage[item] = storage[sanitized_item] self.printable_storage[item] = storage[sanitized_item]
except ValueError as e: except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e) log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item]) return simplify(storage[sanitized_item])
@staticmethod @staticmethod
@ -114,29 +109,12 @@ class Storage:
index = Extract(255, 0, key.input_) index = Extract(255, 0, key.input_)
return simplify(index) return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]: def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
index = self.get_map_index(key) return self._standard_storage
if index is None:
storage = self._standard_storage
is_keccak_storage = False
else:
storage_map = self._map_storage
try:
storage = storage_map[index]
except KeyError:
if isinstance(self._standard_storage, Array):
storage_map[index] = Array("Storage", 512, 256)
else:
storage_map[index] = K(512, 256, 0)
storage = storage_map[index]
is_keccak_storage = True
return storage, is_keccak_storage
def __setitem__(self, key, value: Any) -> None: def __setitem__(self, key, value: Any) -> None:
storage, is_keccak_storage = self._get_corresponding_storage(key) storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value self.printable_storage[key] = value
if is_keccak_storage:
key = self._sanitize(key.input_)
storage[key] = value storage[key] = value
if key.symbolic is False: if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value)) self.storage_keys_loaded.add(int(key.value))
@ -147,7 +125,6 @@ class Storage:
concrete=concrete, address=self.address, dynamic_loader=self.dynld concrete=concrete, address=self.address, dynamic_loader=self.dynld
) )
storage._standard_storage = deepcopy(self._standard_storage) storage._standard_storage = deepcopy(self._standard_storage)
storage._map_storage = deepcopy(self._map_storage)
storage.printable_storage = copy(self.printable_storage) storage.printable_storage = copy(self.printable_storage)
storage.storage_keys_loaded = copy(self.storage_keys_loaded) storage.storage_keys_loaded = copy(self.storage_keys_loaded)
return storage return storage

@ -241,7 +241,6 @@ class LaserEVM:
except NotImplementedError: except NotImplementedError:
log.debug("Encountered unimplemented instruction") log.debug("Encountered unimplemented instruction")
continue continue
new_states = [ new_states = [
state for state in new_states if state.mstate.constraints.is_possible state for state in new_states if state.mstate.constraints.is_possible
] ]

@ -22,6 +22,7 @@ from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.function import Function
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool as SMTBool from mythril.laser.smt.bool import Bool as SMTBool

@ -0,0 +1,25 @@
from typing import cast
import z3
from mythril.laser.smt.bitvec import BitVec
class Function:
"""An uninterpreted function."""
def __init__(self, name: str, domain: int, value_range: int):
"""Initializes an uninterpreted function.
:param name: Name of the Function
:param domain: The domain for the Function (10 -> all the values that a bv of size 10 could take)
:param value_range: The range for the values of the function (10 -> all the values that a bv of size 10 could take)
"""
self.domain = z3.BitVecSort(domain)
self.range = z3.BitVecSort(value_range)
self.raw = z3.Function(name, self.domain, self.range)
def __call__(self, item: BitVec) -> BitVec:
"""Function accessor, item can be symbolic."""
return BitVec(
cast(z3.BitVecRef, self.raw(item.raw)), annotations=item.annotations
) # type: ignore
Loading…
Cancel
Save