Bye Bye complex code

remove/dos
Nikhil Parasaram 5 years ago
parent 9f2b2f759e
commit aaa8e132db
  1. 58
      mythril/laser/ethereum/state/account.py
  2. 62
      mythril/laser/smt/__init__.py
  3. 8
      mythril/laser/smt/bitvec_helper.py

@ -4,7 +4,7 @@ This includes classes representing accounts and their storage.
""" """
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Any, Dict, Union, Tuple, Set, cast from typing import Any, Dict, Union, Set
from mythril.laser.smt import ( from mythril.laser.smt import (
@ -12,10 +12,7 @@ from mythril.laser.smt import (
K, K,
BitVec, BitVec,
simplify, simplify,
BitVecFunc,
Extract,
BaseArray, BaseArray,
Concat,
) )
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
from mythril.laser.smt import symbol_factory from mythril.laser.smt import symbol_factory
@ -23,26 +20,6 @@ from mythril.laser.smt import symbol_factory
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
class StorageRegion:
def __getitem__(self, item):
raise NotImplementedError
def __setitem__(self, key, value):
raise NotImplementedError
class ArrayStorageRegion(StorageRegion):
""" An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions"""
pass
class IteStorageRegion(StorageRegion):
""" An IteStorageRegion is a storage region that uses Ite statements to implement a storage"""
pass
class Storage: class Storage:
"""Storage class represents the storage of an Account.""" """Storage class represents the storage of an Account."""
@ -62,18 +39,8 @@ class Storage:
self.storage_keys_loaded = set() # type: Set[int] self.storage_keys_loaded = set() # type: Set[int]
self.address = address self.address = address
@staticmethod
def _sanitize(input_: BitVec) -> BitVec:
if input_.size() == 512:
return input_
if input_.size() > 512:
return Extract(511, 0, input_)
else:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec: def __getitem__(self, item: BitVec) -> BitVec:
storage = self._standard_storage storage = self._standard_storage
sanitized_item = item
if ( if (
self.address self.address
and self.address.value != 0 and self.address.value != 0
@ -82,7 +49,7 @@ class Storage:
and (self.dynld and self.dynld.storage_loading) and (self.dynld and self.dynld.storage_loading)
): ):
try: try:
storage[sanitized_item] = symbol_factory.BitVecVal( storage[item] = symbol_factory.BitVecVal(
int( int(
self.dynld.read_storage( self.dynld.read_storage(
contract_address="0x{:040X}".format(self.address.value), contract_address="0x{:040X}".format(self.address.value),
@ -93,29 +60,14 @@ class Storage:
256, 256,
) )
self.storage_keys_loaded.add(int(item.value)) self.storage_keys_loaded.add(int(item.value))
self.printable_storage[item] = storage[sanitized_item] self.printable_storage[item] = storage[item]
except ValueError as e: except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e) log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item]) return simplify(storage[item])
@staticmethod
def get_map_index(key: BitVec) -> BitVec:
if (
not isinstance(key, BitVecFunc)
or key.func_name != "keccak256"
or key.input_ is None
):
return None
index = Extract(255, 0, key.input_)
return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
return self._standard_storage
def __setitem__(self, key, value: Any) -> None: def __setitem__(self, key, value: Any) -> None:
storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value self.printable_storage[key] = value
storage[key] = value self._standard_storage[key] = value
if key.symbolic is False: if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value)) self.storage_keys_loaded.add(int(key.value))

@ -79,44 +79,6 @@ class SymbolFactory(Generic[T, U]):
""" """
raise NotImplementedError() raise NotImplementedError()
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param value: The concrete value to set the bit vector to
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value.
:param name: The name of the symbolic bit vector
:param func_name: The name of the bit vector function
:param size: The size of the bit vector
:param annotations: The annotations to initialize the bit vector with
:param input_: The input to the bit vector function
:return: The freshly created bit vector function
"""
raise NotImplementedError()
class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]): class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
""" """
@ -158,30 +120,6 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
raw = z3.BitVec(name, size) raw = z3.BitVec(name, size)
return BitVec(raw, annotations) return BitVec(raw, annotations)
@staticmethod
def BitVecFuncVal(
value: int,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a concrete value."""
raw = z3.BitVecVal(value, size)
return BitVecFunc(raw, func_name, input_, annotations)
@staticmethod
def BitVecFuncSym(
name: str,
func_name: str,
size: int,
annotations: Annotations = None,
input_: "BitVec" = None,
) -> BitVecFunc:
"""Creates a new bit vector function with a symbolic value."""
raw = z3.BitVec(name, size)
return BitVecFunc(raw, func_name, input_, annotations)
class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]): class _Z3SymbolFactory(SymbolFactory[z3.BoolRef, z3.BitVecRef]):
""" """

@ -7,9 +7,7 @@ from mythril.laser.smt.bitvec import BitVec
Annotations = Set[Any] Annotations = Set[Any]
def _comparison_helper( def _comparison_helper(a: BitVec, b: BitVec, operation: Callable) -> Bool:
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations.union(b.annotations) annotations = a.annotations.union(b.annotations)
return Bool(operation(a.raw, b.raw), annotations) return Bool(operation(a.raw, b.raw), annotations)
@ -49,7 +47,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.UGT)
def UGE(a: BitVec, b: BitVec) -> Bool: def UGE(a: BitVec, b: BitVec) -> Bool:
@ -69,7 +67,7 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b: :param b:
:return: :return:
""" """
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) return _comparison_helper(a, b, z3.ULT)
def ULE(a: BitVec, b: BitVec) -> Bool: def ULE(a: BitVec, b: BitVec) -> Bool:

Loading…
Cancel
Save