Fix type hints

pull/1133/head
Nikhil 5 years ago
parent 541182db03
commit 556d138d50
  1. 38
      mythril/laser/ethereum/state/memory.py
  2. 28
      mythril/laser/smt/bitvecfunc.py

@ -1,6 +1,6 @@
"""This module contains a representation of a smart contract's memory."""
from copy import copy
from typing import cast, List, Union, overload
from typing import cast, Dict, List, Union, overload
from z3 import Z3Exception
from mythril.laser.ethereum import util
@ -27,7 +27,7 @@ class Memory:
def __init__(self):
""""""
self._msize = 0
self._memory = {}
self._memory = {} # type: Dict[BitVec, Union[int, BitVec]]
def __len__(self):
"""
@ -111,7 +111,7 @@ class Memory:
self[index + 31 - (i // 8)] = Extract(i + 7, i, value_to_write)
@overload
def __getitem__(self, item: int) -> Union[int, BitVec]:
def __getitem__(self, item: BitVec) -> Union[int, BitVec]:
...
@overload
@ -119,7 +119,7 @@ class Memory:
...
def __getitem__(
self, item: Union[int, slice]
self, item: Union[BitVec, slice]
) -> Union[BitVec, int, List[Union[int, BitVec]]]:
"""
@ -134,11 +134,15 @@ class Memory:
raise IndexError("Invalid Memory Slice")
if step is None:
step = 1
start, stop, step = convert_bv(start), convert_bv(stop), convert_bv(step)
bvstart, bvstop, bvstep = (
convert_bv(start),
convert_bv(stop),
convert_bv(step),
)
ret_lis = []
itr = symbol_factory.BitVecVal(0, 256)
while simplify(start + itr < stop) and itr <= 10000000:
ret_lis.append(self[start + step * itr])
while simplify(bvstart + itr < bvstop) and itr <= 10000000:
ret_lis.append(self[bvstart + bvstep * itr])
itr += 1
return ret_lis
@ -148,7 +152,7 @@ class Memory:
def __setitem__(
self,
key: Union[int, slice],
key: Union[int, BitVec, slice],
value: Union[BitVec, int, List[Union[int, BitVec]]],
):
"""
@ -168,18 +172,24 @@ class Memory:
else:
assert False, "Currently mentioning step size is not supported"
assert type(value) == list
start, stop, step = convert_bv(start), convert_bv(stop), convert_bv(step)
bvstart, bvstop, bvstep = (
convert_bv(start),
convert_bv(stop),
convert_bv(step),
)
itr = symbol_factory.BitVecVal(0, 256)
while simplify(start + itr < stop) and itr <= 10000000:
self[start + itr] = value[itr.value]
while simplify(bvstart + bvstep * itr < bvstop) and itr <= 10000000:
self[bvstart + itr * bvstep] = cast(List[Union[int, BitVec]], value)[
itr.value
]
itr += 1
else:
key = simplify(convert_bv(key))
if key >= len(self):
bv_key = simplify(convert_bv(key))
if bv_key >= len(self):
return
if isinstance(value, int):
assert 0 <= value <= 0xFF
if isinstance(value, BitVec):
assert value.size() == 8
self._memory[key] = cast(Union[int, BitVec], value)
self._memory[bv_key] = cast(Union[int, BitVec], value)

@ -134,56 +134,68 @@ class BitVecFunc(BitVec):
"""
return _arithmetic_helper(self, other, operator.and_)
def __or__(self, other: "BitVec") -> "BitVecFunc":
def __or__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create an or expression.
:param other: The int or BitVec to or with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.or_)
def __xor__(self, other: "BitVec") -> "BitVecFunc":
def __xor__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":
"""Create a xor expression.
:param other: The int or BitVec to xor with this BitVecFunc
:return: The resulting BitVecFunc
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _arithmetic_helper(self, other, operator.xor)
def __lt__(self, other: "BitVec") -> Bool:
def __lt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.lt, default_value=False, inputs_equal=False
)
def __gt__(self, other: "BitVec") -> Bool:
def __gt__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.gt, default_value=False, inputs_equal=False
)
def __le__(self, other: "BitVec") -> Bool:
def __le__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed less than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self < other, self == other)
def __ge__(self, other: "BitVec") -> Bool:
def __ge__(self, other: Union[int, "BitVec"]) -> Bool:
"""Create a signed greater than or equal to expression.
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return Or(self > other, self == other)
# MYPY: fix complains about overriding __eq__
@ -193,6 +205,8 @@ class BitVecFunc(BitVec):
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.eq, default_value=False, inputs_equal=True
)
@ -204,6 +218,8 @@ class BitVecFunc(BitVec):
:param other: The int or BitVec to compare to this BitVecFunc
:return: The resulting Bool
"""
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, self.size()))
return _comparison_helper(
self, other, operator.ne, default_value=True, inputs_equal=False
)

Loading…
Cancel
Save