Nikhil Parasaram 6 years ago committed by GitHub
commit b6fcc593c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 5
      mythril/analysis/modules/base.py
  2. 4
      mythril/analysis/modules/delegatecall.py
  3. 6
      mythril/analysis/modules/dependence_on_predictable_vars.py
  4. 12
      mythril/analysis/modules/integer.py
  5. 16
      mythril/analysis/modules/multiple_sends.py
  6. 25
      mythril/analysis/modules/unchecked_retval.py
  7. 2
      mythril/disassembler/asm.py
  8. 13
      mythril/disassembler/disassembly.py
  9. 34
      mythril/laser/ethereum/call.py
  10. 14
      mythril/laser/ethereum/cfg.py
  11. 3
      mythril/laser/ethereum/gas.py
  12. 57
      mythril/laser/ethereum/instructions.py
  13. 2
      mythril/laser/ethereum/keccak.py
  14. 45
      mythril/laser/ethereum/natives.py
  15. 10
      mythril/laser/ethereum/state/account.py
  16. 57
      mythril/laser/ethereum/state/calldata.py
  17. 2
      mythril/laser/ethereum/state/environment.py
  18. 11
      mythril/laser/ethereum/state/global_state.py
  19. 22
      mythril/laser/ethereum/state/machine_state.py
  20. 39
      mythril/laser/ethereum/state/memory.py
  21. 9
      mythril/laser/ethereum/state/world_state.py
  22. 8
      mythril/laser/ethereum/strategy/basic.py
  23. 38
      mythril/laser/ethereum/svm.py
  24. 25
      mythril/laser/ethereum/transaction/transaction_models.py
  25. 37
      mythril/laser/ethereum/util.py
  26. 19
      mythril/laser/smt/bitvec.py
  27. 10
      mythril/laser/smt/bool.py
  28. 5
      mythril/laser/smt/expression.py
  29. 2
      mythril/support/loader.py
  30. 11
      mythril/support/signatures.py
  31. 3
      mythril/support/support_utils.py
  32. 2
      setup.py
  33. 4
      tox.ini

@ -3,6 +3,7 @@ modules."""
import logging import logging
from typing import List from typing import List
from mythril.analysis.report import Issue
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -21,7 +22,7 @@ class DetectionModule:
entrypoint: str = "post", entrypoint: str = "post",
pre_hooks: List[str] = None, pre_hooks: List[str] = None,
post_hooks: List[str] = None, post_hooks: List[str] = None,
): ) -> None:
self.name = name self.name = name
self.swc_id = swc_id self.swc_id = swc_id
self.pre_hooks = pre_hooks if pre_hooks else [] self.pre_hooks = pre_hooks if pre_hooks else []
@ -33,7 +34,7 @@ class DetectionModule:
self.name, self.name,
) )
self.entrypoint = entrypoint self.entrypoint = entrypoint
self._issues = [] self._issues = [] # type: List[Issue]
@property @property
def issues(self): def issues(self):

@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class DelegateCallModule(DetectionModule): class DelegateCallModule(DetectionModule):
"""This module detects calldata being forwarded using DELEGATECALL.""" """This module detects calldata being forwarded using DELEGATECALL."""
def __init__(self): def __init__(self) -> None:
"""""" """"""
super().__init__( super().__init__(
name="DELEGATECALL Usage in Fallback Function", name="DELEGATECALL Usage in Fallback Function",
@ -46,7 +46,7 @@ def _analyze_states(state: GlobalState) -> List[Issue]:
call = get_call_from_state(state) call = get_call_from_state(state)
if call is None: if call is None:
return [] return []
issues = [] issues = [] # type: List[Issue]
if call.type is not "DELEGATECALL": if call.type is not "DELEGATECALL":
return [] return []

@ -19,7 +19,7 @@ class PredictableDependenceModule(DetectionModule):
"""This module detects whether Ether is sent using predictable """This module detects whether Ether is sent using predictable
parameters.""" parameters."""
def __init__(self): def __init__(self) -> None:
"""""" """"""
super().__init__( super().__init__(
name="Dependence of Predictable Variables", name="Dependence of Predictable Variables",
@ -118,9 +118,9 @@ def _analyze_states(state: GlobalState) -> list:
m = re.search(r"blockhash\w+(\s-\s(\d+))*", str(constraint)) m = re.search(r"blockhash\w+(\s-\s(\d+))*", str(constraint))
if m and solve(call): if m and solve(call):
found = m.group(1) found_item = m.group(1)
if found: # block.blockhash(block.number - N) if found_item: # block.blockhash(block.number - N)
description = ( description = (
"The predictable expression 'block.blockhash(block.number - " "The predictable expression 'block.blockhash(block.number - "
+ m.group(2) + m.group(2)

@ -2,7 +2,7 @@
underflows.""" underflows."""
import json import json
from typing import Dict
from mythril.analysis import solver from mythril.analysis import solver
from mythril.analysis.report import Issue from mythril.analysis.report import Issue
from mythril.analysis.swc_data import INTEGER_OVERFLOW_AND_UNDERFLOW from mythril.analysis.swc_data import INTEGER_OVERFLOW_AND_UNDERFLOW
@ -27,7 +27,9 @@ log = logging.getLogger(__name__)
class OverUnderflowAnnotation: class OverUnderflowAnnotation:
def __init__(self, overflowing_state: GlobalState, operator: str, constraint): def __init__(
self, overflowing_state: GlobalState, operator: str, constraint
) -> None:
self.overflowing_state = overflowing_state self.overflowing_state = overflowing_state
self.operator = operator self.operator = operator
self.constraint = constraint self.constraint = constraint
@ -36,7 +38,7 @@ class OverUnderflowAnnotation:
class IntegerOverflowUnderflowModule(DetectionModule): class IntegerOverflowUnderflowModule(DetectionModule):
"""This module searches for integer over- and underflows.""" """This module searches for integer over- and underflows."""
def __init__(self): def __init__(self) -> None:
"""""" """"""
super().__init__( super().__init__(
name="Integer Overflow and Underflow", name="Integer Overflow and Underflow",
@ -49,8 +51,8 @@ class IntegerOverflowUnderflowModule(DetectionModule):
entrypoint="callback", entrypoint="callback",
pre_hooks=["ADD", "MUL", "SUB", "SSTORE", "JUMPI"], pre_hooks=["ADD", "MUL", "SUB", "SSTORE", "JUMPI"],
) )
self._overflow_cache = {} self._overflow_cache = {} # type: Dict[int, bool]
self._underflow_cache = {} self._underflow_cache = {} # type: Dict[int, bool]
def reset_module(self): def reset_module(self):
""" """

@ -1,7 +1,9 @@
"""This module contains the detection code to find multiple sends occurring in """This module contains the detection code to find multiple sends occurring in
a single transaction.""" a single transaction."""
from copy import copy from copy import copy
from typing import cast, List, Optional
from mythril.analysis.ops import Call
from mythril.analysis.report import Issue from mythril.analysis.report import Issue
from mythril.analysis.swc_data import MULTIPLE_SENDS from mythril.analysis.swc_data import MULTIPLE_SENDS
from mythril.analysis.modules.base import DetectionModule from mythril.analysis.modules.base import DetectionModule
@ -14,8 +16,8 @@ log = logging.getLogger(__name__)
class MultipleSendsAnnotation(StateAnnotation): class MultipleSendsAnnotation(StateAnnotation):
def __init__(self): def __init__(self) -> None:
self.calls = [] self.calls = [] # type: List[Optional[Call]]
def __copy__(self): def __copy__(self):
result = MultipleSendsAnnotation() result = MultipleSendsAnnotation()
@ -56,11 +58,17 @@ def _analyze_state(state: GlobalState):
node = state.node node = state.node
instruction = state.get_current_instruction() instruction = state.get_current_instruction()
annotations = [a for a in state.get_annotations(MultipleSendsAnnotation)] annotations = cast(
List[MultipleSendsAnnotation],
[a for a in state.get_annotations(MultipleSendsAnnotation)],
)
if len(annotations) == 0: if len(annotations) == 0:
log.debug("Creating annotation for state") log.debug("Creating annotation for state")
state.annotate(MultipleSendsAnnotation()) state.annotate(MultipleSendsAnnotation())
annotations = [a for a in state.get_annotations(MultipleSendsAnnotation)] annotations = cast(
List[MultipleSendsAnnotation],
[a for a in state.get_annotations(MultipleSendsAnnotation)],
)
calls = annotations[0].calls calls = annotations[0].calls

@ -1,12 +1,15 @@
"""This module contains detection code to find occurrences of calls whose """This module contains detection code to find occurrences of calls whose
return value remains unchecked.""" return value remains unchecked."""
from copy import copy from copy import copy
from typing import cast, List, Union, Mapping
from mythril.analysis import solver from mythril.analysis import solver
from mythril.analysis.report import Issue from mythril.analysis.report import Issue
from mythril.analysis.swc_data import UNCHECKED_RET_VAL from mythril.analysis.swc_data import UNCHECKED_RET_VAL
from mythril.analysis.modules.base import DetectionModule from mythril.analysis.modules.base import DetectionModule
from mythril.exceptions import UnsatError from mythril.exceptions import UnsatError
from mythril.laser.smt.bitvec import BitVec
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
@ -16,8 +19,8 @@ log = logging.getLogger(__name__)
class UncheckedRetvalAnnotation(StateAnnotation): class UncheckedRetvalAnnotation(StateAnnotation):
def __init__(self): def __init__(self) -> None:
self.retvals = [] self.retvals = [] # type: List[Mapping[str, Union[int, BitVec]]]
def __copy__(self): def __copy__(self):
result = UncheckedRetvalAnnotation() result = UncheckedRetvalAnnotation()
@ -60,10 +63,16 @@ def _analyze_state(state: GlobalState) -> list:
instruction = state.get_current_instruction() instruction = state.get_current_instruction()
node = state.node node = state.node
annotations = [a for a in state.get_annotations(UncheckedRetvalAnnotation)] annotations = cast(
List[UncheckedRetvalAnnotation],
[a for a in state.get_annotations(UncheckedRetvalAnnotation)],
)
if len(annotations) == 0: if len(annotations) == 0:
state.annotate(UncheckedRetvalAnnotation()) state.annotate(UncheckedRetvalAnnotation())
annotations = [a for a in state.get_annotations(UncheckedRetvalAnnotation)] annotations = cast(
List[UncheckedRetvalAnnotation],
[a for a in state.get_annotations(UncheckedRetvalAnnotation)],
)
retvals = annotations[0].retvals retvals = annotations[0].retvals
@ -103,7 +112,13 @@ def _analyze_state(state: GlobalState) -> list:
"opcode" "opcode"
] in ["CALL", "DELEGATECALL", "STATICCALL", "CALLCODE"] ] in ["CALL", "DELEGATECALL", "STATICCALL", "CALLCODE"]
retval = state.mstate.stack[-1] retval = state.mstate.stack[-1]
retvals.append({"address": state.instruction["address"] - 1, "retval": retval}) # Use Typed Dict after release of mypy 0.670 and remove type ignore
retvals.append(
{ # type: ignore
"address": state.instruction["address"] - 1,
"retval": retval,
}
)
return [] return []

@ -90,7 +90,7 @@ def is_sequence_match(pattern: list, instruction_list: list, index: int) -> bool
return True return True
def disassemble(bytecode: str) -> list: def disassemble(bytecode: bytes) -> list:
"""Disassembles evm bytecode and returns a list of instructions. """Disassembles evm bytecode and returns a list of instructions.
:param bytecode: :param bytecode:

@ -3,6 +3,8 @@ from mythril.ethereum import util
from mythril.disassembler import asm from mythril.disassembler import asm
from mythril.support.signatures import SignatureDB from mythril.support.signatures import SignatureDB
from typing import Dict, List, Tuple
class Disassembly(object): class Disassembly(object):
"""Disassembly class. """Disassembly class.
@ -14,7 +16,7 @@ class Disassembly(object):
- function entry point to function name mapping - function entry point to function name mapping
""" """
def __init__(self, code: str, enable_online_lookup: bool = False): def __init__(self, code: str, enable_online_lookup: bool = False) -> None:
""" """
:param code: :param code:
@ -23,9 +25,9 @@ class Disassembly(object):
self.bytecode = code self.bytecode = code
self.instruction_list = asm.disassemble(util.safe_decode(code)) self.instruction_list = asm.disassemble(util.safe_decode(code))
self.func_hashes = [] self.func_hashes = [] # type: List[str]
self.function_name_to_address = {} self.function_name_to_address = {} # type: Dict[str, int]
self.address_to_function_name = {} self.address_to_function_name = {} # type: Dict[int, str]
# open from default locations # open from default locations
# control if you want to have online signature hash lookups # control if you want to have online signature hash lookups
@ -41,7 +43,6 @@ class Disassembly(object):
index, self.instruction_list, signatures index, self.instruction_list, signatures
) )
self.func_hashes.append(function_hash) self.func_hashes.append(function_hash)
if jump_target is not None and function_name is not None: if jump_target is not None and function_name is not None:
self.function_name_to_address[function_name] = jump_target self.function_name_to_address[function_name] = jump_target
self.address_to_function_name[jump_target] = function_name self.address_to_function_name[jump_target] = function_name
@ -56,7 +57,7 @@ class Disassembly(object):
def get_function_info( def get_function_info(
index: int, instruction_list: list, signature_database: SignatureDB index: int, instruction_list: list, signature_database: SignatureDB
) -> (str, int, str): ) -> Tuple[str, int, str]:
"""Finds the function information for a call table entry Solidity uses the """Finds the function information for a call table entry Solidity uses the
first 4 bytes of the calldata to indicate which function the message call first 4 bytes of the calldata to indicate which function the message call
should execute The generated code that directs execution to the correct should execute The generated code that directs execution to the correct

@ -3,9 +3,9 @@ instructions.py to get the necessary elements from the stack and determine the
parameters for the new global state.""" parameters for the new global state."""
import logging import logging
from typing import Union, List from typing import Union, List, cast, Callable
from z3 import Z3Exception from z3 import Z3Exception
from mythril.laser.smt import BitVec
from mythril.laser.ethereum import natives from mythril.laser.ethereum import natives
from mythril.laser.ethereum.gas import OPCODE_GAS from mythril.laser.ethereum.gas import OPCODE_GAS
from mythril.laser.smt import simplify, Expression, symbol_factory from mythril.laser.smt import simplify, Expression, symbol_factory
@ -155,8 +155,8 @@ def get_callee_account(
def get_call_data( def get_call_data(
global_state: GlobalState, global_state: GlobalState,
memory_start: Union[int, Expression], memory_start: Union[int, BitVec],
memory_size: Union[int, Expression], memory_size: Union[int, BitVec],
): ):
"""Gets call_data from the global_state. """Gets call_data from the global_state.
@ -168,22 +168,28 @@ def get_call_data(
state = global_state.mstate state = global_state.mstate
transaction_id = "{}_internalcall".format(global_state.current_transaction.id) transaction_id = "{}_internalcall".format(global_state.current_transaction.id)
memory_start = ( memory_start = cast(
symbol_factory.BitVecVal(memory_start, 256) BitVec,
if isinstance(memory_start, int) (
else memory_start symbol_factory.BitVecVal(memory_start, 256)
if isinstance(memory_start, int)
else memory_start
),
) )
memory_size = ( memory_size = cast(
symbol_factory.BitVecVal(memory_size, 256) BitVec,
if isinstance(memory_size, int) (
else memory_size symbol_factory.BitVecVal(memory_size, 256)
if isinstance(memory_size, int)
else memory_size
),
) )
uses_entire_calldata = simplify( uses_entire_calldata = simplify(
memory_size - global_state.environment.calldata.calldatasize == 0 memory_size - global_state.environment.calldata.calldatasize == 0
) )
if uses_entire_calldata == True: if uses_entire_calldata is True:
return global_state.environment.calldata return global_state.environment.calldata
try: try:
@ -218,7 +224,7 @@ def native_call(
contract_list = ["ecrecover", "sha256", "ripemd160", "identity"] contract_list = ["ecrecover", "sha256", "ripemd160", "identity"]
call_address_int = int(callee_address, 16) call_address_int = int(callee_address, 16)
native_gas_min, native_gas_max = OPCODE_GAS["NATIVE_COST"]( native_gas_min, native_gas_max = cast(Callable, OPCODE_GAS["NATIVE_COST"])(
global_state.mstate.calculate_extension_size(mem_out_start, mem_out_sz), global_state.mstate.calculate_extension_size(mem_out_start, mem_out_sz),
contract_list[call_address_int - 1], contract_list[call_address_int - 1],
) )

@ -1,9 +1,12 @@
"""This module.""" """This module."""
from enum import Enum from enum import Enum
from typing import Dict from typing import Dict, List, TYPE_CHECKING
from flags import Flags from flags import Flags
if TYPE_CHECKING:
from mythril.laser.ethereum.state.global_state import GlobalState
gbl_next_uid = 0 # node counter gbl_next_uid = 0 # node counter
@ -20,6 +23,9 @@ class JumpType(Enum):
class NodeFlags(Flags): class NodeFlags(Flags):
"""A collection of flags to denote the type a call graph node can have.""" """A collection of flags to denote the type a call graph node can have."""
def __or__(self, other) -> "NodeFlags":
return super().__or__(other)
FUNC_ENTRY = 1 FUNC_ENTRY = 1
CALL_RETURN = 2 CALL_RETURN = 2
@ -33,7 +39,7 @@ class Node:
start_addr=0, start_addr=0,
constraints=None, constraints=None,
function_name="unknown", function_name="unknown",
): ) -> None:
""" """
:param contract_name: :param contract_name:
@ -43,7 +49,7 @@ class Node:
constraints = constraints if constraints else [] constraints = constraints if constraints else []
self.contract_name = contract_name self.contract_name = contract_name
self.start_addr = start_addr self.start_addr = start_addr
self.states = [] self.states = [] # type: List[GlobalState]
self.constraints = constraints self.constraints = constraints
self.function_name = function_name self.function_name = function_name
self.flags = NodeFlags() self.flags = NodeFlags()
@ -86,7 +92,7 @@ class Edge:
node_to: int, node_to: int,
edge_type=JumpType.UNCONDITIONAL, edge_type=JumpType.UNCONDITIONAL,
condition=None, condition=None,
): ) -> None:
""" """
:param node_from: :param node_from:

@ -2,6 +2,7 @@
table.""" table."""
from ethereum import opcodes from ethereum import opcodes
from ethereum.utils import ceil32 from ethereum.utils import ceil32
from typing import Callable, Dict, Tuple, Union
def calculate_native_gas(size: int, contract: str): def calculate_native_gas(size: int, contract: str):
@ -185,4 +186,4 @@ OPCODE_GAS = {
"SUICIDE": (5000, 30000), "SUICIDE": (5000, 30000),
"ASSERT_FAIL": (0, 0), "ASSERT_FAIL": (0, 0),
"INVALID": (0, 0), "INVALID": (0, 0),
} } # type: Dict[str, Union[Tuple[int, int], Callable]]

@ -4,7 +4,7 @@ import binascii
import logging import logging
from copy import copy, deepcopy from copy import copy, deepcopy
from typing import Callable, List, Union from typing import cast, Callable, List, Union, Tuple
from datetime import datetime from datetime import datetime
from ethereum import utils from ethereum import utils
@ -127,7 +127,7 @@ class StateTransition(object):
if not self.enable_gas: if not self.enable_gas:
return global_state return global_state
opcode = global_state.instruction["opcode"] opcode = global_state.instruction["opcode"]
min_gas, max_gas = OPCODE_GAS[opcode] min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS[opcode])
global_state.mstate.min_gas_used += min_gas global_state.mstate.min_gas_used += min_gas
global_state.mstate.max_gas_used += max_gas global_state.mstate.max_gas_used += max_gas
return global_state return global_state
@ -155,7 +155,7 @@ class Instruction:
"""Instruction class is used to mutate a state according to the current """Instruction class is used to mutate a state according to the current
instruction.""" instruction."""
def __init__(self, op_code: str, dynamic_loader: DynLoader, iprof=None): def __init__(self, op_code: str, dynamic_loader: DynLoader, iprof=None) -> None:
""" """
:param op_code: :param op_code:
@ -358,7 +358,7 @@ class Instruction:
symbol_factory.BitVecVal(0, 248), symbol_factory.BitVecVal(0, 248),
Extract(offset + 7, offset, op1), Extract(offset + 7, offset, op1),
) )
) ) # type: Union[int, Expression]
else: else:
result = 0 result = 0
except TypeError: except TypeError:
@ -717,17 +717,15 @@ class Instruction:
log.debug("Unsupported symbolic memory offset in CALLDATACOPY") log.debug("Unsupported symbolic memory offset in CALLDATACOPY")
return [global_state] return [global_state]
dstart_sym = False
try: try:
dstart = util.get_concrete_int(op1) dstart = util.get_concrete_int(op1) # type: Union[int, BitVec]
except TypeError: except TypeError:
log.debug("Unsupported symbolic calldata offset in CALLDATACOPY") log.debug("Unsupported symbolic calldata offset in CALLDATACOPY")
dstart = simplify(op1) dstart = simplify(op1)
dstart_sym = True
size_sym = False size_sym = False
try: try:
size = util.get_concrete_int(op2) size = util.get_concrete_int(op2) # type: Union[int, BitVec]
except TypeError: except TypeError:
log.debug("Unsupported symbolic size in CALLDATACOPY") log.debug("Unsupported symbolic size in CALLDATACOPY")
size = simplify(op2) size = simplify(op2)
@ -746,7 +744,7 @@ class Instruction:
8, 8,
) )
return [global_state] return [global_state]
size = cast(int, size)
if size > 0: if size > 0:
try: try:
state.mem_extend(mstart, size) state.mem_extend(mstart, size)
@ -778,7 +776,9 @@ class Instruction:
new_memory.append(value) new_memory.append(value)
i_data = ( i_data = (
i_data + 1 if isinstance(i_data, int) else simplify(i_data + 1) i_data + 1
if isinstance(i_data, int)
else simplify(cast(BitVec, i_data) + 1)
) )
for i in range(len(new_memory)): for i in range(len(new_memory)):
state.memory[i + mstart] = new_memory[i] state.memory[i + mstart] = new_memory[i]
@ -881,11 +881,12 @@ class Instruction:
state.stack.append( state.stack.append(
symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256) symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256)
) )
state.min_gas_used += OPCODE_GAS["SHA3"][0] gas_tuple = cast(Tuple, OPCODE_GAS["SHA3"])
state.max_gas_used += OPCODE_GAS["SHA3"][1] state.min_gas_used += gas_tuple[0]
state.max_gas_used += gas_tuple[1]
return [global_state] return [global_state]
min_gas, max_gas = OPCODE_GAS["SHA3_FUNC"](length) min_gas, max_gas = cast(Callable, OPCODE_GAS["SHA3_FUNC"])(length)
state.min_gas_used += min_gas state.min_gas_used += min_gas
state.max_gas_used += max_gas state.max_gas_used += max_gas
StateTransition.check_gas_usage_limit(global_state) StateTransition.check_gas_usage_limit(global_state)
@ -1268,7 +1269,9 @@ class Instruction:
state.mem_extend(offset, 1) state.mem_extend(offset, 1)
try: try:
value_to_write = util.get_concrete_int(value) ^ 0xFF value_to_write = (
util.get_concrete_int(value) ^ 0xFF
) # type: Union[int, BitVec]
except TypeError: # BitVec except TypeError: # BitVec
value_to_write = Extract(7, 0, value) value_to_write = Extract(7, 0, value)
log.debug("MSTORE8 to mem[" + str(offset) + "]: " + str(value_to_write)) log.debug("MSTORE8 to mem[" + str(offset) + "]: " + str(value_to_write))
@ -1301,7 +1304,7 @@ class Instruction:
storage_keys = global_state.environment.active_account.storage.keys() storage_keys = global_state.environment.active_account.storage.keys()
keccak_keys = list(filter(keccak_function_manager.is_keccak, storage_keys)) keccak_keys = list(filter(keccak_function_manager.is_keccak, storage_keys))
results = [] results = [] # type: List[GlobalState]
constraints = [] constraints = []
for keccak_key in keccak_keys: for keccak_key in keccak_keys:
@ -1328,7 +1331,7 @@ class Instruction:
@staticmethod @staticmethod
def _sload_helper( def _sload_helper(
global_state: GlobalState, index: Union[int, Expression], constraints=None global_state: GlobalState, index: Union[str, int], constraints=None
): ):
""" """
@ -1387,17 +1390,21 @@ class Instruction:
storage_keys = global_state.environment.active_account.storage.keys() storage_keys = global_state.environment.active_account.storage.keys()
keccak_keys = filter(keccak_function_manager.is_keccak, storage_keys) keccak_keys = filter(keccak_function_manager.is_keccak, storage_keys)
results = [] results = [] # type: List[GlobalState]
new = symbol_factory.Bool(False) new = symbol_factory.Bool(False)
for keccak_key in keccak_keys: for keccak_key in keccak_keys:
key_argument = keccak_function_manager.get_argument(keccak_key) key_argument = keccak_function_manager.get_argument(
index_argument = keccak_function_manager.get_argument(index) keccak_key
) # type: Expression
index_argument = keccak_function_manager.get_argument(
index
) # type: Expression
condition = key_argument == index_argument condition = key_argument == index_argument
condition = ( condition = (
condition condition
if type(condition) == bool if type(condition) == bool
else is_true(simplify(condition)) else is_true(simplify(cast(Bool, condition)))
) )
if condition: if condition:
return self._sstore_helper( return self._sstore_helper(
@ -1414,7 +1421,7 @@ class Instruction:
key_argument == index_argument, key_argument == index_argument,
) )
new = Or(new, key_argument != index_argument) new = Or(new, cast(Bool, key_argument != index_argument))
if len(results) > 0: if len(results) > 0:
results += self._sstore_helper( results += self._sstore_helper(
@ -1482,7 +1489,7 @@ class Instruction:
new_state = copy(global_state) new_state = copy(global_state)
# add JUMP gas cost # add JUMP gas cost
min_gas, max_gas = OPCODE_GAS["JUMP"] min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS["JUMP"])
new_state.mstate.min_gas_used += min_gas new_state.mstate.min_gas_used += min_gas
new_state.mstate.max_gas_used += max_gas new_state.mstate.max_gas_used += max_gas
@ -1501,7 +1508,7 @@ class Instruction:
""" """
state = global_state.mstate state = global_state.mstate
disassembly = global_state.environment.code disassembly = global_state.environment.code
min_gas, max_gas = OPCODE_GAS["JUMPI"] min_gas, max_gas = cast(Tuple[int, int], OPCODE_GAS["JUMPI"])
states = [] states = []
op0, condition = state.stack.pop(), state.stack.pop() op0, condition = state.stack.pop(), state.stack.pop()
@ -1910,12 +1917,12 @@ class Instruction:
try: try:
memory_out_offset = ( memory_out_offset = (
util.get_concrete_int(memory_out_offset) util.get_concrete_int(memory_out_offset)
if isinstance(memory_out_offset, ExprRef) if isinstance(memory_out_offset, Expression)
else memory_out_offset else memory_out_offset
) )
memory_out_size = ( memory_out_size = (
util.get_concrete_int(memory_out_size) util.get_concrete_int(memory_out_size)
if isinstance(memory_out_size, ExprRef) if isinstance(memory_out_size, Expression)
else memory_out_size else memory_out_size
) )
except TypeError: except TypeError:

@ -18,7 +18,7 @@ class KeccakFunctionManager:
""" """
return str(expression) in self.keccak_expression_mapping.keys() return str(expression) in self.keccak_expression_mapping.keys()
def get_argument(self, expression: str) -> Expression: def get_argument(self, expression: Expression) -> Expression:
""" """
:param expression: :param expression:

@ -9,7 +9,8 @@ from py_ecc.secp256k1 import N as secp256k1n
from rlp.utils import ALL_BYTES from rlp.utils import ALL_BYTES
from mythril.laser.ethereum.state.calldata import BaseCalldata, ConcreteCalldata from mythril.laser.ethereum.state.calldata import BaseCalldata, ConcreteCalldata
from mythril.laser.ethereum.util import bytearray_to_int, sha3 from mythril.laser.ethereum.util import bytearray_to_int
from ethereum.utils import sha3
from mythril.laser.smt import Concat, simplify from mythril.laser.smt import Concat, simplify
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -50,7 +51,7 @@ def extract32(data: bytearray, i: int) -> int:
return bytearray_to_int(o) return bytearray_to_int(o)
def ecrecover(data: Union[bytes, str, List[int]]) -> bytes: def ecrecover(data: List[int]) -> List[int]:
""" """
:param data: :param data:
@ -58,54 +59,54 @@ def ecrecover(data: Union[bytes, str, List[int]]) -> bytes:
""" """
# TODO: Add type hints # TODO: Add type hints
try: try:
data = bytearray(data) byte_data = bytearray(data)
v = extract32(data, 32) v = extract32(byte_data, 32)
r = extract32(data, 64) r = extract32(byte_data, 64)
s = extract32(data, 96) s = extract32(byte_data, 96)
except TypeError: except TypeError:
raise NativeContractException raise NativeContractException
message = b"".join([ALL_BYTES[x] for x in data[0:32]]) message = b"".join([ALL_BYTES[x] for x in byte_data[0:32]])
if r >= secp256k1n or s >= secp256k1n or v < 27 or v > 28: if r >= secp256k1n or s >= secp256k1n or v < 27 or v > 28:
return [] return []
try: try:
pub = ecrecover_to_pub(message, v, r, s) pub = ecrecover_to_pub(message, v, r, s)
except Exception as e: except Exception as e:
log.debug("An error has occured while extracting public key: " + e) log.debug("An error has occured while extracting public key: " + str(e))
return [] return []
o = [0] * 12 + [x for x in sha3(pub)[-20:]] o = [0] * 12 + [x for x in sha3(pub)[-20:]]
return o return list(bytearray(o))
def sha256(data: Union[bytes, str, List[int]]) -> bytes: def sha256(data: List[int]) -> List[int]:
""" """
:param data: :param data:
:return: :return:
""" """
try: try:
data = bytes(data) byte_data = bytes(data)
except TypeError: except TypeError:
raise NativeContractException raise NativeContractException
return hashlib.sha256(data).digest() return list(bytearray(hashlib.sha256(byte_data).digest()))
def ripemd160(data: Union[bytes, str, List[int]]) -> bytes: def ripemd160(data: List[int]) -> List[int]:
""" """
:param data: :param data:
:return: :return:
""" """
try: try:
data = bytes(data) bytes_data = bytes(data)
except TypeError: except TypeError:
raise NativeContractException raise NativeContractException
digest = hashlib.new("ripemd160", data).digest() digest = hashlib.new("ripemd160", bytes_data).digest()
padded = 12 * [0] + list(digest) padded = 12 * [0] + list(digest)
return bytes(padded) return list(bytearray(bytes(padded)))
def identity(data: Union[bytes, str, List[int]]) -> bytes: def identity(data: List[int]) -> List[int]:
""" """
:param data: :param data:
@ -117,13 +118,9 @@ def identity(data: Union[bytes, str, List[int]]) -> bytes:
# implementation would be byte indexed for the most # implementation would be byte indexed for the most
# part. # part.
return data return data
result = []
for i in range(0, len(data), 32):
result.append(simplify(Concat(data[i : i + 32])))
return result
def native_contracts(address: int, data: BaseCalldata): def native_contracts(address: int, data: BaseCalldata) -> List[int]:
"""Takes integer address 1, 2, 3, 4. """Takes integer address 1, 2, 3, 4.
:param address: :param address:
@ -133,8 +130,8 @@ def native_contracts(address: int, data: BaseCalldata):
functions = (ecrecover, sha256, ripemd160, identity) functions = (ecrecover, sha256, ripemd160, identity)
if isinstance(data, ConcreteCalldata): if isinstance(data, ConcreteCalldata):
data = data.concrete(None) concrete_data = data.concrete(None)
else: else:
raise NativeContractException() raise NativeContractException()
return functions[address - 1](data) return functions[address - 1](concrete_data)

@ -14,17 +14,17 @@ from mythril.laser.smt import symbol_factory
class Storage: class Storage:
"""Storage class represents the storage of an Account.""" """Storage class represents the storage of an Account."""
def __init__(self, concrete=False, address=None, dynamic_loader=None): def __init__(self, concrete=False, address=None, dynamic_loader=None) -> None:
"""Constructor for Storage. """Constructor for Storage.
:param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic :param concrete: bool indicating whether to interpret uninitialized storage as concrete versus symbolic
""" """
self._storage = {} self._storage = {} # type: Dict[Union[int, str], Any]
self.concrete = concrete self.concrete = concrete
self.dynld = dynamic_loader self.dynld = dynamic_loader
self.address = address self.address = address
def __getitem__(self, item: Union[int, slice]) -> Any: def __getitem__(self, item: Union[str, int]) -> Any:
try: try:
return self._storage[item] return self._storage[item]
except KeyError: except KeyError:
@ -51,7 +51,7 @@ class Storage:
self._storage[item] = symbol_factory.BitVecVal(0, 256) self._storage[item] = symbol_factory.BitVecVal(0, 256)
return self._storage[item] return self._storage[item]
def __setitem__(self, key: str, value: ExprRef) -> None: def __setitem__(self, key: Union[int, str], value: Any) -> None:
self._storage[key] = value self._storage[key] = value
def keys(self) -> KeysView: def keys(self) -> KeysView:
@ -73,7 +73,7 @@ class Account:
balance=None, balance=None,
concrete_storage=False, concrete_storage=False,
dynamic_loader=None, dynamic_loader=None,
): ) -> None:
"""Constructor for account. """Constructor for account.
:param address: Address of the account :param address: Address of the account

@ -1,7 +1,6 @@
"""This module declares classes to represent call data.""" """This module declares classes to represent call data."""
from typing import Union, Any from typing import cast, Union, Tuple, List
from mythril.laser.smt import K, Array, If, simplify, Concat, Expression, BitVec
from enum import Enum from enum import Enum
from typing import Any, Union from typing import Any, Union
@ -13,6 +12,7 @@ from mythril.laser.ethereum.util import get_concrete_int
from mythril.laser.smt import ( from mythril.laser.smt import (
Array, Array,
BitVec, BitVec,
Bool,
Concat, Concat,
Expression, Expression,
If, If,
@ -26,7 +26,7 @@ class BaseCalldata:
"""Base calldata class This represents the calldata provided when sending a """Base calldata class This represents the calldata provided when sending a
transaction to a contract.""" transaction to a contract."""
def __init__(self, tx_id): def __init__(self, tx_id: str) -> None:
""" """
:param tx_id: :param tx_id:
@ -34,7 +34,7 @@ class BaseCalldata:
self.tx_id = tx_id self.tx_id = tx_id
@property @property
def calldatasize(self) -> Expression: def calldatasize(self) -> BitVec:
""" """
:return: Calldata size for this calldata object :return: Calldata size for this calldata object
@ -53,7 +53,7 @@ class BaseCalldata:
parts = self[offset : offset + 32] parts = self[offset : offset + 32]
return simplify(Concat(parts)) return simplify(Concat(parts))
def __getitem__(self, item: Union[int, slice]) -> Any: def __getitem__(self, item: Union[int, slice, BitVec]) -> Any:
""" """
:param item: :param item:
@ -88,7 +88,7 @@ class BaseCalldata:
raise ValueError raise ValueError
def _load(self, item: Union[int, Expression]) -> Any: def _load(self, item: Union[int, BitVec]) -> Any:
""" """
:param item: :param item:
@ -96,7 +96,7 @@ class BaseCalldata:
raise NotImplementedError() raise NotImplementedError()
@property @property
def size(self) -> Union[Expression, int]: def size(self) -> Union[BitVec, int]:
"""Returns the exact size of this calldata, this is not normalized. """Returns the exact size of this calldata, this is not normalized.
:return: unnormalized call data size :return: unnormalized call data size
@ -114,7 +114,7 @@ class BaseCalldata:
class ConcreteCalldata(BaseCalldata): class ConcreteCalldata(BaseCalldata):
"""A concrete call data representation.""" """A concrete call data representation."""
def __init__(self, tx_id: int, calldata: list): def __init__(self, tx_id: str, calldata: list) -> None:
"""Initializes the ConcreteCalldata object. """Initializes the ConcreteCalldata object.
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
@ -132,7 +132,7 @@ class ConcreteCalldata(BaseCalldata):
super().__init__(tx_id) super().__init__(tx_id)
def _load(self, item: Union[int, Expression]) -> BitVec: def _load(self, item: Union[int, BitVec]) -> BitVec:
""" """
:param item: :param item:
@ -161,7 +161,7 @@ class ConcreteCalldata(BaseCalldata):
class BasicConcreteCalldata(BaseCalldata): class BasicConcreteCalldata(BaseCalldata):
"""A base class to represent concrete call data.""" """A base class to represent concrete call data."""
def __init__(self, tx_id: int, calldata: list): def __init__(self, tx_id: str, calldata: list) -> None:
"""Initializes the ConcreteCalldata object, that doesn't use z3 arrays. """Initializes the ConcreteCalldata object, that doesn't use z3 arrays.
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
@ -184,7 +184,7 @@ class BasicConcreteCalldata(BaseCalldata):
value = symbol_factory.BitVecVal(0x0, 8) value = symbol_factory.BitVecVal(0x0, 8)
for i in range(self.size): for i in range(self.size):
value = If(item == i, self._calldata[i], value) value = If(cast(Union[BitVec, Bool], item) == i, self._calldata[i], value)
return value return value
def concrete(self, model: Model) -> list: def concrete(self, model: Model) -> list:
@ -207,7 +207,7 @@ class BasicConcreteCalldata(BaseCalldata):
class SymbolicCalldata(BaseCalldata): class SymbolicCalldata(BaseCalldata):
"""A class for representing symbolic call data.""" """A class for representing symbolic call data."""
def __init__(self, tx_id: int): def __init__(self, tx_id: str) -> None:
"""Initializes the SymbolicCalldata object. """Initializes the SymbolicCalldata object.
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
@ -216,7 +216,7 @@ class SymbolicCalldata(BaseCalldata):
self._calldata = Array("{}_calldata".format(tx_id), 256, 8) self._calldata = Array("{}_calldata".format(tx_id), 256, 8)
super().__init__(tx_id) super().__init__(tx_id)
def _load(self, item: Union[int, Expression]) -> Any: def _load(self, item: Union[int, BitVec]) -> Any:
""" """
:param item: :param item:
@ -226,7 +226,7 @@ class SymbolicCalldata(BaseCalldata):
return simplify( return simplify(
If( If(
item < self._size, item < self._size,
simplify(self._calldata[item]), simplify(self._calldata[cast(BitVec, item)]),
symbol_factory.BitVecVal(0, 8), symbol_factory.BitVecVal(0, 8),
) )
) )
@ -247,7 +247,7 @@ class SymbolicCalldata(BaseCalldata):
return result return result
@property @property
def size(self) -> Expression: def size(self) -> BitVec:
""" """
:return: :return:
@ -258,29 +258,34 @@ class SymbolicCalldata(BaseCalldata):
class BasicSymbolicCalldata(BaseCalldata): class BasicSymbolicCalldata(BaseCalldata):
"""A basic class representing symbolic call data.""" """A basic class representing symbolic call data."""
def __init__(self, tx_id: int): def __init__(self, tx_id: str) -> None:
"""Initializes the SymbolicCalldata object. """Initializes the SymbolicCalldata object.
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
""" """
self._reads = [] self._reads = [] # type: List[Tuple[Union[int, BitVec], BitVec]]
self._size = BitVec(str(tx_id) + "_calldatasize", 256) self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256)
super().__init__(tx_id) super().__init__(tx_id)
def _load(self, item: Union[int, Expression], clean=False) -> Any: def _load(self, item: Union[int, BitVec], clean=False) -> Any:
x = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item expr_item = (
symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
) # type: BitVec
symbolic_base_value = If( symbolic_base_value = If(
x >= self._size, expr_item >= self._size,
symbol_factory.BitVecVal(0, 8), symbol_factory.BitVecVal(0, 8),
BitVec("{}_calldata_{}".format(self.tx_id, str(item)), 8), BitVec(
symbol_factory.BitVecSym(
"{}_calldata_{}".format(self.tx_id, str(item)), 8
)
),
) )
return_value = symbolic_base_value return_value = symbolic_base_value
for r_index, r_value in self._reads: for r_index, r_value in self._reads:
return_value = If(r_index == item, r_value, return_value) return_value = If(r_index == expr_item, r_value, return_value)
if not clean: if not clean:
self._reads.append((item, symbolic_base_value)) self._reads.append((expr_item, symbolic_base_value))
return simplify(return_value) return simplify(return_value)
def concrete(self, model: Model) -> list: def concrete(self, model: Model) -> list:
@ -299,7 +304,7 @@ class BasicSymbolicCalldata(BaseCalldata):
return result return result
@property @property
def size(self) -> Expression: def size(self) -> BitVec:
""" """
:return: :return:

@ -22,7 +22,7 @@ class Environment:
callvalue: ExprRef, callvalue: ExprRef,
origin: ExprRef, origin: ExprRef,
code=None, code=None,
): ) -> None:
""" """
:param active_account: :param active_account:

@ -1,5 +1,5 @@
"""This module contains a representation of the global execution state.""" """This module contains a representation of the global execution state."""
from typing import Dict, Union, List, Iterable from typing import Dict, Union, List, Iterable, TYPE_CHECKING
from copy import copy, deepcopy from copy import copy, deepcopy
from z3 import BitVec from z3 import BitVec
@ -10,6 +10,13 @@ from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.machine_state import MachineState from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
if TYPE_CHECKING:
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.transaction.transaction_models import (
MessageCallTransaction,
ContractCreationTransaction,
)
class GlobalState: class GlobalState:
"""GlobalState represents the current globalstate.""" """GlobalState represents the current globalstate."""
@ -23,7 +30,7 @@ class GlobalState:
transaction_stack=None, transaction_stack=None,
last_return_data=None, last_return_data=None,
annotations=None, annotations=None,
): ) -> None:
"""Constructor for GlobalState. """Constructor for GlobalState.
:param world_state: :param world_state:

@ -1,9 +1,9 @@
"""This module contains a representation of the EVM's machine state and its """This module contains a representation of the EVM's machine state and its
stack.""" stack."""
from copy import copy from copy import copy
from typing import Union, Any, List, Dict from typing import cast, Sized, Union, Any, List, Dict, Optional
from z3 import BitVec from mythril.laser.smt import BitVec, Expression
from ethereum import opcodes, utils from ethereum import opcodes, utils
from mythril.laser.ethereum.evm_exceptions import ( from mythril.laser.ethereum.evm_exceptions import (
@ -20,16 +20,14 @@ class MachineStack(list):
STACK_LIMIT = 1024 STACK_LIMIT = 1024
def __init__(self, default_list=None): def __init__(self, default_list=None) -> None:
""" """
:param default_list: :param default_list:
""" """
if default_list is None: super(MachineStack, self).__init__(default_list or [])
default_list = []
super(MachineStack, self).__init__(default_list)
def append(self, element: BitVec) -> None: def append(self, element: Union[int, Expression]) -> None:
""" """
:param element: element to be appended to the list :param element: element to be appended to the list
:function: appends the element to list if the size is less than STACK_LIMIT, else throws an error :function: appends the element to list if the size is less than STACK_LIMIT, else throws an error
@ -41,7 +39,7 @@ class MachineStack(list):
) )
super(MachineStack, self).append(element) super(MachineStack, self).append(element)
def pop(self, index=-1) -> BitVec: def pop(self, index=-1) -> Union[int, Expression]:
""" """
:param index:index to be popped, same as the list() class. :param index:index to be popped, same as the list() class.
:returns popped value :returns popped value
@ -90,12 +88,12 @@ class MachineState:
gas_limit: int, gas_limit: int,
pc=0, pc=0,
stack=None, stack=None,
memory=None, memory: Optional[Memory] = None,
constraints=None, constraints=None,
depth=0, depth=0,
max_gas_used=0, max_gas_used=0,
min_gas_used=0, min_gas_used=0,
): ) -> None:
"""Constructor for machineState. """Constructor for machineState.
:param gas_limit: :param gas_limit:
@ -164,7 +162,7 @@ class MachineState:
self.check_gas() self.check_gas()
self.memory.extend(m_extend) self.memory.extend(m_extend)
def memory_write(self, offset: int, data: List[int]) -> None: def memory_write(self, offset: int, data: List[Union[int, BitVec]]) -> None:
"""Writes data to memory starting at offset. """Writes data to memory starting at offset.
:param offset: :param offset:
@ -217,7 +215,7 @@ class MachineState:
:return: :return:
""" """
return len(self.memory) return len(cast(Sized, self.memory))
@property @property
def as_dict(self) -> Dict: def as_dict(self) -> Dict:

@ -1,5 +1,5 @@
"""This module contains a representation of a smart contract's memory.""" """This module contains a representation of a smart contract's memory."""
from typing import Union from typing import cast, List, Union, overload
from z3 import Z3Exception from z3 import Z3Exception
@ -20,7 +20,7 @@ class Memory:
def __init__(self): def __init__(self):
"""""" """"""
self._memory = [] self._memory = [] # type: List[Union[int, BitVec]]
def __len__(self): def __len__(self):
""" """
@ -50,12 +50,14 @@ class Memory:
), ),
256, 256,
) )
except: except TypeError:
result = simplify( result = simplify(
Concat( Concat(
[ [
b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8) b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
for b in self[index : index + 32] for b in cast(
List[Union[int, BitVec]], self[index : index + 32]
)
] ]
) )
) )
@ -79,8 +81,9 @@ class Memory:
else: else:
_bytes = util.concrete_int_to_bytes(value) _bytes = util.concrete_int_to_bytes(value)
assert len(_bytes) == 32 assert len(_bytes) == 32
self[index : index + 32] = _bytes self[index : index + 32] = list(bytearray(_bytes))
except (Z3Exception, AttributeError): # BitVector or BoolRef except (Z3Exception, AttributeError): # BitVector or BoolRef
value = cast(Union[BitVec, Bool], value)
if isinstance(value, Bool): if isinstance(value, Bool):
value_to_write = If( value_to_write = If(
value, value,
@ -94,7 +97,17 @@ class Memory:
for i in range(0, value_to_write.size(), 8): for i in range(0, value_to_write.size(), 8):
self[index + 31 - (i // 8)] = Extract(i + 7, i, value_to_write) self[index + 31 - (i // 8)] = Extract(i + 7, i, value_to_write)
def __getitem__(self, item: Union[int, slice]) -> Union[BitVec, int, list]: @overload
def __getitem__(self, item: int) -> Union[int, BitVec]:
...
@overload
def __getitem__(self, item: slice) -> List[Union[int, BitVec]]:
...
def __getitem__(
self, item: Union[int, slice]
) -> Union[BitVec, int, List[Union[int, BitVec]]]:
""" """
:param item: :param item:
@ -108,14 +121,18 @@ class Memory:
raise IndexError("Invalid Memory Slice") raise IndexError("Invalid Memory Slice")
if step is None: if step is None:
step = 1 step = 1
return [self[i] for i in range(start, stop, step)] return [cast(Union[int, BitVec], self[i]) for i in range(start, stop, step)]
try: try:
return self._memory[item] return self._memory[item]
except IndexError: except IndexError:
return 0 return 0
def __setitem__(self, key: Union[int, slice], value: Union[BitVec, int, list]): def __setitem__(
self,
key: Union[int, slice],
value: Union[BitVec, int, List[Union[int, BitVec]]],
):
""" """
:param key: :param key:
@ -130,13 +147,13 @@ class Memory:
raise IndexError("Invalid Memory Slice") raise IndexError("Invalid Memory Slice")
if step is None: if step is None:
step = 1 step = 1
assert type(value) == list
for i in range(0, stop - start, step): for i in range(0, stop - start, step):
self[start + i] = value[i] self[start + i] = cast(List[Union[int, BitVec]], value)[i]
else: else:
if isinstance(value, int): if isinstance(value, int):
assert 0 <= value <= 0xFF assert 0 <= value <= 0xFF
if isinstance(value, BitVec): if isinstance(value, BitVec):
assert value.size() == 8 assert value.size() == 8
self._memory[key] = value self._memory[key] = cast(Union[int, BitVec], value)

@ -1,11 +1,14 @@
"""This module contains a representation of the EVM's world state.""" """This module contains a representation of the EVM's world state."""
from copy import copy from copy import copy
from random import randint from random import randint
from typing import List, Iterator from typing import Dict, List, Iterator, Optional, TYPE_CHECKING
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
if TYPE_CHECKING:
from mythril.laser.ethereum.cfg import Node
class WorldState: class WorldState:
"""The WorldState class represents the world state as described in the """The WorldState class represents the world state as described in the
@ -19,8 +22,8 @@ class WorldState:
:param transaction_sequence: :param transaction_sequence:
:param annotations: :param annotations:
""" """
self.accounts = {} self.accounts = {} # type: Dict[str, Account]
self.node = None self.node = None # type: Optional['Node']
self.transaction_sequence = transaction_sequence or [] self.transaction_sequence = transaction_sequence or []
self._annotations = annotations or [] self._annotations = annotations or []

@ -1,5 +1,6 @@
"""This module implements basic symbolic execution search strategies.""" """This module implements basic symbolic execution search strategies."""
from random import randrange from random import randrange
from typing import List
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from . import BasicSearchStrategy from . import BasicSearchStrategy
@ -13,7 +14,10 @@ except ImportError:
from random import random from random import random
from bisect import bisect from bisect import bisect
def choices(population, weights=None): # TODO: Remove ignore after this has been fixed: https://github.com/python/mypy/issues/1297
def choices( # type: ignore
population: List, weights: List[int] = None
) -> List[int]:
"""Returns a random element out of the population based on weight. """Returns a random element out of the population based on weight.
If the relative weights or cumulative weights are not specified, If the relative weights or cumulative weights are not specified,
@ -21,7 +25,7 @@ except ImportError:
""" """
if weights is None: if weights is None:
return [population[int(random() * len(population))]] return [population[int(random() * len(population))]]
cum_weights = accumulate(weights) cum_weights = list(accumulate(weights))
return [ return [
population[ population[
bisect(cum_weights, random() * cum_weights[-1], 0, len(population) - 1) bisect(cum_weights, random() * cum_weights[-1], 0, len(population) - 1)

@ -4,7 +4,7 @@ from collections import defaultdict
from copy import copy from copy import copy
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import reduce from functools import reduce
from typing import Callable, Dict, List, Tuple, Union from typing import Callable, Dict, DefaultDict, List, Tuple, Union
from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType
from mythril.laser.ethereum.evm_exceptions import StackUnderflowException from mythril.laser.ethereum.evm_exceptions import StackUnderflowException
@ -56,7 +56,7 @@ class LaserEVM:
transaction_count=2, transaction_count=2,
requires_statespace=True, requires_statespace=True,
enable_iprof=False, enable_iprof=False,
): ) -> None:
""" """
:param accounts: :param accounts:
@ -73,12 +73,12 @@ class LaserEVM:
self.world_state = world_state self.world_state = world_state
self.open_states = [world_state] self.open_states = [world_state]
self.coverage = {} self.coverage = {} # type: Dict[str, Tuple[int, List[bool]]]
self.total_states = 0 self.total_states = 0
self.dynamic_loader = dynamic_loader self.dynamic_loader = dynamic_loader
self.work_list = [] self.work_list = [] # type: List[GlobalState]
self.strategy = strategy(self.work_list, max_depth) self.strategy = strategy(self.work_list, max_depth)
self.max_depth = max_depth self.max_depth = max_depth
self.transaction_count = transaction_count self.transaction_count = transaction_count
@ -88,14 +88,15 @@ class LaserEVM:
self.requires_statespace = requires_statespace self.requires_statespace = requires_statespace
if self.requires_statespace: if self.requires_statespace:
self.nodes = {} self.nodes = {} # type: Dict[int, Node]
self.edges = [] self.edges = [] # type: List[Edge]
self.time = None self.time = None # type: datetime
self.pre_hooks = defaultdict(list) self.pre_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]]
self.post_hooks = defaultdict(list) self.post_hooks = defaultdict(list) # type: DefaultDict[str, List[Callable]]
self._add_world_state_hooks = []
self._add_world_state_hooks = [] # type: List[Callable]
self.iprof = InstructionProfiler() if enable_iprof else None self.iprof = InstructionProfiler() if enable_iprof else None
log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) log.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader))
@ -153,11 +154,8 @@ class LaserEVM:
self.total_states, self.total_states,
) )
for code, coverage in self.coverage.items(): for code, coverage in self.coverage.items():
cov = ( cov = sum(coverage[1]) / float(coverage[0]) * 100
reduce(lambda sum_, val: sum_ + 1 if val else sum_, coverage[1])
/ float(coverage[0])
* 100
)
log.info("Achieved {:.2f}% coverage for code: {}".format(cov, code)) log.info("Achieved {:.2f}% coverage for code: {}".format(cov, code))
if self.iprof is not None: if self.iprof is not None:
@ -198,9 +196,7 @@ class LaserEVM:
""" """
total_covered_instructions = 0 total_covered_instructions = 0
for _, cv in self.coverage.items(): for _, cv in self.coverage.items():
total_covered_instructions += reduce( total_covered_instructions += sum(cv[1])
lambda sum_, val: sum_ + 1 if val else sum_, cv[1]
)
return total_covered_instructions return total_covered_instructions
def exec(self, create=False, track_gas=False) -> Union[List[GlobalState], None]: def exec(self, create=False, track_gas=False) -> Union[List[GlobalState], None]:
@ -210,7 +206,7 @@ class LaserEVM:
:param track_gas: :param track_gas:
:return: :return:
""" """
final_states = [] final_states = [] # type: List[GlobalState]
for global_state in self.strategy: for global_state in self.strategy:
if ( if (
self.create_timeout self.create_timeout
@ -385,10 +381,10 @@ class LaserEVM:
instruction_index = global_state.mstate.pc instruction_index = global_state.mstate.pc
if code not in self.coverage.keys(): if code not in self.coverage.keys():
self.coverage[code] = [ self.coverage[code] = (
number_of_instructions, number_of_instructions,
[False] * number_of_instructions, [False] * number_of_instructions,
] )
self.coverage[code][1][instruction_index] = True self.coverage[code][1][instruction_index] = True

@ -3,7 +3,7 @@ execution."""
import array import array
from z3 import ExprRef from z3 import ExprRef
from typing import Union from typing import Union, Optional, cast
from mythril.laser.ethereum.state.calldata import ConcreteCalldata from mythril.laser.ethereum.state.calldata import ConcreteCalldata
from mythril.laser.ethereum.state.account import Account from mythril.laser.ethereum.state.account import Account
@ -17,20 +17,20 @@ from mythril.laser.smt import symbol_factory
_next_transaction_id = 0 _next_transaction_id = 0
def get_next_transaction_id() -> int: def get_next_transaction_id() -> str:
""" """
:return: :return:
""" """
global _next_transaction_id global _next_transaction_id
_next_transaction_id += 1 _next_transaction_id += 1
return _next_transaction_id return str(_next_transaction_id)
class TransactionEndSignal(Exception): class TransactionEndSignal(Exception):
"""Exception raised when a transaction is finalized.""" """Exception raised when a transaction is finalized."""
def __init__(self, global_state: GlobalState, revert=False): def __init__(self, global_state: GlobalState, revert=False) -> None:
self.global_state = global_state self.global_state = global_state
self.revert = revert self.revert = revert
@ -42,7 +42,7 @@ class TransactionStartSignal(Exception):
self, self,
transaction: Union["MessageCallTransaction", "ContractCreationTransaction"], transaction: Union["MessageCallTransaction", "ContractCreationTransaction"],
op_code: str, op_code: str,
): ) -> None:
self.transaction = transaction self.transaction = transaction
self.op_code = op_code self.op_code = op_code
@ -56,14 +56,14 @@ class BaseTransaction:
callee_account: Account = None, callee_account: Account = None,
caller: ExprRef = None, caller: ExprRef = None,
call_data=None, call_data=None,
identifier=None, identifier: Optional[str] = None,
gas_price=None, gas_price=None,
gas_limit=None, gas_limit=None,
origin=None, origin=None,
code=None, code=None,
call_value=None, call_value=None,
init_call_data=True, init_call_data=True,
): ) -> None:
assert isinstance(world_state, WorldState) assert isinstance(world_state, WorldState)
self.world_state = world_state self.world_state = world_state
self.id = identifier or get_next_transaction_id() self.id = identifier or get_next_transaction_id()
@ -85,7 +85,7 @@ class BaseTransaction:
self.caller = caller self.caller = caller
self.callee_account = callee_account self.callee_account = callee_account
if call_data is None and init_call_data: if call_data is None and init_call_data:
self.call_data = SymbolicCalldata(self.id) self.call_data = SymbolicCalldata(self.id) # type: BaseCalldata
else: else:
self.call_data = ( self.call_data = (
call_data call_data
@ -99,7 +99,7 @@ class BaseTransaction:
else symbol_factory.BitVecSym("callvalue{}".format(identifier), 256) else symbol_factory.BitVecSym("callvalue{}".format(identifier), 256)
) )
self.return_data = None self.return_data = None # type: str
def initial_global_state_from_environment(self, environment, active_function): def initial_global_state_from_environment(self, environment, active_function):
""" """
@ -117,7 +117,7 @@ class BaseTransaction:
class MessageCallTransaction(BaseTransaction): class MessageCallTransaction(BaseTransaction):
"""Transaction object models an transaction.""" """Transaction object models an transaction."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
def initial_global_state(self) -> GlobalState: def initial_global_state(self) -> GlobalState:
@ -149,8 +149,9 @@ class MessageCallTransaction(BaseTransaction):
class ContractCreationTransaction(BaseTransaction): class ContractCreationTransaction(BaseTransaction):
"""Transaction object models an transaction.""" """Transaction object models an transaction."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs, init_call_data=False) # Remove ignore after https://github.com/python/mypy/issues/4335 is fixed
super().__init__(*args, **kwargs, init_call_data=False) # type: ignore
# TODO: set correct balance for new account # TODO: set correct balance for new account
self.callee_account = self.callee_account or self.world_state.create_account( self.callee_account = self.callee_account or self.world_state.create_account(
0, concrete_storage=True 0, concrete_storage=True

@ -1,9 +1,10 @@
"""This module contains various utility conversion functions and constants for """This module contains various utility conversion functions and constants for
LASER.""" LASER."""
import re import re
from typing import Dict, List, Union from typing import Dict, List, Union, TYPE_CHECKING, cast
import sha3 as _sha3 if TYPE_CHECKING:
from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.smt import BitVec, Bool, Expression, If, simplify, symbol_factory from mythril.laser.smt import BitVec, Bool, Expression, If, simplify, symbol_factory
@ -12,15 +13,6 @@ TT256M1 = 2 ** 256 - 1
TT255 = 2 ** 255 TT255 = 2 ** 255
def sha3(seed: str) -> bytes:
"""
:param seed:
:return:
"""
return _sha3.keccak_256(bytes(seed)).digest()
def safe_decode(hex_encoded_string: str) -> bytes: def safe_decode(hex_encoded_string: str) -> bytes:
""" """
@ -83,18 +75,16 @@ def pop_bitvec(state: "MachineState") -> BitVec:
item = state.stack.pop() item = state.stack.pop()
if type(item) == Bool: if isinstance(item, Bool):
return If( return If(
item, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256) cast(Bool, item),
symbol_factory.BitVecVal(1, 256),
symbol_factory.BitVecVal(0, 256),
) )
elif type(item) == bool: elif isinstance(item, int):
if item:
return symbol_factory.BitVecVal(1, 256)
else:
return symbol_factory.BitVecVal(0, 256)
elif type(item) == int:
return symbol_factory.BitVecVal(item, 256) return symbol_factory.BitVecVal(item, 256)
else: else:
item = cast(BitVec, item)
return simplify(item) return simplify(item)
@ -116,8 +106,12 @@ def get_concrete_int(item: Union[int, Expression]) -> int:
raise TypeError("Symbolic boolref encountered") raise TypeError("Symbolic boolref encountered")
return value return value
assert False, "Unhandled type {} encountered".format(str(type(item)))
def concrete_int_from_bytes(concrete_bytes: bytes, start_index: int) -> int: def concrete_int_from_bytes(
concrete_bytes: Union[List[Union[BitVec, int]], bytes], start_index: int
) -> int:
""" """
:param concrete_bytes: :param concrete_bytes:
@ -130,7 +124,8 @@ def concrete_int_from_bytes(concrete_bytes: bytes, start_index: int) -> int:
] ]
integer_bytes = concrete_bytes[start_index : start_index + 32] integer_bytes = concrete_bytes[start_index : start_index + 32]
return int.from_bytes(integer_bytes, byteorder="big") # The below statement is expected to fail in some circumstances whose error is caught
return int.from_bytes(integer_bytes, byteorder="big") # type: ignore
def concrete_int_to_bytes(val): def concrete_int_to_bytes(val):

@ -11,6 +11,7 @@ Annotations = List[Any]
# fmt: off # fmt: off
class BitVec(Expression[z3.BitVecRef]): class BitVec(Expression[z3.BitVecRef]):
"""A bit vector symbol.""" """A bit vector symbol."""
@ -139,6 +140,24 @@ class BitVec(Expression[z3.BitVecRef]):
union = self.annotations + other.annotations union = self.annotations + other.annotations
return Bool(self.raw > other.raw, annotations=union) return Bool(self.raw > other.raw, annotations=union)
def __le__(self, other: "BitVec") -> Bool:
"""Create a signed less than expression.
:param other:
:return:
"""
union = self.annotations + other.annotations
return Bool(self.raw <= other.raw, annotations=union)
def __ge__(self, other: "BitVec") -> Bool:
"""Create a signed greater than expression.
:param other:
:return:
"""
union = self.annotations + other.annotations
return Bool(self.raw >= other.raw, annotations=union)
# MYPY: fix complains about overriding __eq__ # MYPY: fix complains about overriding __eq__
def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore def __eq__(self, other: Union[int, "BitVec"]) -> Bool: # type: ignore
"""Create an equality expression. """Create an equality expression.

@ -1,7 +1,7 @@
"""This module provides classes for an SMT abstraction of boolean """This module provides classes for an SMT abstraction of boolean
expressions.""" expressions."""
from typing import Union, cast from typing import Union, cast, List
import z3 import z3
@ -81,13 +81,13 @@ class Bool(Expression[z3.BoolRef]):
return False return False
def And(*args: Bool) -> Bool: def And(*args: Union[Bool, bool]) -> Bool:
"""Create an And expression.""" """Create an And expression."""
union = [] union = []
args = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
for arg in args: for arg in args_list:
union.append(arg.annotations) union.append(arg.annotations)
return Bool(z3.And([a.raw for a in args]), union) return Bool(z3.And([a.raw for a in args_list]), union)
def Or(a: Bool, b: Bool) -> Bool: def Or(a: Bool, b: Bool) -> Bool:

@ -46,7 +46,10 @@ class Expression(Generic[T]):
return repr(self.raw) return repr(self.raw)
def simplify(expression: Expression) -> Expression: G = TypeVar("G", bound=Expression)
def simplify(expression: G) -> G:
"""Simplify the expression . """Simplify the expression .
:param expression: :param expression:

@ -22,7 +22,7 @@ class DynLoader:
self.contract_loading = contract_loading self.contract_loading = contract_loading
self.storage_loading = storage_loading self.storage_loading = storage_loading
def read_storage(self, contract_address, index): def read_storage(self, contract_address: str, index: int):
""" """
:param contract_address: :param contract_address:

@ -7,7 +7,7 @@ import sqlite3
import time import time
from collections import defaultdict from collections import defaultdict
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
from typing import List from typing import List, Set, DefaultDict, Dict
from mythril.exceptions import CompilerError from mythril.exceptions import CompilerError
@ -45,7 +45,7 @@ def synchronized(sync_lock):
class Singleton(type): class Singleton(type):
"""A metaclass type implementing the singleton pattern.""" """A metaclass type implementing the singleton pattern."""
_instances = {} _instances = dict() # type: Dict[Singleton, Singleton]
@synchronized(lock) @synchronized(lock)
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
@ -60,6 +60,7 @@ class Singleton(type):
""" """
if cls not in cls._instances: if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls] return cls._instances[cls]
@ -120,12 +121,12 @@ class SignatureDB(object, metaclass=Singleton):
:param path: :param path:
""" """
self.enable_online_lookup = enable_online_lookup self.enable_online_lookup = enable_online_lookup
self.online_lookup_miss = set() self.online_lookup_miss = set() # type: Set[str]
self.online_lookup_timeout = 0 self.online_lookup_timeout = 0
# if we're analysing a Solidity file, store its hashes # if we're analysing a Solidity file, store its hashes
# here to prevent unnecessary lookups # here to prevent unnecessary lookups
self.solidity_sigs = defaultdict(list) self.solidity_sigs = defaultdict(list) # type: DefaultDict[str, List[str]]
if path is None: if path is None:
self.path = os.environ.get("MYTHRIL_DIR") or os.path.join( self.path = os.environ.get("MYTHRIL_DIR") or os.path.join(
os.path.expanduser("~"), ".mythril" os.path.expanduser("~"), ".mythril"
@ -225,7 +226,7 @@ class SignatureDB(object, metaclass=Singleton):
return text_sigs return text_sigs
except FourByteDirectoryOnlineLookupError as fbdole: except FourByteDirectoryOnlineLookupError as fbdole:
# wait at least 2 mins to try again # wait at least 2 mins to try again
self.online_lookup_timeout = time.time() + 2 * 60 self.online_lookup_timeout = int(time.time()) + 2 * 60
log.warning("Online lookup failed, not retrying for 2min: %s", fbdole) log.warning("Online lookup failed, not retrying for 2min: %s", fbdole)
return [] return []

@ -1,10 +1,11 @@
"""This module contains utility functions for the Mythril support package.""" """This module contains utility functions for the Mythril support package."""
from typing import Dict
class Singleton(type): class Singleton(type):
"""A metaclass type implementing the singleton pattern.""" """A metaclass type implementing the singleton pattern."""
_instances = {} _instances = {} # type: Dict
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
"""Delegate the call to an existing resource or a a new one. """Delegate the call to an existing resource or a a new one.

@ -98,7 +98,7 @@ setup(
"persistent>=4.2.0", "persistent>=4.2.0",
"ethereum-input-decoder>=0.2.2", "ethereum-input-decoder>=0.2.2",
], ],
tests_require=["pytest>=3.6.0", "pytest_mock", "pytest-cov"], tests_require=["mypy", "pytest>=3.6.0", "pytest_mock", "pytest-cov"],
python_requires=">=3.5", python_requires=">=3.5",
extras_require={}, extras_require={},
package_data={"mythril.analysis.templates": ["*"], "mythril.support.assets": ["*"]}, package_data={"mythril.analysis.templates": ["*"], "mythril.support.assets": ["*"]},

@ -19,12 +19,14 @@ basepython = python3.6
setenv = setenv =
COVERAGE_FILE = .coverage.{envname} COVERAGE_FILE = .coverage.{envname}
deps = deps =
mypy
pytest pytest
pytest-mock pytest-mock
pytest-cov pytest-cov
passenv = MYTHRIL_DIR = {homedir} passenv = MYTHRIL_DIR = {homedir}
whitelist_externals = mkdir whitelist_externals = mkdir
commands = commands =
mypy --follow-imports=silent --warn-unused-ignores --ignore-missing-imports --no-strict-optional mythril
mkdir -p {toxinidir}/tests/testdata/outputs_current/ mkdir -p {toxinidir}/tests/testdata/outputs_current/
mkdir -p {toxinidir}/tests/testdata/outputs_current_laser_result/ mkdir -p {toxinidir}/tests/testdata/outputs_current_laser_result/
py.test -v \ py.test -v \
@ -35,6 +37,8 @@ commands =
--junitxml={toxworkdir}/output/{envname}/junit.xml \ --junitxml={toxworkdir}/output/{envname}/junit.xml \
{posargs} {posargs}
[coverage:report] [coverage:report]
omit = omit =
*__init__.py *__init__.py

Loading…
Cancel
Save