diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 327fcc82..9ec82f05 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -6,7 +6,7 @@ from copy import deepcopy from typing import Any, Dict, Union -from mythril.laser.smt import Array, K, BitVec, simplify, BitVecFunc, Extract +from mythril.laser.smt import Array, K, BitVec, simplify, BitVecFunc, Extract, BaseArray from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory @@ -20,8 +20,8 @@ class Storage: :param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic """ if concrete: - self._standard_storage = K(256, 256, 0) - self._map_storage = {} + self._standard_storage = K(256, 256, 0) # type: BaseArray + self._map_storage = {} # type: Dict[BitVec, BaseArray] else: self._standard_storage = Array("Storage", 256, 256) self._map_storage = {} @@ -29,18 +29,22 @@ class Storage: self.dynld = dynamic_loader self.address = address - def __getitem__(self, item: Union[str, int]) -> Any: + def __getitem__(self, item: BitVec) -> Any: storage = self._get_corresponding_storage(item) return simplify(storage[item]) @staticmethod - def get_map_index(key): - if not isinstance(key, BitVecFunc) or key.func_name != "keccak256": + def get_map_index(key: BitVec) -> BitVec: + if ( + not isinstance(key, BitVecFunc) + or key.func_name != "keccak256" + or key.input_ is None + ): return None index = Extract(255, 0, key.input_) return simplify(index) - def _get_corresponding_storage(self, key): + def _get_corresponding_storage(self, key: BitVec) -> BaseArray: index = self.get_map_index(key) if index is None: storage = self._standard_storage @@ -68,7 +72,7 @@ class Storage: storage._map_storage = deepcopy(self._map_storage) return storage - def __str__(self): + def __str__(self) -> str: return str(self._standard_storage) diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index d849d0c4..5a3f1ca1 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -72,7 +72,7 @@ class SymbolFactory(Generic[T, U]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value. @@ -91,7 +91,7 @@ class SymbolFactory(Generic[T, U]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value. @@ -140,7 +140,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a concrete value.""" raw = z3.BitVecVal(value, size) @@ -152,7 +152,7 @@ class _SmtSymbolFactory(SymbolFactory[bool.Bool, BitVec]): func_name: str, size: int, annotations: Annotations = None, - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, ) -> BitVecFunc: """Creates a new bit vector function with a symbolic value.""" raw = z3.BitVec(name, size) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index cbf60d69..257d2b46 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -79,7 +79,7 @@ class BitVecFunc(BitVec): self, raw: z3.BitVecRef, func_name: Optional[str], - input_: Union[int, "BitVec"] = None, + input_: "BitVec" = None, annotations: Optional[Annotations] = None, ): """