diff --git a/mythril/analysis/modules/integer.py b/mythril/analysis/modules/integer.py index cbaf144d..b2e58ca8 100644 --- a/mythril/analysis/modules/integer.py +++ b/mythril/analysis/modules/integer.py @@ -230,7 +230,6 @@ class IntegerOverflowUnderflowModule(DetectionModule): or annotation.overflowing_state in state_annotation.ostates_seen ): continue - state_annotation.overflowing_state_annotations.append(annotation) state_annotation.ostates_seen.add(annotation.overflowing_state) @@ -248,7 +247,6 @@ class IntegerOverflowUnderflowModule(DetectionModule): or annotation.overflowing_state in state_annotation.ostates_seen ): continue - state_annotation.overflowing_state_annotations.append(annotation) state_annotation.ostates_seen.add(annotation.overflowing_state) @@ -311,7 +309,6 @@ class IntegerOverflowUnderflowModule(DetectionModule): try: constraints = state.mstate.constraints + [annotation.constraint] - transaction_sequence = solver.get_transaction_sequence( state, constraints ) diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index b258a623..ee31a8ff 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -158,7 +158,9 @@ class SymExecWrapper: contract.disassembly, dynamic_loader=dynloader, contract_name=contract.name, - concrete_storage=False, + concrete_storage=True + if (dynloader is not None and dynloader.storage_loading) + else False, ) world_state.put_account(account) self.laser.sym_exec(world_state=world_state, target_address=address.value) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index c084be0b..19c54d1d 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -18,13 +18,11 @@ from mythril.laser.smt import ( ULT, UGT, BitVec, - is_true, is_false, URem, SRem, If, Bool, - Or, Not, LShR, ) @@ -41,7 +39,6 @@ from mythril.laser.ethereum.evm_exceptions import ( OutOfGasException, ) from mythril.laser.ethereum.gas import OPCODE_GAS -from mythril.laser.ethereum.keccak import KeccakFunctionManager from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.transaction import ( MessageCallTransaction, @@ -56,8 +53,6 @@ log = logging.getLogger(__name__) TT256 = 2 ** 256 TT256M1 = 2 ** 256 - 1 -keccak_function_manager = KeccakFunctionManager() - class StateTransition(object): """Decorator that handles global state copy and original return. @@ -193,7 +188,6 @@ class Instruction: if not post else getattr(self, op + "_" + "post", None) ) - if instruction_mutator is None: raise NotImplementedError @@ -417,6 +411,7 @@ class Instruction: * helper.pop_bitvec(global_state.mstate) ) ) + return [global_state] @StateTransition() @@ -893,7 +888,6 @@ class Instruction: :param global_state: :return: """ - global keccak_function_manager state = global_state.mstate op0, op1 = state.stack.pop(), state.stack.pop() @@ -948,7 +942,6 @@ class Instruction: ) log.debug("Created BitVecFunc hash.") - keccak_function_manager.add_keccak(result, state.memory[index]) else: keccak = utils.sha3(data.value.to_bytes(length, byteorder="big")) result = symbol_factory.BitVecFuncVal( @@ -1400,86 +1393,13 @@ class Instruction: :param global_state: :return: """ - global keccak_function_manager state = global_state.mstate index = state.stack.pop() log.debug("Storage access at index " + str(index)) - - try: - index = util.get_concrete_int(index) - return self._sload_helper(global_state, index) - - except TypeError: - if not keccak_function_manager.is_keccak(index): - return self._sload_helper(global_state, str(index)) - - storage_keys = global_state.environment.active_account.storage.keys() - keccak_keys = list(filter(keccak_function_manager.is_keccak, storage_keys)) - - results = [] # type: List[GlobalState] - constraints = [] - - for keccak_key in keccak_keys: - key_argument = keccak_function_manager.get_argument(keccak_key) - index_argument = keccak_function_manager.get_argument(index) - constraints.append((keccak_key, key_argument == index_argument)) - - for (keccak_key, constraint) in constraints: - if constraint in state.constraints: - results += self._sload_helper( - global_state, keccak_key, [constraint] - ) - if len(results) > 0: - return results - - for (keccak_key, constraint) in constraints: - results += self._sload_helper( - copy(global_state), keccak_key, [constraint] - ) - if len(results) > 0: - return results - - return self._sload_helper(global_state, str(index)) - - @staticmethod - def _sload_helper( - global_state: GlobalState, index: Union[str, int], constraints=None - ): - """ - - :param global_state: - :param index: - :param constraints: - :return: - """ - try: - data = global_state.environment.active_account.storage[index] - except KeyError: - data = global_state.new_bitvec("storage_" + str(index), 256) - global_state.environment.active_account.storage[index] = data - - if constraints is not None: - global_state.mstate.constraints += constraints - - global_state.mstate.stack.append(data) + state.stack.append(global_state.environment.active_account.storage[index]) return [global_state] - @staticmethod - def _get_constraints(keccak_keys, this_key, argument): - """ - - :param keccak_keys: - :param this_key: - :param argument: - """ - global keccak_function_manager - for keccak_key in keccak_keys: - if keccak_key == this_key: - continue - keccak_argument = keccak_function_manager.get_argument(keccak_key) - yield keccak_argument != argument - @StateTransition() def sstore_(self, global_state: GlobalState) -> List[GlobalState]: """ @@ -1487,90 +1407,10 @@ class Instruction: :param global_state: :return: """ - global keccak_function_manager state = global_state.mstate index, value = state.stack.pop(), state.stack.pop() log.debug("Write to storage[" + str(index) + "]") - - try: - index = util.get_concrete_int(index) - return self._sstore_helper(global_state, index, value) - except TypeError: - is_keccak = keccak_function_manager.is_keccak(index) - if not is_keccak: - return self._sstore_helper(global_state, str(index), value) - - storage_keys = global_state.environment.active_account.storage.keys() - keccak_keys = filter(keccak_function_manager.is_keccak, storage_keys) - - results = [] # type: List[GlobalState] - new = symbol_factory.Bool(False) - - for keccak_key in keccak_keys: - key_argument = keccak_function_manager.get_argument( - keccak_key - ) # type: Expression - index_argument = keccak_function_manager.get_argument( - index - ) # type: Expression - condition = key_argument == index_argument - condition = ( - condition - if type(condition) == bool - else is_true(simplify(cast(Bool, condition))) - ) - if condition: - return self._sstore_helper( - copy(global_state), - keccak_key, - value, - key_argument == index_argument, - ) - - results += self._sstore_helper( - copy(global_state), - keccak_key, - value, - key_argument == index_argument, - ) - - new = Or(new, cast(Bool, key_argument != index_argument)) - - if len(results) > 0: - results += self._sstore_helper( - copy(global_state), str(index), value, new - ) - return results - - return self._sstore_helper(global_state, str(index), value) - - @staticmethod - def _sstore_helper(global_state, index, value, constraint=None): - """ - - :param global_state: - :param index: - :param value: - :param constraint: - :return: - """ - try: - global_state.environment.active_account = deepcopy( - global_state.environment.active_account - ) - global_state.accounts[ - global_state.environment.active_account.address.value - ] = global_state.environment.active_account - - global_state.environment.active_account.storage[index] = ( - value if not isinstance(value, Expression) else simplify(value) - ) - except KeyError: - log.debug("Error writing to storage: Invalid index") - - if constraint is not None: - global_state.mstate.constraints.append(constraint) - + global_state.environment.active_account.storage[index] = value return [global_state] @StateTransition(increment_pc=False, enable_gas=False) diff --git a/mythril/laser/ethereum/keccak.py b/mythril/laser/ethereum/keccak.py deleted file mode 100644 index 47a39ed1..00000000 --- a/mythril/laser/ethereum/keccak.py +++ /dev/null @@ -1,38 +0,0 @@ -"""This module contains a function manager to deal with symbolic Keccak -values.""" -from mythril.laser.smt import Expression - - -class KeccakFunctionManager: - """A keccak function manager for symbolic expressions.""" - - def __init__(self): - """""" - self.keccak_expression_mapping = {} - - def is_keccak(self, expression: Expression) -> bool: - """ - - :param expression: - :return: - """ - return str(expression) in self.keccak_expression_mapping.keys() - - def get_argument(self, expression: Expression) -> Expression: - """ - - :param expression: - :return: - """ - if not self.is_keccak(expression): - raise ValueError("Expression is not a recognized keccac result") - return self.keccak_expression_mapping[str(expression)][1] - - def add_keccak(self, expression: Expression, argument: Expression) -> None: - """ - - :param expression: - :param argument: - """ - index = str(expression) - self.keccak_expression_mapping[index] = (expression, argument) diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 191f0641..d3e6f876 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -2,12 +2,20 @@ This includes classes representing accounts and their storage. """ -from copy import deepcopy, copy -from typing import Any, Dict, KeysView, Union - -from z3 import ExprRef - -from mythril.laser.smt import Array, symbol_factory, BitVec +from copy import copy, deepcopy +from typing import Any, Dict, Union, Tuple, cast + + +from mythril.laser.smt import ( + Array, + K, + BitVec, + simplify, + BitVecFunc, + Extract, + BaseArray, + Concat, +) from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -20,62 +28,105 @@ class Storage: :param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic """ - self._storage = {} # type: Dict[Union[int, str], Any] - self.concrete = concrete + if concrete: + self._standard_storage = K(256, 256, 0) # type: BaseArray + else: + self._standard_storage = Array("Storage", 256, 256) + self._map_storage = {} # type: Dict[BitVec, BaseArray] + + self.printable_storage = {} # type: Dict[BitVec, BitVec] + self.dynld = dynamic_loader self.address = address - def __getitem__(self, item: Union[str, int]) -> Any: - try: - return self._storage[item] - except KeyError: - if ( - self.address - and self.address.value != 0 - and (self.dynld and self.dynld.storage_loading) - ): - try: - self._storage[item] = symbol_factory.BitVecVal( - int( - self.dynld.read_storage( - contract_address=hex(self.address.value), - index=int(item), - ), - 16, + @staticmethod + def _sanitize(input_: BitVec) -> BitVec: + if input_.size() == 512: + return input_ + if input_.size() > 512: + return Extract(511, 0, input_) + else: + return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_) + + def __getitem__(self, item: BitVec) -> BitVec: + storage, is_keccak_storage = self._get_corresponding_storage(item) + if is_keccak_storage: + item = self._sanitize(cast(BitVecFunc, item).input_) + value = storage[item] + if ( + value.value == 0 + and self.address + and item.symbolic is False + and self.address.value != 0 + and (self.dynld and self.dynld.storage_loading) + ): + try: + storage[item] = symbol_factory.BitVecVal( + int( + self.dynld.read_storage( + contract_address=hex(self.address.value), + index=int(item.value), ), - 256, - ) - return self._storage[item] - except ValueError: - pass - - if self.concrete: - return symbol_factory.BitVecVal(0, 256) - - self._storage[item] = symbol_factory.BitVecSym( - "storage_{}_{}".format(str(item), str(self.address)), 256 - ) - return self._storage[item] - - def __setitem__(self, key: Union[int, str], value: Any) -> None: - self._storage[key] = value - - def keys(self) -> KeysView: - """ - - :return: - """ - return self._storage.keys() + 16, + ), + 256, + ) + self.printable_storage[item] = storage[item] + return storage[item] + except ValueError: + pass + + return simplify(storage[item]) + + @staticmethod + def get_map_index(key: BitVec) -> BitVec: + if ( + not isinstance(key, BitVecFunc) + or key.func_name != "keccak256" + or key.input_ is None + ): + return None + index = Extract(255, 0, key.input_) + return simplify(index) + + def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]: + index = self.get_map_index(key) + if index is None: + storage = self._standard_storage + is_keccak_storage = False + else: + storage_map = self._map_storage + try: + storage = storage_map[index] + except KeyError: + if isinstance(self._standard_storage, Array): + storage_map[index] = Array("Storage", 512, 256) + else: + storage_map[index] = K(512, 256, 0) + storage = storage_map[index] + is_keccak_storage = True + return storage, is_keccak_storage + + def __setitem__(self, key, value: Any) -> None: + storage, is_keccak_storage = self._get_corresponding_storage(key) + self.printable_storage[key] = value + if is_keccak_storage: + key = self._sanitize(key.input_) + storage[key] = value def __deepcopy__(self, memodict={}): + concrete = isinstance(self._standard_storage, K) storage = Storage( - concrete=self.concrete, address=self.address, dynamic_loader=self.dynld + concrete=concrete, address=self.address, dynamic_loader=self.dynld ) - storage._storage = copy(self._storage) + storage._standard_storage = deepcopy(self._standard_storage) + storage._map_storage = deepcopy(self._map_storage) + storage.print_storage = copy(self.printable_storage) return storage - def __str__(self): - return str(self._storage) + def __str__(self) -> str: + # TODO: Do something better here + return str(self.printable_storage) class Account: @@ -159,7 +210,7 @@ class Account: "storage": self.storage, } - def __deepcopy__(self, memodict={}): + def __copy__(self, memodict={}): new_account = Account( address=self.address, code=self.code, diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index 15e9c43f..73b33516 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -1,7 +1,7 @@ """This module contains the class used to represent state-change constraints in the call graph.""" -from mythril.laser.smt import Solver, Bool, symbol_factory +from mythril.laser.smt import Solver, Bool, symbol_factory, simplify from typing import Iterable, List, Optional, Union from z3 import unsat @@ -54,7 +54,9 @@ class Constraints(list): :param constraint: The constraint to be appended """ - constraint = constraint if isinstance(constraint, Bool) else Bool(constraint) + constraint = ( + simplify(constraint) if isinstance(constraint, Bool) else Bool(constraint) + ) super(Constraints, self).append(constraint) self._is_possible = None diff --git a/mythril/laser/ethereum/state/global_state.py b/mythril/laser/ethereum/state/global_state.py index e6a13cd3..86364a69 100644 --- a/mythril/laser/ethereum/state/global_state.py +++ b/mythril/laser/ethereum/state/global_state.py @@ -61,6 +61,7 @@ class GlobalState: environment = copy(self.environment) mstate = deepcopy(self.mstate) transaction_stack = copy(self.transaction_stack) + environment.active_account = world_state[environment.active_account.address] return GlobalState( world_state, environment, diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index d849d0c4..5a3f1ca1 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -72,7 +72,7 @@ class SymbolFactory(Generic[T, U]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value. @@ -91,7 +91,7 @@ class SymbolFactory(Generic[T, U]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value. @@ -140,7 +140,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a concrete value.""" raw = z3.BitVecVal(value, size) @@ -152,7 +152,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value.""" raw = z3.BitVec(name, size) diff --git a/mythril/laser/smt/array.py b/mythril/laser/smt/array.py index c2d658d2..00107df1 100644 --- a/mythril/laser/smt/array.py +++ b/mythril/laser/smt/array.py @@ -8,7 +8,8 @@ default values over a certain range. from typing import cast import z3 -from mythril.laser.smt.bitvec import BitVec +from mythril.laser.smt.bitvec import BitVec, If +from mythril.laser.smt.bool import Bool class BaseArray: @@ -24,6 +25,9 @@ class BaseArray: def __setitem__(self, key: BitVec, value: BitVec) -> None: """Sets an item in the array, key can be symbolic.""" + if isinstance(value, Bool): + value = If(value, 1, 0) + self.raw = z3.Store(self.raw, key.raw, value.raw) # type: ignore diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index f73ef504..257d2b46 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -79,7 +79,7 @@ class BitVecFunc(BitVec): self, raw: z3.BitVecRef, func_name: Optional[str], - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, annotations: Optional[Annotations] = None, ): """ @@ -205,7 +205,7 @@ class BitVecFunc(BitVec): :return: The resulting Bool """ return _comparison_helper( - self, other, operator.eq, default_value=True, inputs_equal=False + self, other, operator.ne, default_value=True, inputs_equal=False ) def __lshift__(self, other: Union[int, "BitVec"]) -> "BitVec": @@ -223,3 +223,6 @@ class BitVecFunc(BitVec): :return The resulting right shifted output: """ return _arithmetic_helper(self, other, operator.rshift) + + def __hash__(self) -> int: + return self.raw.__hash__() diff --git a/mythril/laser/smt/bool.py b/mythril/laser/smt/bool.py index 1dd1cc76..1859ab43 100644 --- a/mythril/laser/smt/bool.py +++ b/mythril/laser/smt/bool.py @@ -80,6 +80,9 @@ class Bool(Expression[z3.BoolRef]): else: return False + def __hash__(self) -> int: + return self.raw.__hash__() + def And(*args: Union[Bool, bool]) -> Bool: """Create an And expression.""" @@ -90,6 +93,13 @@ def And(*args: Union[Bool, bool]) -> Bool: return Bool(z3.And([a.raw for a in args_list]), union) +def Xor(a: Bool, b: Bool) -> Bool: + """Create an And expression.""" + + union = a.annotations + b.annotations + return Bool(z3.Xor(a.raw, b.raw), union) + + def Or(*args: Union[Bool, bool]) -> Bool: """Create an or expression. diff --git a/mythril/laser/smt/expression.py b/mythril/laser/smt/expression.py index 8e9e697e..ae98cd31 100644 --- a/mythril/laser/smt/expression.py +++ b/mythril/laser/smt/expression.py @@ -45,6 +45,12 @@ class Expression(Generic[T]): def __repr__(self) -> str: return repr(self.raw) + def size(self): + return self.raw.size() + + def __hash__(self) -> int: + return self.raw.__hash__() + G = TypeVar("G", bound=Expression) diff --git a/tests/instructions/codecopy_test.py b/tests/instructions/codecopy_test.py index 9d63efa8..8843199f 100644 --- a/tests/instructions/codecopy_test.py +++ b/tests/instructions/codecopy_test.py @@ -10,9 +10,13 @@ from mythril.laser.ethereum.transaction.transaction_models import MessageCallTra def test_codecopy_concrete(): # Arrange - active_account = Account("0x0", code=Disassembly("60606040")) - environment = Environment(active_account, None, None, None, None, None) - og_state = GlobalState(None, environment, None, MachineState(gas_limit=8000000)) + world_state = WorldState() + account = world_state.create_account(balance=10, address=101) + account.code = Disassembly("60606040") + environment = Environment(account, None, None, None, None, None) + og_state = GlobalState( + world_state, environment, None, MachineState(gas_limit=8000000) + ) og_state.transaction_stack.append( (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) ) diff --git a/tests/instructions/sar_test.py b/tests/instructions/sar_test.py index b618582f..3a23da13 100644 --- a/tests/instructions/sar_test.py +++ b/tests/instructions/sar_test.py @@ -2,6 +2,7 @@ import pytest from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.state.environment import Environment +from mythril.laser.ethereum.state.world_state import WorldState from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.global_state import GlobalState @@ -12,9 +13,11 @@ from mythril.laser.smt import symbol_factory, simplify def get_state(): - active_account = Account("0x0", code=Disassembly("60606040")) - environment = Environment(active_account, None, None, None, None, None) - state = GlobalState(None, environment, None, MachineState(gas_limit=8000000)) + world_state = WorldState() + account = world_state.create_account(balance=10, address=101) + account.code = Disassembly("60606040") + environment = Environment(account, None, None, None, None, None) + state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) state.transaction_stack.append( (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) ) diff --git a/tests/instructions/shl_test.py b/tests/instructions/shl_test.py index 0e36c088..fb1680a5 100644 --- a/tests/instructions/shl_test.py +++ b/tests/instructions/shl_test.py @@ -12,9 +12,11 @@ from mythril.laser.smt import symbol_factory, simplify def get_state(): - active_account = Account("0x0", code=Disassembly("60606040")) - environment = Environment(active_account, None, None, None, None, None) - state = GlobalState(None, environment, None, MachineState(gas_limit=8000000)) + world_state = WorldState() + account = world_state.create_account(balance=10, address=101) + account.code = Disassembly("60606040") + environment = Environment(account, None, None, None, None, None) + state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) state.transaction_stack.append( (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) ) diff --git a/tests/instructions/shr_test.py b/tests/instructions/shr_test.py index aaf370e2..f0f66787 100644 --- a/tests/instructions/shr_test.py +++ b/tests/instructions/shr_test.py @@ -12,9 +12,11 @@ from mythril.laser.smt import symbol_factory, simplify, LShR def get_state(): - active_account = Account("0x0", code=Disassembly("60606040")) - environment = Environment(active_account, None, None, None, None, None) - state = GlobalState(None, environment, None, MachineState(gas_limit=8000000)) + world_state = WorldState() + account = world_state.create_account(balance=10, address=101) + account.code = Disassembly("60606040") + environment = Environment(account, None, None, None, None, None) + state = GlobalState(world_state, environment, None, MachineState(gas_limit=8000000)) state.transaction_stack.append( (MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None) ) diff --git a/tests/laser/evm_testsuite/evm_test.py b/tests/laser/evm_testsuite/evm_test.py index 381438cb..244b5844 100644 --- a/tests/laser/evm_testsuite/evm_test.py +++ b/tests/laser/evm_testsuite/evm_test.py @@ -125,7 +125,8 @@ def test_vmtest( account.code = Disassembly(details["code"][2:]) account.nonce = int(details["nonce"], 16) for key, value in details["storage"].items(): - account.storage[int(key, 16)] = int(value, 16) + key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256) + account.storage[key_bitvec] = symbol_factory.BitVecVal(int(value, 16), 256) world_state.put_account(account) account.set_balance(int(details["balance"], 16)) @@ -175,8 +176,7 @@ def test_vmtest( for index, value in details["storage"].items(): expected = int(value, 16) - actual = account.storage[int(index, 16)] - + actual = account.storage[symbol_factory.BitVecVal(int(index, 16), 256)] if isinstance(actual, Expression): actual = actual.value actual = 1 if actual is True else 0 if actual is False else actual diff --git a/tests/laser/state/storage_test.py b/tests/laser/state/storage_test.py index 35e31df3..2e8aa4f9 100644 --- a/tests/laser/state/storage_test.py +++ b/tests/laser/state/storage_test.py @@ -1,7 +1,10 @@ import pytest +from mythril.laser.smt import symbol_factory from mythril.laser.ethereum.state.account import Storage from mythril.laser.smt import Expression +BVV = symbol_factory.BitVecVal + storage_uninitialized_test_data = [({}, 1), ({1: 5}, 2), ({1: 5, 3: 10}, 2)] @@ -9,10 +12,11 @@ storage_uninitialized_test_data = [({}, 1), ({1: 5}, 2), ({1: 5, 3: 10}, 2)] def test_concrete_storage_uninitialized_index(initial_storage, key): # Arrange storage = Storage(concrete=True) - storage._storage = initial_storage + for k, val in initial_storage.items(): + storage[BVV(k, 256)] = BVV(val, 256) # Act - value = storage[key] + value = storage[BVV(key, 256)] # Assert assert value == 0 @@ -22,10 +26,11 @@ def test_concrete_storage_uninitialized_index(initial_storage, key): def test_symbolic_storage_uninitialized_index(initial_storage, key): # Arrange storage = Storage(concrete=False) - storage._storage = initial_storage + for k, val in initial_storage.items(): + storage[BVV(k, 256)] = BVV(val, 256) # Act - value = storage[key] + value = storage[BVV(key, 256)] # Assert assert isinstance(value, Expression) @@ -36,18 +41,18 @@ def test_storage_set_item(): storage = Storage() # Act - storage[1] = 13 + storage[BVV(1, 256)] = BVV(13, 256) # Assert - assert storage[1] == 13 + assert storage[BVV(1, 256)] == BVV(13, 256) def test_storage_change_item(): # Arrange storage = Storage() - storage._storage = {1: 12} + storage[BVV(1, 256)] = BVV(12, 256) # Act - storage[1] = 14 + storage[BVV(1, 256)] = BVV(14, 256) # Assert - assert storage[1] == 14 + assert storage[BVV(1, 256)] == BVV(14, 256)