* Misc fixes

* Fix singleton
pull/1800/head
Nikhil Parasaram 1 year ago committed by GitHub
parent fd48221682
commit bd27897533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      mythril/analysis/module/modules/external_calls.py
  2. 6
      mythril/analysis/module/modules/integer.py
  3. 2
      mythril/analysis/module/modules/multiple_sends.py
  4. 4
      mythril/analysis/module/modules/state_change_external_calls.py
  5. 2
      mythril/analysis/module/util.py
  6. 6
      mythril/analysis/report.py
  7. 4
      mythril/analysis/security.py
  8. 6
      mythril/analysis/solver.py
  9. 4
      mythril/analysis/symbolic.py
  10. 2
      mythril/concolic/find_trace.py
  11. 4
      mythril/disassembler/asm.py
  12. 12
      mythril/disassembler/disassembly.py
  13. 5
      mythril/ethereum/util.py
  14. 2
      mythril/interfaces/cli.py
  15. 2
      mythril/laser/ethereum/cfg.py
  16. 2
      mythril/laser/ethereum/function_managers/keccak_function_manager.py
  17. 30
      mythril/laser/ethereum/instructions.py
  18. 4
      mythril/laser/ethereum/state/account.py
  19. 6
      mythril/laser/ethereum/state/calldata.py
  20. 4
      mythril/laser/ethereum/state/constraints.py
  21. 4
      mythril/laser/ethereum/state/memory.py
  22. 4
      mythril/laser/ethereum/state/world_state.py
  23. 4
      mythril/laser/ethereum/strategy/__init__.py
  24. 1
      mythril/laser/ethereum/strategy/constraint_strategy.py
  25. 4
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  26. 6
      mythril/laser/ethereum/svm.py
  27. 8
      mythril/laser/ethereum/transaction/transaction_models.py
  28. 2
      mythril/laser/ethereum/util.py
  29. 6
      mythril/laser/plugin/loader.py
  30. 4
      mythril/laser/plugin/plugins/coverage/coverage_plugin.py
  31. 8
      mythril/laser/plugin/plugins/dependency_pruner.py
  32. 10
      mythril/laser/plugin/plugins/plugin_annotations.py
  33. 2
      mythril/laser/smt/bitvec_helper.py
  34. 4
      mythril/laser/smt/bool.py
  35. 2
      mythril/laser/smt/model.py
  36. 22
      mythril/laser/smt/solver/independence_solver.py
  37. 8
      mythril/laser/smt/solver/solver.py
  38. 6
      mythril/mythril/mythril_analyzer.py
  39. 4
      mythril/mythril/mythril_config.py
  40. 7
      mythril/mythril/mythril_disassembler.py
  41. 2
      mythril/plugin/discovery.py
  42. 2
      mythril/plugin/loader.py
  43. 6
      mythril/solidity/features.py
  44. 4
      mythril/support/loader.py
  45. 10
      mythril/support/model.py
  46. 6
      mythril/support/signatures.py
  47. 2
      mythril/support/source_support.py
  48. 11
      mythril/support/support_utils.py
  49. 1
      tests/features_test.py
  50. 4
      tests/laser/evm_testsuite/evm_test.py
  51. 1
      tests/laser/tx_prioritisation_test.py

@ -27,8 +27,8 @@ Search for external calls with unrestricted gas to a user-specified address.
def _is_precompile_call(global_state: GlobalState): def _is_precompile_call(global_state: GlobalState):
to = global_state.mstate.stack[-2] # type: BitVec to: BitVec = global_state.mstate.stack[-2]
constraints = copy(global_state.world_state.constraints) constraints: Constraints = copy(global_state.world_state.constraints)
constraints += [ constraints += [
Or( Or(
to < symbol_factory.BitVecVal(1, 256), to < symbol_factory.BitVecVal(1, 256),

@ -50,7 +50,7 @@ class OverUnderflowStateAnnotation(StateAnnotation):
"""State Annotation used if an overflow is both possible and used in the annotated path""" """State Annotation used if an overflow is both possible and used in the annotated path"""
def __init__(self) -> None: def __init__(self) -> None:
self.overflowing_state_annotations = set() # type: Set[OverUnderflowAnnotation] self.overflowing_state_annotations: Set[OverUnderflowAnnotation] = set()
def __copy__(self): def __copy__(self):
new_annotation = OverUnderflowStateAnnotation() new_annotation = OverUnderflowStateAnnotation()
@ -91,8 +91,8 @@ class IntegerArithmetics(DetectionModule):
""" """
super().__init__() super().__init__()
self._ostates_satisfiable = set() # type: Set[GlobalState] self._ostates_satisfiable: Set[GlobalState] = set()
self._ostates_unsatisfiable = set() # type: Set[GlobalState] self._ostates_unsatisfiable: Set[GlobalState] = set()
def reset_module(self): def reset_module(self):
""" """

@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class MultipleSendsAnnotation(StateAnnotation): class MultipleSendsAnnotation(StateAnnotation):
def __init__(self) -> None: def __init__(self) -> None:
self.call_offsets = [] # type: List[int] self.call_offsets: List[int] = []
def __copy__(self): def __copy__(self):
result = MultipleSendsAnnotation() result = MultipleSendsAnnotation()

@ -29,7 +29,7 @@ STATE_READ_WRITE_LIST = ["SSTORE", "SLOAD", "CREATE", "CREATE2"]
class StateChangeCallsAnnotation(StateAnnotation): class StateChangeCallsAnnotation(StateAnnotation):
def __init__(self, call_state: GlobalState, user_defined_address: bool) -> None: def __init__(self, call_state: GlobalState, user_defined_address: bool) -> None:
self.call_state = call_state self.call_state = call_state
self.state_change_states = [] # type: List[GlobalState] self.state_change_states: List[GlobalState] = []
self.user_defined_address = user_defined_address self.user_defined_address = user_defined_address
def __copy__(self): def __copy__(self):
@ -165,7 +165,7 @@ class StateChangeAfterCall(DetectionModule):
# Record state changes following from a transfer of ether # Record state changes following from a transfer of ether
if op_code in CALL_LIST: if op_code in CALL_LIST:
value = global_state.mstate.stack[-3] # type: BitVec value: BitVec = global_state.mstate.stack[-3]
if StateChangeAfterCall._balance_change(value, global_state): if StateChangeAfterCall._balance_change(value, global_state):
for annotation in annotations: for annotation in annotations:
annotation.state_change_states.append(global_state) annotation.state_change_states.append(global_state)

@ -19,7 +19,7 @@ def get_detection_module_hooks(
:param hook_type: The type of hooks to retrieve (default: "pre") :param hook_type: The type of hooks to retrieve (default: "pre")
:return: Dictionary with discovered hooks :return: Dictionary with discovered hooks
""" """
hook_dict = defaultdict(list) # type: Mapping[str, List[Callable]] hook_dict: Mapping[str, List[Callable]] = defaultdict(list)
for module in modules: for module in modules:
hooks = module.pre_hooks if hook_type == "pre" else module.post_hooks hooks = module.pre_hooks if hook_type == "pre" else module.post_hooks

@ -254,9 +254,9 @@ class Report:
:param contracts: :param contracts:
:param exceptions: :param exceptions:
""" """
self.issues = {} # type: Dict[bytes, Issue] self.issues: Dict[bytes, Issue] = {}
self.solc_version = "" self.solc_version = ""
self.meta = {} # type: Dict[str, Any] self.meta: Dict[str, Any] = {}
self.source = Source() self.source = Source()
self.source.get_source_from_contracts_list(contracts) self.source.get_source_from_contracts_list(contracts)
self.exceptions = exceptions or [] self.exceptions = exceptions or []
@ -306,7 +306,7 @@ class Report:
def _get_exception_data(self) -> dict: def _get_exception_data(self) -> dict:
if not self.exceptions: if not self.exceptions:
return {} return {}
logs = [] # type: List[Dict] logs: List[Dict] = []
for exception in self.exceptions: for exception in self.exceptions:
logs += [{"level": "error", "hidden": True, "msg": exception}] logs += [{"level": "error", "hidden": True, "msg": exception}]
return {"logs": logs} return {"logs": logs}

@ -13,7 +13,7 @@ log = logging.getLogger(__name__)
def retrieve_callback_issues(white_list: Optional[List[str]] = None) -> List[Issue]: def retrieve_callback_issues(white_list: Optional[List[str]] = None) -> List[Issue]:
"""Get the issues discovered by callback type detection modules""" """Get the issues discovered by callback type detection modules"""
issues = [] # type: List[Issue] issues: List[Issue] = []
for module in ModuleLoader().get_detection_modules( for module in ModuleLoader().get_detection_modules(
entry_point=EntryPoint.CALLBACK, white_list=white_list entry_point=EntryPoint.CALLBACK, white_list=white_list
): ):
@ -34,7 +34,7 @@ def fire_lasers(statespace, white_list: Optional[List[str]] = None) -> List[Issu
""" """
log.info("Starting analysis") log.info("Starting analysis")
issues = [] # type: List[Issue] issues: List[Issue] = []
for module in ModuleLoader().get_detection_modules( for module in ModuleLoader().get_detection_modules(
entry_point=EntryPoint.POST, white_list=white_list entry_point=EntryPoint.POST, white_list=white_list
): ):

@ -36,7 +36,7 @@ def pretty_print_model(model):
ret = "" ret = ""
for d in model.decls(): for d in model.decls():
if type(model[d]) == FuncInterp: if isinstance(model[d], FuncInterp):
condition = model[d].as_list() condition = model[d].as_list()
ret += "%s: %s\n" % (d.name(), condition) ret += "%s: %s\n" % (d.name(), condition)
continue continue
@ -119,7 +119,7 @@ def _add_calldata_placeholder(
if not isinstance(transaction_sequence[0], ContractCreationTransaction): if not isinstance(transaction_sequence[0], ContractCreationTransaction):
return return
if type(transaction_sequence[0].code.bytecode) == tuple: if isinstance(transaction_sequence[0].code.bytecode, tuple):
code_len = len(transaction_sequence[0].code.bytecode) * 2 code_len = len(transaction_sequence[0].code.bytecode) * 2
else: else:
code_len = len(transaction_sequence[0].code.bytecode) code_len = len(transaction_sequence[0].code.bytecode)
@ -206,7 +206,7 @@ def _get_concrete_transaction(model: z3.Model, transaction: BaseTransaction):
) )
# Create concrete transaction dict # Create concrete transaction dict
concrete_transaction = dict() # type: Dict[str, str] concrete_transaction: Dict[str, str] = dict()
concrete_transaction["input"] = "0x" + input_ concrete_transaction["input"] = "0x" + input_
concrete_transaction["value"] = "0x%x" % value concrete_transaction["value"] = "0x%x" % value
# Fixme: base origin assignment on origin symbol # Fixme: base origin assignment on origin symbol

@ -85,7 +85,7 @@ class SymExecWrapper:
address = symbol_factory.BitVecVal(address, 256) address = symbol_factory.BitVecVal(address, 256)
beam_width = None beam_width = None
if strategy == "dfs": if strategy == "dfs":
s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy] s_strategy: Type[BasicSearchStrategy] = DepthFirstSearchStrategy
elif strategy == "bfs": elif strategy == "bfs":
s_strategy = BreadthFirstSearchStrategy s_strategy = BreadthFirstSearchStrategy
elif strategy == "naive-random": elif strategy == "naive-random":
@ -240,7 +240,7 @@ class SymExecWrapper:
# Parse calls to make them easily accessible # Parse calls to make them easily accessible
self.calls = [] # type: List[Call] self.calls: List[Call] = []
for key in self.nodes: for key in self.nodes:

@ -30,7 +30,7 @@ def setup_concrete_initial_state(concrete_data: ConcreteData) -> WorldState:
account = Account(address, concrete_storage=True) account = Account(address, concrete_storage=True)
account.code = Disassembly(details["code"][2:]) account.code = Disassembly(details["code"][2:])
account.nonce = details["nonce"] account.nonce = details["nonce"]
if type(details["storage"]) == str: if isinstance(type(details["storage"]), str):
details["storage"] = eval(details["storage"]) # type: ignore details["storage"] = eval(details["storage"]) # type: ignore
for key, value in details["storage"].items(): for key, value in details["storage"].items():
key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256) key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256)

@ -106,7 +106,7 @@ def disassemble(bytecode) -> list:
address = 0 address = 0
length = len(bytecode) length = len(bytecode)
if type(bytecode) == str: if isinstance(bytecode, str):
bytecode = util.safe_decode(bytecode) bytecode = util.safe_decode(bytecode)
length = len(bytecode) length = len(bytecode)
part_code = bytecode[-43:] part_code = bytecode[-43:]
@ -135,7 +135,7 @@ def disassemble(bytecode) -> list:
match = re.search(regex_PUSH, op_code) match = re.search(regex_PUSH, op_code)
if match: if match:
argument_bytes = bytecode[address + 1 : address + 1 + int(match.group(1))] argument_bytes = bytecode[address + 1 : address + 1 + int(match.group(1))]
if type(argument_bytes) == bytes: if isinstance(argument_bytes, bytes):
current_instruction.argument = "0x" + argument_bytes.hex() current_instruction.argument = "0x" + argument_bytes.hex()
else: else:
current_instruction.argument = argument_bytes current_instruction.argument = argument_bytes

@ -23,13 +23,13 @@ class Disassembly(object):
:param enable_online_lookup: :param enable_online_lookup:
""" """
self.bytecode = code self.bytecode = code
if type(code) == str: if isinstance(code, str):
self.instruction_list = asm.disassemble(util.safe_decode(code)) self.instruction_list = asm.disassemble(util.safe_decode(code))
else: else:
self.instruction_list = asm.disassemble(code) self.instruction_list = asm.disassemble(code)
self.func_hashes = [] # type: List[str] self.func_hashes: List[str] = []
self.function_name_to_address = {} # type: Dict[str, int] self.function_name_to_address: Dict[str, int] = {}
self.address_to_function_name = {} # type: Dict[int, str] self.address_to_function_name: Dict[int, str] = {}
self.enable_online_lookup = enable_online_lookup self.enable_online_lookup = enable_online_lookup
self.assign_bytecode(bytecode=code) self.assign_bytecode(bytecode=code)
@ -84,7 +84,7 @@ def get_function_info(
""" """
# Append with missing 0s at the beginning # Append with missing 0s at the beginning
if type(instruction_list[index]["argument"]) == tuple: if isinstance(instruction_list[index]["argument"], tuple):
try: try:
function_hash = "0x" + bytes( function_hash = "0x" + bytes(
instruction_list[index]["argument"] instruction_list[index]["argument"]
@ -105,7 +105,7 @@ def get_function_info(
try: try:
offset = instruction_list[index + 2]["argument"] offset = instruction_list[index + 2]["argument"]
if type(offset) == tuple: if isinstance(offset, tuple):
offset = bytes(offset).hex() offset = bytes(offset).hex()
entry_point = int(offset, 16) entry_point = int(offset, 16)
except (KeyError, IndexError): except (KeyError, IndexError):

@ -216,11 +216,6 @@ def extract_version(file: typing.Optional[str]):
continue continue
return str(version) return str(version)
else:
return None
return None
def extract_binary(file: str) -> str: def extract_binary(file: str) -> str:
file_data = None file_data = None

@ -322,7 +322,7 @@ def main() -> None:
formatter_class=RawTextHelpFormatter, formatter_class=RawTextHelpFormatter,
) )
list_detectors_parser = subparsers.add_parser( _ = subparsers.add_parser(
LIST_DETECTORS_COMMAND, LIST_DETECTORS_COMMAND,
parents=[output_parser], parents=[output_parser],
help="Lists available detection modules", help="Lists available detection modules",

@ -48,7 +48,7 @@ class Node:
constraints = constraints if constraints else Constraints() constraints = constraints if constraints else Constraints()
self.contract_name = contract_name self.contract_name = contract_name
self.start_addr = start_addr self.start_addr = start_addr
self.states = [] # type: List[GlobalState] self.states: List[GlobalState] = []
self.constraints = constraints self.constraints = constraints
self.function_name = function_name self.function_name = function_name
self.flags = NodeFlags() self.flags = NodeFlags()

@ -135,7 +135,7 @@ class KeccakFunctionManager:
:param model: The z3 model to query for concrete values :param model: The z3 model to query for concrete values
:return: A dictionary with concrete hashes { <hash_input_size> : [<concrete_hash>, <concrete_hash>]} :return: A dictionary with concrete hashes { <hash_input_size> : [<concrete_hash>, <concrete_hash>]}
""" """
concrete_hashes = {} # type: Dict[int, List[Optional[int]]] concrete_hashes: Dict[int, List[Optional[int]]] = {}
for size in self.hash_result_store: for size in self.hash_result_store:
concrete_hashes[size] = [] concrete_hashes[size] = []
for val in self.hash_result_store[size]: for val in self.hash_result_store[size]:

@ -293,14 +293,14 @@ class Instruction:
if length_of_value == 0: if length_of_value == 0:
global_state.mstate.stack.append(symbol_factory.BitVecVal(0, 256)) global_state.mstate.stack.append(symbol_factory.BitVecVal(0, 256))
elif type(push_value) == tuple: elif isinstance(push_value, tuple):
if type(push_value[0]) == int: if isinstance(push_value[0], int):
new_value = symbol_factory.BitVecVal(push_value[0], 8) new_value = symbol_factory.BitVecVal(push_value[0], 8)
else: else:
new_value = push_value[0] new_value = push_value[0]
if len(push_value) > 1: if len(push_value) > 1:
for val in push_value[1:]: for val in push_value[1:]:
if type(val) == int: if isinstance(val, int):
new_value = Concat(new_value, symbol_factory.BitVecVal(val, 8)) new_value = Concat(new_value, symbol_factory.BitVecVal(val, 8))
else: else:
new_value = Concat(new_value, val) new_value = Concat(new_value, val)
@ -442,12 +442,12 @@ class Instruction:
index = util.get_concrete_int(op0) index = util.get_concrete_int(op0)
offset = (31 - index) * 8 offset = (31 - index) * 8
if offset >= 0: if offset >= 0:
result = simplify( result: Union[int, Expression] = simplify(
Concat( Concat(
symbol_factory.BitVecVal(0, 248), symbol_factory.BitVecVal(0, 248),
Extract(offset + 7, offset, op1), Extract(offset + 7, offset, op1),
) )
) # type: Union[int, Expression] )
else: else:
result = 0 result = 0
except TypeError: except TypeError:
@ -818,13 +818,13 @@ class Instruction:
return [global_state] return [global_state]
try: try:
dstart = util.get_concrete_int(dstart) # type: Union[int, BitVec] dstart: Union[int, BitVec] = util.get_concrete_int(dstart)
except TypeError: except TypeError:
log.debug("Unsupported symbolic calldata offset in CALLDATACOPY") log.debug("Unsupported symbolic calldata offset in CALLDATACOPY")
dstart = simplify(dstart) dstart = simplify(dstart)
try: try:
size = util.get_concrete_int(size) # type: Union[int, BitVec] size: Union[int, BitVec] = util.get_concrete_int(size)
except TypeError: except TypeError:
log.debug("Unsupported symbolic size in CALLDATACOPY") log.debug("Unsupported symbolic size in CALLDATACOPY")
size = SYMBOLIC_CALLDATA_SIZE # The excess size will get overwritten size = SYMBOLIC_CALLDATA_SIZE # The excess size will get overwritten
@ -1087,7 +1087,7 @@ class Instruction:
global_state.mstate.stack.pop(), global_state.mstate.stack.pop(),
) )
code = global_state.environment.code.bytecode code = global_state.environment.code.bytecode
if code[0:2] == "0x": if code.startswith("0x"):
code = code[2:] code = code[2:]
code_size = len(code) // 2 code_size = len(code) // 2
if isinstance(global_state.current_transaction, ContractCreationTransaction): if isinstance(global_state.current_transaction, ContractCreationTransaction):
@ -1227,7 +1227,7 @@ class Instruction:
) )
return [global_state] return [global_state]
if code[0:2] == "0x": if isinstance(code, str) and code.startswith("0x"):
code = code[2:] code = code[2:]
for i in range(concrete_size): for i in range(concrete_size):
@ -1486,9 +1486,7 @@ class Instruction:
state.mem_extend(offset, 1) state.mem_extend(offset, 1)
try: try:
value_to_write = ( value_to_write: Union[int, BitVec] = util.get_concrete_int(value) % 256
util.get_concrete_int(value) % 256
) # type: Union[int, BitVec]
except TypeError: # BitVec except TypeError: # BitVec
value_to_write = Extract(7, 0, value) value_to_write = Extract(7, 0, value)
@ -1591,10 +1589,10 @@ class Instruction:
condi = simplify(condition) if isinstance(condition, Bool) else condition != 0 condi = simplify(condition) if isinstance(condition, Bool) else condition != 0
condi.simplify() condi.simplify()
negated_cond = (type(negated) == bool and negated) or ( negated_cond = (isinstance(negated, bool) and negated) or (
isinstance(negated, Bool) and not is_false(negated) isinstance(negated, Bool) and not is_false(negated)
) )
positive_cond = (type(condi) == bool and condi) or ( positive_cond = (isinstance(condi, bool) and condi) or (
isinstance(condi, Bool) and not is_false(condi) isinstance(condi, Bool) and not is_false(condi)
) )
@ -1734,7 +1732,7 @@ class Instruction:
world_state = global_state.world_state world_state = global_state.world_state
call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size) call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size)
code_raw = [] code_raw: List[int] = []
code_end = call_data.size code_end = call_data.size
size = call_data.size size = call_data.size
@ -1776,7 +1774,7 @@ class Instruction:
gas_price = environment.gasprice gas_price = environment.gasprice
origin = environment.origin origin = environment.origin
contract_address = None # type: Union[BitVec, int] contract_address: Union[BitVec, int] = None
Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2) Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2)
if create2_salt: if create2_salt:

@ -204,11 +204,11 @@ class Account:
} }
def serialised_code(self): def serialised_code(self):
if type(self.code.bytecode) == str: if isinstance(self.code.bytecode, str):
return self.code.bytecode return self.code.bytecode
new_code = "0x" new_code = "0x"
for byte in self.code.bytecode: for byte in self.code.bytecode:
if type(byte) == int: if isinstance(byte, int):
new_code += hex(byte) new_code += hex(byte)
else: else:
new_code += "<call_data>" new_code += "<call_data>"

@ -278,14 +278,14 @@ class BasicSymbolicCalldata(BaseCalldata):
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
""" """
self._reads = [] # type: List[Tuple[Union[int, BitVec], BitVec]] self._reads: List[Tuple[Union[int, BitVec], BitVec]] = []
self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256) self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256)
super().__init__(tx_id) super().__init__(tx_id)
def _load(self, item: Union[int, BitVec], clean=False) -> Any: def _load(self, item: Union[int, BitVec], clean=False) -> Any:
expr_item = ( expr_item: BitVec = (
symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
) # type: BitVec )
symbolic_base_value = If( symbolic_base_value = If(
expr_item >= self._size, expr_item >= self._size,

@ -4,7 +4,7 @@ from mythril.exceptions import UnsatError, SolverTimeOutException
from mythril.laser.smt import symbol_factory, simplify, Bool from mythril.laser.smt import symbol_factory, simplify, Bool
from mythril.support.model import get_model from mythril.support.model import get_model
from mythril.laser.ethereum.function_managers import keccak_function_manager from mythril.laser.ethereum.function_managers import keccak_function_manager
from mythril.laser.smt.model import Model
from copy import copy from copy import copy
from typing import Iterable, List, Optional, Union from typing import Iterable, List, Optional, Union
@ -42,7 +42,7 @@ class Constraints(list):
return False return False
return True return True
def get_model(self, solver_timeout=None) -> bool: def get_model(self, solver_timeout=None) -> Optional[Model]:
""" """
:param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout :param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout
:return: True/False based on the existence of solution of constraints :return: True/False based on the existence of solution of constraints

@ -31,7 +31,7 @@ class Memory:
def __init__(self): def __init__(self):
"""""" """"""
self._msize = 0 self._msize = 0
self._memory = {} # type: Dict[BitVec, Union[int, BitVec]] self._memory: Dict[BitVec, Union[int, BitVec]] = {}
def __len__(self): def __len__(self):
""" """
@ -179,7 +179,7 @@ class Memory:
step = 1 step = 1
else: else:
assert False, "Currently mentioning step size is not supported" assert False, "Currently mentioning step size is not supported"
assert type(value) == list assert isinstance(value, list)
bvstart, bvstop, bvstep = ( bvstart, bvstop, bvstep = (
convert_bv(start), convert_bv(start),
convert_bv(stop), convert_bv(stop),

@ -29,12 +29,12 @@ class WorldState:
:param transaction_sequence: :param transaction_sequence:
:param annotations: :param annotations:
""" """
self._accounts = {} # type: Dict[int, Account] self._accounts: Dict[int, Account] = {}
self.balances = Array("balance", 256, 256) self.balances = Array("balance", 256, 256)
self.starting_balances = deepcopy(self.balances) self.starting_balances = deepcopy(self.balances)
self.constraints = constraints or Constraints() self.constraints = constraints or Constraints()
self.node = None # type: Optional['Node'] self.node: Optional["Node"] = None
self.transaction_sequence = transaction_sequence or [] self.transaction_sequence = transaction_sequence or []
self._annotations = annotations or [] self._annotations = annotations or []

@ -9,8 +9,8 @@ class BasicSearchStrategy(ABC):
""" """
def __init__(self, work_list, max_depth, **kwargs): def __init__(self, work_list, max_depth, **kwargs):
self.work_list = work_list # type: List[GlobalState] self.work_list: List[GlobalState] = work_list
self.max_depth = max_depth self.max_depth: int = max_depth
def __iter__(self): def __iter__(self):
return self return self

@ -28,7 +28,6 @@ class DelayConstraintStrategy(BasicSearchStrategy):
:return: Global state :return: Global state
""" """
while True:
while len(self.work_list) == 0: while len(self.work_list) == 0:
state = self.pending_worklist.pop(0) state = self.pending_worklist.pop(0)
model = state.world_state.constraints.get_model() model = state.world_state.constraints.get_model()

@ -14,8 +14,8 @@ class JumpdestCountAnnotation(StateAnnotation):
"""State annotation that counts the number of jumps per destination.""" """State annotation that counts the number of jumps per destination."""
def __init__(self) -> None: def __init__(self) -> None:
self._reached_count = {} # type: Dict[int, int] self._reached_count: Dict[int, int] = {}
self.trace = [] # type: List[int] self.trace: List[int] = []
def __copy__(self): def __copy__(self):
result = JumpdestCountAnnotation() result = JumpdestCountAnnotation()

@ -238,7 +238,7 @@ class LaserEVM:
for hook in self._stop_exec_trans_hooks: for hook in self._stop_exec_trans_hooks:
hook() hook()
def _execute_transactions_ordered(self, address): def _execute_transactions_non_ordered(self, address):
""" """
This function executes multiple transactions non-incrementally, using some type priority ordering This function executes multiple transactions non-incrementally, using some type priority ordering
@ -327,7 +327,7 @@ class LaserEVM:
:param track_gas: :param track_gas:
:return: :return:
""" """
final_states = [] # type: List[GlobalState] final_states: List[GlobalState] = []
for hook in self._start_exec_hooks: for hook in self._start_exec_hooks:
hook() hook()
@ -387,7 +387,7 @@ class LaserEVM:
# exceptional halt all changes should be discarded, and this world state would not provide us with a # exceptional halt all changes should be discarded, and this world state would not provide us with a
# previously unseen world state # previously unseen world state
log.debug("Encountered a VmException, ending path: `{}`".format(error_msg)) log.debug("Encountered a VmException, ending path: `{}`".format(error_msg))
new_global_states = [] # type: List[GlobalState] new_global_states: List[GlobalState] = []
else: else:
# First execute the post hook for the transaction ending instruction # First execute the post hook for the transaction ending instruction
self._execute_post_hook(op_code, [global_state]) self._execute_post_hook(op_code, [global_state])

@ -105,7 +105,7 @@ class BaseTransaction:
self.caller = caller self.caller = caller
self.callee_account = callee_account self.callee_account = callee_account
if call_data is None and init_call_data: if call_data is None and init_call_data:
self.call_data = SymbolicCalldata(self.id) # type: BaseCalldata self.call_data: BaseCalldata = SymbolicCalldata(self.id)
else: else:
self.call_data = ( self.call_data = (
call_data call_data
@ -119,7 +119,7 @@ class BaseTransaction:
else symbol_factory.BitVecSym(f"callvalue{identifier}", 256) else symbol_factory.BitVecSym(f"callvalue{identifier}", 256)
) )
self.static = static self.static = static
self.return_data = None # type: str self.return_data: Optional[ReturnData] = None
def initial_global_state_from_environment(self, environment, active_function): def initial_global_state_from_environment(self, environment, active_function):
""" """
@ -278,7 +278,9 @@ class ContractCreationTransaction(BaseTransaction):
tuple(return_data.return_data) tuple(return_data.return_data)
) )
return_data = str(hex(global_state.environment.active_account.address.value)) return_data = str(hex(global_state.environment.active_account.address.value))
self.return_data = ReturnData(return_data, len(return_data) // 2) self.return_data: Optional[ReturnData] = ReturnData(
return_data, symbol_factory.BitVecVal(len(return_data) // 2, 256)
)
assert global_state.environment.active_account.code.instruction_list != [] assert global_state.environment.active_account.code.instruction_list != []
raise TransactionEndSignal(global_state, revert=revert) raise TransactionEndSignal(global_state, revert=revert)

@ -143,7 +143,7 @@ def concrete_int_to_bytes(val):
:return: :return:
""" """
# logging.debug("concrete_int_to_bytes " + str(val)) # logging.debug("concrete_int_to_bytes " + str(val))
if type(val) == int: if isinstance(val, int):
return val.to_bytes(32, byteorder="big") return val.to_bytes(32, byteorder="big")
return simplify(val).value.to_bytes(32, byteorder="big") return simplify(val).value.to_bytes(32, byteorder="big")

@ -17,9 +17,9 @@ class LaserPluginLoader(object, metaclass=Singleton):
def __init__(self) -> None: def __init__(self) -> None:
"""Initializes the plugin loader""" """Initializes the plugin loader"""
self.laser_plugin_builders = {} # type: Dict[str, PluginBuilder] self.laser_plugin_builders: Dict[str, PluginBuilder] = {}
self.plugin_args = {} # type: Dict[str, Dict] self.plugin_args: Dict[str, Dict] = {}
self.plugin_list = {} # type: Dict[str, LaserPlugin] self.plugin_list: Dict[str, LaserPlugin] = {}
def add_args(self, plugin_name, **kwargs): def add_args(self, plugin_name, **kwargs):
self.plugin_args[plugin_name] = kwargs self.plugin_args[plugin_name] = kwargs

@ -30,7 +30,7 @@ class InstructionCoveragePlugin(LaserPlugin):
""" """
def __init__(self): def __init__(self):
self.coverage = {} # type: Dict[str, Tuple[int, List[bool]]] self.coverage: Dict[str, Tuple[int, List[bool]]] = {}
self.initial_coverage = 0 self.initial_coverage = 0
self.tx_id = 0 self.tx_id = 0
@ -54,7 +54,7 @@ class InstructionCoveragePlugin(LaserPlugin):
else: else:
cov_percentage = sum(code_cov[1]) / float(code_cov[0]) * 100 cov_percentage = sum(code_cov[1]) / float(code_cov[0]) * 100
string_code = code string_code = code
if type(code) == tuple: if isinstance(code, tuple):
try: try:
string_code = bytearray(code).hex() string_code = bytearray(code).hex()
except TypeError: except TypeError:

@ -95,10 +95,10 @@ class DependencyPruner(LaserPlugin):
def _reset(self): def _reset(self):
self.iteration = 0 self.iteration = 0
self.calls_on_path = {} # type: Dict[int, bool] self.calls_on_path: Dict[int, bool] = {}
self.sloads_on_path = {} # type: Dict[int, List[object]] self.sloads_on_path: Dict[int, List[object]] = {}
self.sstores_on_path = {} # type: Dict[int, List[object]] self.sstores_on_path: Dict[int, List[object]] = {}
self.storage_accessed_global = set() # type: Set self.storage_accessed_global: Set = set()
def update_sloads(self, path: List[int], target_location: object) -> None: def update_sloads(self, path: List[int], target_location: object) -> None:
"""Update the dependency map for the block offsets on the given path. """Update the dependency map for the block offsets on the given path.

@ -31,11 +31,11 @@ class DependencyAnnotation(MergeableStateAnnotation):
""" """
def __init__(self): def __init__(self):
self.storage_loaded = set() # type: Set self.storage_loaded: Set = set()
self.storage_written = {} # type: Dict[int, Set] self.storage_written: Dict[int, Set] = {}
self.has_call = False # type: bool self.has_call: bool = False
self.path = [0] # type: List self.path: List = [0]
self.blocks_seen = set() # type: Set[int] self.blocks_seen: Set[int] = set()
def __copy__(self): def __copy__(self):
result = DependencyAnnotation() result = DependencyAnnotation()

@ -190,7 +190,7 @@ def Sum(*args: BitVec) -> BitVec:
:return: :return:
""" """
raw = z3.Sum([a.raw for a in args]) raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations annotations: Annotations = set()
for bv in args: for bv in args:
annotations = annotations.union(bv.annotations) annotations = annotations.union(bv.annotations)

@ -97,7 +97,7 @@ class Bool(Expression[z3.BoolRef]):
def And(*args: Union[Bool, bool]) -> Bool: def And(*args: Union[Bool, bool]) -> Bool:
"""Create an And expression.""" """Create an And expression."""
annotations = set() # type: Set annotations: Set = set()
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
for arg in args_list: for arg in args_list:
annotations = annotations.union(arg.annotations) annotations = annotations.union(arg.annotations)
@ -119,7 +119,7 @@ def Or(*args: Union[Bool, bool]) -> Bool:
:return: :return:
""" """
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args] args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
annotations = set() # type: Set annotations: Set = set()
for arg in args_list: for arg in args_list:
annotations = annotations.union(arg.annotations) annotations = annotations.union(arg.annotations)
return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations) return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations)

@ -19,7 +19,7 @@ class Model:
def decls(self) -> List[z3.ExprRef]: def decls(self) -> List[z3.ExprRef]:
"""Get the declarations for this model""" """Get the declarations for this model"""
result = [] # type: List[z3.ExprRef] result: List[z3.ExprRef] = []
for internal_model in self.raw: for internal_model in self.raw:
result.extend(internal_model.decls()) result.extend(internal_model.decls())
return result return result

@ -31,8 +31,8 @@ class DependenceBucket:
:param variables: Variables contained in the conditions :param variables: Variables contained in the conditions
:param conditions: The conditions that are dependent on each other :param conditions: The conditions that are dependent on each other
""" """
self.variables = variables or [] # type: List[z3.ExprRef] self.variables: List[z3.ExprRef] = variables or []
self.conditions = conditions or [] # type: List[z3.ExprRef] self.conditions: List[z3.ExprRef] = conditions or []
class DependenceMap: class DependenceMap:
@ -40,8 +40,8 @@ class DependenceMap:
def __init__(self): def __init__(self):
"""Initializes a DependenceMap object""" """Initializes a DependenceMap object"""
self.buckets = [] # type: List[DependenceBucket] self.buckets: List[DependenceBucket] = []
self.variable_map = {} # type: Dict[str, DependenceBucket] self.variable_map: Dict[str, DependenceBucket] = {}
def add_condition(self, condition: z3.BoolRef) -> None: def add_condition(self, condition: z3.BoolRef) -> None:
""" """
@ -70,8 +70,8 @@ class DependenceMap:
def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket: def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket:
"""Merges the buckets in bucket list""" """Merges the buckets in bucket list"""
variables = [] # type: List[str] variables: List[str] = []
conditions = [] # type: List[z3.BoolRef] conditions: List[z3.BoolRef] = []
for bucket in bucket_list: for bucket in bucket_list:
self.buckets.remove(bucket) self.buckets.remove(bucket)
variables += bucket.variables variables += bucket.variables
@ -100,14 +100,14 @@ class IndependenceSolver:
""" """
self.raw.set(timeout=timeout) self.raw.set(timeout=timeout)
def add(self, *constraints: Tuple[Bool]) -> None: def add(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver. """Adds the constraints to this solver.
:param constraints: constraints to add :param constraints: constraints to add
""" """
raw_constraints = [ raw_constraints: List[z3.BoolRef] = [
c.raw for c in cast(Tuple[Bool], constraints) c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef] ]
self.constraints.extend(raw_constraints) self.constraints.extend(raw_constraints)
def append(self, *constraints: Tuple[Bool]) -> None: def append(self, *constraints: Tuple[Bool]) -> None:
@ -115,9 +115,9 @@ class IndependenceSolver:
:param constraints: constraints to add :param constraints: constraints to add
""" """
raw_constraints = [ raw_constraints: List[z3.BoolRef] = [
c.raw for c in cast(Tuple[Bool], constraints) c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef] ]
self.constraints.extend(raw_constraints) self.constraints.extend(raw_constraints)
@stat_smt_query @stat_smt_query

@ -28,18 +28,18 @@ class BaseSolver(Generic[T]):
""" """
self.raw.set(timeout=timeout) self.raw.set(timeout=timeout)
def add(self, *constraints: List[Bool]) -> None: def add(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver. """Adds the constraints to this solver.
:param constraints: :param constraints:
:return: :return:
""" """
z3_constraints = [ z3_constraints: Sequence[z3.BoolRef] = [
c.raw for c in cast(List[Bool], constraints) c.raw for c in cast(List[Bool], constraints)
] # type: Sequence[z3.BoolRef] ]
self.raw.add(z3_constraints) self.raw.add(z3_constraints)
def append(self, *constraints: List[Bool]) -> None: def append(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver. """Adds the constraints to this solver.
:param constraints: :param constraints:

@ -47,7 +47,7 @@ class MythrilAnalyzer:
:param address: Address of the contract :param address: Address of the contract
""" """
self.eth = disassembler.eth self.eth = disassembler.eth
self.contracts = disassembler.contracts or [] # type: List[EVMContract] self.contracts: List[EVMContract] = disassembler.contracts or []
self.enable_online_lookup = disassembler.enable_online_lookup self.enable_online_lookup = disassembler.enable_online_lookup
self.use_onchain_data = not cmd_args.no_onchain_data self.use_onchain_data = not cmd_args.no_onchain_data
self.strategy = strategy self.strategy = strategy
@ -141,10 +141,10 @@ class MythrilAnalyzer:
:param transaction_count: The amount of transactions to be executed :param transaction_count: The amount of transactions to be executed
:return: The Report class which contains the all the issues/vulnerabilities :return: The Report class which contains the all the issues/vulnerabilities
""" """
all_issues = [] # type: List[Issue] all_issues: List[Issue] = []
SolverStatistics().enabled = True SolverStatistics().enabled = True
exceptions = [] exceptions = []
execution_info = None # type: Optional[List[ExecutionInfo]] execution_info: Optional[List[ExecutionInfo]] = None
for contract in self.contracts: for contract in self.contracts:
StartTime() # Reinitialize start time for new contracts StartTime() # Reinitialize start time for new contracts
try: try:

@ -22,11 +22,11 @@ class MythrilConfig:
""" """
def __init__(self): def __init__(self):
self.infura_id = os.getenv("INFURA_ID") # type: str self.infura_id: str = os.getenv("INFURA_ID")
self.mythril_dir = self.init_mythril_dir() self.mythril_dir = self.init_mythril_dir()
self.config_path = os.path.join(self.mythril_dir, "config.ini") self.config_path = os.path.join(self.mythril_dir, "config.ini")
self._init_config() self._init_config()
self.eth = None # type: Optional[EthJsonRpc] self.eth: Optional[EthJsonRpc] = None
def set_api_infura_id(self, id): def set_api_infura_id(self, id):
self.infura_id = id self.infura_id = id

@ -30,11 +30,11 @@ from mythril.solidity.soliditycontract import (
from mythril.support.support_args import args from mythril.support.support_args import args
def format_Warning(message, category, filename, lineno, line=""): def format_warning(message, category, filename, lineno, line=""):
return "{}: {}\n\n".format(str(filename), str(message)) return "{}: {}\n\n".format(str(filename), str(message))
warnings.formatwarning = format_Warning warnings.formatwarning = format_warning
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -62,7 +62,7 @@ class MythrilDisassembler:
self.eth = eth self.eth = eth
self.enable_online_lookup = enable_online_lookup self.enable_online_lookup = enable_online_lookup
self.sigs = signatures.SignatureDB(enable_online_lookup=enable_online_lookup) self.sigs = signatures.SignatureDB(enable_online_lookup=enable_online_lookup)
self.contracts = [] # type: List[EVMContract] self.contracts: List[EVMContract] = []
@staticmethod @staticmethod
def _init_solc_binary(version: str) -> Optional[str]: def _init_solc_binary(version: str) -> Optional[str]:
@ -190,7 +190,6 @@ class MythrilDisassembler:
build_dir = os.path.join(project_root, "artifacts", "contracts", "build-info") build_dir = os.path.join(project_root, "artifacts", "contracts", "build-info")
files = os.listdir(build_dir)
address = util.get_indexed_address(0) address = util.get_indexed_address(0)
files = sorted( files = sorted(

@ -12,7 +12,7 @@ class PluginDiscovery(object, metaclass=Singleton):
""" """
# Installed plugins structure. Retrieves all modules that have an entry point for mythril.plugins # Installed plugins structure. Retrieves all modules that have an entry point for mythril.plugins
_installed_plugins = None # type: Optional[Dict[str, Any]] _installed_plugins: Optional[Dict[str, Any]] = None
def init_installed_plugins(self): def init_installed_plugins(self):
self._installed_plugins = { self._installed_plugins = {

@ -27,7 +27,7 @@ class MythrilPluginLoader(object, metaclass=Singleton):
def __init__(self): def __init__(self):
log.info("Initializing mythril plugin loader") log.info("Initializing mythril plugin loader")
self.loaded_plugins = [] self.loaded_plugins = []
self.plugin_args = dict() # type: Dict[str, Dict] self.plugin_args: Dict[str, Dict] = dict()
self._load_default_enabled() self._load_default_enabled()
def set_args(self, plugin_name: str, **kwargs): def set_args(self, plugin_name: str, **kwargs):

@ -203,7 +203,7 @@ class SolidityFeatureExtractor:
return variables return variables
def extract_address_variable(self, node): def extract_address_variable(self, node):
if type(node) == int: if isinstance(node, int):
return set([]) return set([])
transfer_vars = set([]) transfer_vars = set([])
if ( if (
@ -211,12 +211,10 @@ class SolidityFeatureExtractor:
and node.get("expression", {}).get("nodeType") == "FunctionCall" and node.get("expression", {}).get("nodeType") == "FunctionCall"
): ):
expression = node["expression"].get("expression", None) expression = node["expression"].get("expression", None)
if expression is not None: if expression is not None and (
if (
expression["nodeType"] == "MemberAccess" expression["nodeType"] == "MemberAccess"
and expression["memberName"] in TRANSFER_METHODS and expression["memberName"] in TRANSFER_METHODS
): ):
print(expression)
address_variable = expression["expression"].get("name") address_variable = expression["expression"].get("name")
if address_variable: if address_variable:
transfer_vars.update(set([address_variable])) transfer_vars.update(set([address_variable]))

@ -40,7 +40,7 @@ class DynLoader:
value = self.eth.eth_getStorageAt( value = self.eth.eth_getStorageAt(
contract_address, position=index, block="latest" contract_address, position=index, block="latest"
) )
if value == "0x": if value.startswith("0x"):
value = "0x0000000000000000000000000000000000000000000000000000000000000000" value = "0x0000000000000000000000000000000000000000000000000000000000000000"
return value return value
@ -96,7 +96,7 @@ class DynLoader:
code = self.eth.eth_getCode(dependency_address) code = self.eth.eth_getCode(dependency_address)
if code == "0x": if code.startswith("0x"):
return None return None
else: else:
return Disassembly(code) return Disassembly(code)

@ -86,12 +86,16 @@ def get_model(
if solver_timeout <= 0: if solver_timeout <= 0:
raise SolverTimeOutException raise SolverTimeOutException
for constraint in constraints: for constraint in constraints:
if type(constraint) == bool and not constraint: if isinstance(constraint, bool) and not constraint:
raise UnsatError raise UnsatError
if type(constraints) != tuple: if isinstance(constraints, tuple) is False:
constraints = constraints.get_all_constraints() constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool] constraints = [
constraint
for constraint in constraints
if isinstance(constraint, bool) is False
]
if len(maximize) + len(minimize) == 0: if len(maximize) + len(minimize) == 0:
ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw) ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw)

@ -44,7 +44,7 @@ def synchronized(sync_lock):
class Singleton(type): class Singleton(type):
"""A metaclass type implementing the singleton pattern.""" """A metaclass type implementing the singleton pattern."""
_instances = dict() # type: Dict[Singleton, Singleton] _instances: Dict["Singleton", "Singleton"] = dict()
@synchronized(lock) @synchronized(lock)
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
@ -123,12 +123,12 @@ class SignatureDB(object, metaclass=Singleton):
:param path: :param path:
""" """
self.enable_online_lookup = enable_online_lookup self.enable_online_lookup = enable_online_lookup
self.online_lookup_miss = set() # type: Set[str] self.online_lookup_miss: Set[str] = set()
self.online_lookup_timeout = 0 self.online_lookup_timeout = 0
# if we're analysing a Solidity file, store its hashes # if we're analysing a Solidity file, store its hashes
# here to prevent unnecessary lookups # here to prevent unnecessary lookups
self.solidity_sigs = defaultdict(list) # type: DefaultDict[str, List[str]] self.solidity_sigs: DefaultDict[str, List[str]] = defaultdict(list)
if path is None: if path is None:
self.path = os.environ.get("MYTHRIL_DIR") or os.path.join( self.path = os.environ.get("MYTHRIL_DIR") or os.path.join(
os.path.expanduser("~"), ".mythril" os.path.expanduser("~"), ".mythril"

@ -38,7 +38,7 @@ class Source:
self.source_format = "evm-byzantium-bytecode" self.source_format = "evm-byzantium-bytecode"
self.source_type = ( self.source_type = (
"ethereum-address" "ethereum-address"
if len(contracts[0].name) == 42 and contracts[0].name[0:2] == "0x" if len(contracts[0].name) == 42 and contracts[0].name.startswith("0x")
else "raw-bytecode" else "raw-bytecode"
) )
for contract in contracts: for contract in contracts:

@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class Singleton(type): class Singleton(type):
"""A metaclass type implementing the singleton pattern.""" """A metaclass type implementing the singleton pattern."""
_instances = {} # type: Dict _instances: Dict = {}
def __call__(cls, *args, **kwargs): def __call__(cls, *args, **kwargs):
"""Delegate the call to an existing resource or a a new one. """Delegate the call to an existing resource or a a new one.
@ -59,7 +59,6 @@ class ModelCache:
@lru_cache(maxsize=2**10) @lru_cache(maxsize=2**10)
def check_quick_sat(self, constraints) -> bool: def check_quick_sat(self, constraints) -> bool:
model_list = list(reversed(self.model_cache.lru_cache.keys()))
for model in reversed(self.model_cache.lru_cache.keys()): for model in reversed(self.model_cache.lru_cache.keys()):
model_copy = deepcopy(model) model_copy = deepcopy(model)
if is_true(model_copy.eval(constraints, model_completion=True)): if is_true(model_copy.eval(constraints, model_completion=True)):
@ -77,11 +76,11 @@ def get_code_hash(code) -> str:
:param code: bytecode :param code: bytecode
:return: Returns hash of the given bytecode :return: Returns hash of the given bytecode
""" """
if type(code) == tuple: if isinstance(code, tuple):
# Temporary hack, since we cannot find symbols of sha3 # Temporary hack, since we cannot find symbols of sha3
return str(hash(code)) return str(hash(code))
code = code[2:] if code[:2] == "0x" else code code = code[2:] if code.startswith("0x") else code
try: try:
hash_ = keccak(bytes.fromhex(code)) hash_ = keccak(bytes.fromhex(code))
return "0x" + hash_.hex() return "0x" + hash_.hex()
@ -91,8 +90,8 @@ def get_code_hash(code) -> str:
def sha3(value): def sha3(value):
if type(value) == str: if isinstance(value, str):
if value[:2] == "0x": if value.startswith("0x"):
new_hash = keccak(bytes.fromhex(value)) new_hash = keccak(bytes.fromhex(value))
else: else:
new_hash = keccak(value.encode()) new_hash = keccak(value.encode())

@ -61,7 +61,6 @@ test_cases = [
def test_feature_selfdestruct(file_name, num_funcs, func_name, field, expected_value): def test_feature_selfdestruct(file_name, num_funcs, func_name, field, expected_value):
input_file = TEST_FILES / file_name input_file = TEST_FILES / file_name
name = file_name.split(".")[0] name = file_name.split(".")[0]
print(name, name.capitalize())
if name[0].islower(): if name[0].islower():
name = name.capitalize() name = name.capitalize()
contract = SolidityContract(str(input_file), name=name, solc_binary=solc_binary) contract = SolidityContract(str(input_file), name=name, solc_binary=solc_binary)

@ -182,8 +182,8 @@ def test_vmtest(
actual = actual.value actual = actual.value
actual = 1 if actual is True else 0 if actual is False else actual actual = 1 if actual is True else 0 if actual is False else actual
else: else:
if type(actual) == bytes: if isinstance(actual, bytes):
actual = int(binascii.b2a_hex(actual), 16) actual = int(binascii.b2a_hex(actual), 16)
elif type(actual) == str: elif isinstance(actual, str):
actual = int(actual, 16) actual = int(actual, 16)
assert actual == expected assert actual == expected

@ -5,7 +5,6 @@ from unittest.mock import Mock, patch, mock_open
def mock_predict_proba(X): def mock_predict_proba(X):
print(X)
if X[0][-1] == 1: if X[0][-1] == 1:
return np.array([[0.1, 0.7, 0.1, 0.1]]) return np.array([[0.1, 0.7, 0.1, 0.1]])
elif X[0][-1] == 2: elif X[0][-1] == 2:

Loading…
Cancel
Save