* Misc fixes

* Fix singleton
pull/1800/head
Nikhil Parasaram 1 year ago committed by GitHub
parent fd48221682
commit bd27897533
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      mythril/analysis/module/modules/external_calls.py
  2. 6
      mythril/analysis/module/modules/integer.py
  3. 2
      mythril/analysis/module/modules/multiple_sends.py
  4. 4
      mythril/analysis/module/modules/state_change_external_calls.py
  5. 2
      mythril/analysis/module/util.py
  6. 6
      mythril/analysis/report.py
  7. 4
      mythril/analysis/security.py
  8. 6
      mythril/analysis/solver.py
  9. 4
      mythril/analysis/symbolic.py
  10. 2
      mythril/concolic/find_trace.py
  11. 4
      mythril/disassembler/asm.py
  12. 12
      mythril/disassembler/disassembly.py
  13. 5
      mythril/ethereum/util.py
  14. 2
      mythril/interfaces/cli.py
  15. 2
      mythril/laser/ethereum/cfg.py
  16. 2
      mythril/laser/ethereum/function_managers/keccak_function_manager.py
  17. 30
      mythril/laser/ethereum/instructions.py
  18. 4
      mythril/laser/ethereum/state/account.py
  19. 6
      mythril/laser/ethereum/state/calldata.py
  20. 4
      mythril/laser/ethereum/state/constraints.py
  21. 4
      mythril/laser/ethereum/state/memory.py
  22. 4
      mythril/laser/ethereum/state/world_state.py
  23. 4
      mythril/laser/ethereum/strategy/__init__.py
  24. 17
      mythril/laser/ethereum/strategy/constraint_strategy.py
  25. 4
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  26. 6
      mythril/laser/ethereum/svm.py
  27. 8
      mythril/laser/ethereum/transaction/transaction_models.py
  28. 2
      mythril/laser/ethereum/util.py
  29. 6
      mythril/laser/plugin/loader.py
  30. 4
      mythril/laser/plugin/plugins/coverage/coverage_plugin.py
  31. 8
      mythril/laser/plugin/plugins/dependency_pruner.py
  32. 10
      mythril/laser/plugin/plugins/plugin_annotations.py
  33. 2
      mythril/laser/smt/bitvec_helper.py
  34. 4
      mythril/laser/smt/bool.py
  35. 2
      mythril/laser/smt/model.py
  36. 22
      mythril/laser/smt/solver/independence_solver.py
  37. 8
      mythril/laser/smt/solver/solver.py
  38. 6
      mythril/mythril/mythril_analyzer.py
  39. 4
      mythril/mythril/mythril_config.py
  40. 7
      mythril/mythril/mythril_disassembler.py
  41. 2
      mythril/plugin/discovery.py
  42. 2
      mythril/plugin/loader.py
  43. 18
      mythril/solidity/features.py
  44. 4
      mythril/support/loader.py
  45. 10
      mythril/support/model.py
  46. 6
      mythril/support/signatures.py
  47. 2
      mythril/support/source_support.py
  48. 11
      mythril/support/support_utils.py
  49. 1
      tests/features_test.py
  50. 4
      tests/laser/evm_testsuite/evm_test.py
  51. 1
      tests/laser/tx_prioritisation_test.py

@ -27,8 +27,8 @@ Search for external calls with unrestricted gas to a user-specified address.
def _is_precompile_call(global_state: GlobalState):
to = global_state.mstate.stack[-2] # type: BitVec
constraints = copy(global_state.world_state.constraints)
to: BitVec = global_state.mstate.stack[-2]
constraints: Constraints = copy(global_state.world_state.constraints)
constraints += [
Or(
to < symbol_factory.BitVecVal(1, 256),

@ -50,7 +50,7 @@ class OverUnderflowStateAnnotation(StateAnnotation):
"""State Annotation used if an overflow is both possible and used in the annotated path"""
def __init__(self) -> None:
self.overflowing_state_annotations = set() # type: Set[OverUnderflowAnnotation]
self.overflowing_state_annotations: Set[OverUnderflowAnnotation] = set()
def __copy__(self):
new_annotation = OverUnderflowStateAnnotation()
@ -91,8 +91,8 @@ class IntegerArithmetics(DetectionModule):
"""
super().__init__()
self._ostates_satisfiable = set() # type: Set[GlobalState]
self._ostates_unsatisfiable = set() # type: Set[GlobalState]
self._ostates_satisfiable: Set[GlobalState] = set()
self._ostates_unsatisfiable: Set[GlobalState] = set()
def reset_module(self):
"""

@ -17,7 +17,7 @@ log = logging.getLogger(__name__)
class MultipleSendsAnnotation(StateAnnotation):
def __init__(self) -> None:
self.call_offsets = [] # type: List[int]
self.call_offsets: List[int] = []
def __copy__(self):
result = MultipleSendsAnnotation()

@ -29,7 +29,7 @@ STATE_READ_WRITE_LIST = ["SSTORE", "SLOAD", "CREATE", "CREATE2"]
class StateChangeCallsAnnotation(StateAnnotation):
def __init__(self, call_state: GlobalState, user_defined_address: bool) -> None:
self.call_state = call_state
self.state_change_states = [] # type: List[GlobalState]
self.state_change_states: List[GlobalState] = []
self.user_defined_address = user_defined_address
def __copy__(self):
@ -165,7 +165,7 @@ class StateChangeAfterCall(DetectionModule):
# Record state changes following from a transfer of ether
if op_code in CALL_LIST:
value = global_state.mstate.stack[-3] # type: BitVec
value: BitVec = global_state.mstate.stack[-3]
if StateChangeAfterCall._balance_change(value, global_state):
for annotation in annotations:
annotation.state_change_states.append(global_state)

@ -19,7 +19,7 @@ def get_detection_module_hooks(
:param hook_type: The type of hooks to retrieve (default: "pre")
:return: Dictionary with discovered hooks
"""
hook_dict = defaultdict(list) # type: Mapping[str, List[Callable]]
hook_dict: Mapping[str, List[Callable]] = defaultdict(list)
for module in modules:
hooks = module.pre_hooks if hook_type == "pre" else module.post_hooks

@ -254,9 +254,9 @@ class Report:
:param contracts:
:param exceptions:
"""
self.issues = {} # type: Dict[bytes, Issue]
self.issues: Dict[bytes, Issue] = {}
self.solc_version = ""
self.meta = {} # type: Dict[str, Any]
self.meta: Dict[str, Any] = {}
self.source = Source()
self.source.get_source_from_contracts_list(contracts)
self.exceptions = exceptions or []
@ -306,7 +306,7 @@ class Report:
def _get_exception_data(self) -> dict:
if not self.exceptions:
return {}
logs = [] # type: List[Dict]
logs: List[Dict] = []
for exception in self.exceptions:
logs += [{"level": "error", "hidden": True, "msg": exception}]
return {"logs": logs}

@ -13,7 +13,7 @@ log = logging.getLogger(__name__)
def retrieve_callback_issues(white_list: Optional[List[str]] = None) -> List[Issue]:
"""Get the issues discovered by callback type detection modules"""
issues = [] # type: List[Issue]
issues: List[Issue] = []
for module in ModuleLoader().get_detection_modules(
entry_point=EntryPoint.CALLBACK, white_list=white_list
):
@ -34,7 +34,7 @@ def fire_lasers(statespace, white_list: Optional[List[str]] = None) -> List[Issu
"""
log.info("Starting analysis")
issues = [] # type: List[Issue]
issues: List[Issue] = []
for module in ModuleLoader().get_detection_modules(
entry_point=EntryPoint.POST, white_list=white_list
):

@ -36,7 +36,7 @@ def pretty_print_model(model):
ret = ""
for d in model.decls():
if type(model[d]) == FuncInterp:
if isinstance(model[d], FuncInterp):
condition = model[d].as_list()
ret += "%s: %s\n" % (d.name(), condition)
continue
@ -119,7 +119,7 @@ def _add_calldata_placeholder(
if not isinstance(transaction_sequence[0], ContractCreationTransaction):
return
if type(transaction_sequence[0].code.bytecode) == tuple:
if isinstance(transaction_sequence[0].code.bytecode, tuple):
code_len = len(transaction_sequence[0].code.bytecode) * 2
else:
code_len = len(transaction_sequence[0].code.bytecode)
@ -206,7 +206,7 @@ def _get_concrete_transaction(model: z3.Model, transaction: BaseTransaction):
)
# Create concrete transaction dict
concrete_transaction = dict() # type: Dict[str, str]
concrete_transaction: Dict[str, str] = dict()
concrete_transaction["input"] = "0x" + input_
concrete_transaction["value"] = "0x%x" % value
# Fixme: base origin assignment on origin symbol

@ -85,7 +85,7 @@ class SymExecWrapper:
address = symbol_factory.BitVecVal(address, 256)
beam_width = None
if strategy == "dfs":
s_strategy = DepthFirstSearchStrategy # type: Type[BasicSearchStrategy]
s_strategy: Type[BasicSearchStrategy] = DepthFirstSearchStrategy
elif strategy == "bfs":
s_strategy = BreadthFirstSearchStrategy
elif strategy == "naive-random":
@ -240,7 +240,7 @@ class SymExecWrapper:
# Parse calls to make them easily accessible
self.calls = [] # type: List[Call]
self.calls: List[Call] = []
for key in self.nodes:

@ -30,7 +30,7 @@ def setup_concrete_initial_state(concrete_data: ConcreteData) -> WorldState:
account = Account(address, concrete_storage=True)
account.code = Disassembly(details["code"][2:])
account.nonce = details["nonce"]
if type(details["storage"]) == str:
if isinstance(type(details["storage"]), str):
details["storage"] = eval(details["storage"]) # type: ignore
for key, value in details["storage"].items():
key_bitvec = symbol_factory.BitVecVal(int(key, 16), 256)

@ -106,7 +106,7 @@ def disassemble(bytecode) -> list:
address = 0
length = len(bytecode)
if type(bytecode) == str:
if isinstance(bytecode, str):
bytecode = util.safe_decode(bytecode)
length = len(bytecode)
part_code = bytecode[-43:]
@ -135,7 +135,7 @@ def disassemble(bytecode) -> list:
match = re.search(regex_PUSH, op_code)
if match:
argument_bytes = bytecode[address + 1 : address + 1 + int(match.group(1))]
if type(argument_bytes) == bytes:
if isinstance(argument_bytes, bytes):
current_instruction.argument = "0x" + argument_bytes.hex()
else:
current_instruction.argument = argument_bytes

@ -23,13 +23,13 @@ class Disassembly(object):
:param enable_online_lookup:
"""
self.bytecode = code
if type(code) == str:
if isinstance(code, str):
self.instruction_list = asm.disassemble(util.safe_decode(code))
else:
self.instruction_list = asm.disassemble(code)
self.func_hashes = [] # type: List[str]
self.function_name_to_address = {} # type: Dict[str, int]
self.address_to_function_name = {} # type: Dict[int, str]
self.func_hashes: List[str] = []
self.function_name_to_address: Dict[str, int] = {}
self.address_to_function_name: Dict[int, str] = {}
self.enable_online_lookup = enable_online_lookup
self.assign_bytecode(bytecode=code)
@ -84,7 +84,7 @@ def get_function_info(
"""
# Append with missing 0s at the beginning
if type(instruction_list[index]["argument"]) == tuple:
if isinstance(instruction_list[index]["argument"], tuple):
try:
function_hash = "0x" + bytes(
instruction_list[index]["argument"]
@ -105,7 +105,7 @@ def get_function_info(
try:
offset = instruction_list[index + 2]["argument"]
if type(offset) == tuple:
if isinstance(offset, tuple):
offset = bytes(offset).hex()
entry_point = int(offset, 16)
except (KeyError, IndexError):

@ -216,11 +216,6 @@ def extract_version(file: typing.Optional[str]):
continue
return str(version)
else:
return None
return None
def extract_binary(file: str) -> str:
file_data = None

@ -322,7 +322,7 @@ def main() -> None:
formatter_class=RawTextHelpFormatter,
)
list_detectors_parser = subparsers.add_parser(
_ = subparsers.add_parser(
LIST_DETECTORS_COMMAND,
parents=[output_parser],
help="Lists available detection modules",

@ -48,7 +48,7 @@ class Node:
constraints = constraints if constraints else Constraints()
self.contract_name = contract_name
self.start_addr = start_addr
self.states = [] # type: List[GlobalState]
self.states: List[GlobalState] = []
self.constraints = constraints
self.function_name = function_name
self.flags = NodeFlags()

@ -135,7 +135,7 @@ class KeccakFunctionManager:
:param model: The z3 model to query for concrete values
:return: A dictionary with concrete hashes { <hash_input_size> : [<concrete_hash>, <concrete_hash>]}
"""
concrete_hashes = {} # type: Dict[int, List[Optional[int]]]
concrete_hashes: Dict[int, List[Optional[int]]] = {}
for size in self.hash_result_store:
concrete_hashes[size] = []
for val in self.hash_result_store[size]:

@ -293,14 +293,14 @@ class Instruction:
if length_of_value == 0:
global_state.mstate.stack.append(symbol_factory.BitVecVal(0, 256))
elif type(push_value) == tuple:
if type(push_value[0]) == int:
elif isinstance(push_value, tuple):
if isinstance(push_value[0], int):
new_value = symbol_factory.BitVecVal(push_value[0], 8)
else:
new_value = push_value[0]
if len(push_value) > 1:
for val in push_value[1:]:
if type(val) == int:
if isinstance(val, int):
new_value = Concat(new_value, symbol_factory.BitVecVal(val, 8))
else:
new_value = Concat(new_value, val)
@ -442,12 +442,12 @@ class Instruction:
index = util.get_concrete_int(op0)
offset = (31 - index) * 8
if offset >= 0:
result = simplify(
result: Union[int, Expression] = simplify(
Concat(
symbol_factory.BitVecVal(0, 248),
Extract(offset + 7, offset, op1),
)
) # type: Union[int, Expression]
)
else:
result = 0
except TypeError:
@ -818,13 +818,13 @@ class Instruction:
return [global_state]
try:
dstart = util.get_concrete_int(dstart) # type: Union[int, BitVec]
dstart: Union[int, BitVec] = util.get_concrete_int(dstart)
except TypeError:
log.debug("Unsupported symbolic calldata offset in CALLDATACOPY")
dstart = simplify(dstart)
try:
size = util.get_concrete_int(size) # type: Union[int, BitVec]
size: Union[int, BitVec] = util.get_concrete_int(size)
except TypeError:
log.debug("Unsupported symbolic size in CALLDATACOPY")
size = SYMBOLIC_CALLDATA_SIZE # The excess size will get overwritten
@ -1087,7 +1087,7 @@ class Instruction:
global_state.mstate.stack.pop(),
)
code = global_state.environment.code.bytecode
if code[0:2] == "0x":
if code.startswith("0x"):
code = code[2:]
code_size = len(code) // 2
if isinstance(global_state.current_transaction, ContractCreationTransaction):
@ -1227,7 +1227,7 @@ class Instruction:
)
return [global_state]
if code[0:2] == "0x":
if isinstance(code, str) and code.startswith("0x"):
code = code[2:]
for i in range(concrete_size):
@ -1486,9 +1486,7 @@ class Instruction:
state.mem_extend(offset, 1)
try:
value_to_write = (
util.get_concrete_int(value) % 256
) # type: Union[int, BitVec]
value_to_write: Union[int, BitVec] = util.get_concrete_int(value) % 256
except TypeError: # BitVec
value_to_write = Extract(7, 0, value)
@ -1591,10 +1589,10 @@ class Instruction:
condi = simplify(condition) if isinstance(condition, Bool) else condition != 0
condi.simplify()
negated_cond = (type(negated) == bool and negated) or (
negated_cond = (isinstance(negated, bool) and negated) or (
isinstance(negated, Bool) and not is_false(negated)
)
positive_cond = (type(condi) == bool and condi) or (
positive_cond = (isinstance(condi, bool) and condi) or (
isinstance(condi, Bool) and not is_false(condi)
)
@ -1734,7 +1732,7 @@ class Instruction:
world_state = global_state.world_state
call_data = get_call_data(global_state, mem_offset, mem_offset + mem_size)
code_raw = []
code_raw: List[int] = []
code_end = call_data.size
size = call_data.size
@ -1776,7 +1774,7 @@ class Instruction:
gas_price = environment.gasprice
origin = environment.origin
contract_address = None # type: Union[BitVec, int]
contract_address: Union[BitVec, int] = None
Instruction._sha3_gas_helper(global_state, len(code_str[2:]) // 2)
if create2_salt:

@ -204,11 +204,11 @@ class Account:
}
def serialised_code(self):
if type(self.code.bytecode) == str:
if isinstance(self.code.bytecode, str):
return self.code.bytecode
new_code = "0x"
for byte in self.code.bytecode:
if type(byte) == int:
if isinstance(byte, int):
new_code += hex(byte)
else:
new_code += "<call_data>"

@ -278,14 +278,14 @@ class BasicSymbolicCalldata(BaseCalldata):
:param tx_id: Id of the transaction that the calldata is for.
"""
self._reads = [] # type: List[Tuple[Union[int, BitVec], BitVec]]
self._reads: List[Tuple[Union[int, BitVec], BitVec]] = []
self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256)
super().__init__(tx_id)
def _load(self, item: Union[int, BitVec], clean=False) -> Any:
expr_item = (
expr_item: BitVec = (
symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
) # type: BitVec
)
symbolic_base_value = If(
expr_item >= self._size,

@ -4,7 +4,7 @@ from mythril.exceptions import UnsatError, SolverTimeOutException
from mythril.laser.smt import symbol_factory, simplify, Bool
from mythril.support.model import get_model
from mythril.laser.ethereum.function_managers import keccak_function_manager
from mythril.laser.smt.model import Model
from copy import copy
from typing import Iterable, List, Optional, Union
@ -42,7 +42,7 @@ class Constraints(list):
return False
return True
def get_model(self, solver_timeout=None) -> bool:
def get_model(self, solver_timeout=None) -> Optional[Model]:
"""
:param solver_timeout: The default timeout uses analysis timeout from args.solver_timeout
:return: True/False based on the existence of solution of constraints

@ -31,7 +31,7 @@ class Memory:
def __init__(self):
""""""
self._msize = 0
self._memory = {} # type: Dict[BitVec, Union[int, BitVec]]
self._memory: Dict[BitVec, Union[int, BitVec]] = {}
def __len__(self):
"""
@ -179,7 +179,7 @@ class Memory:
step = 1
else:
assert False, "Currently mentioning step size is not supported"
assert type(value) == list
assert isinstance(value, list)
bvstart, bvstop, bvstep = (
convert_bv(start),
convert_bv(stop),

@ -29,12 +29,12 @@ class WorldState:
:param transaction_sequence:
:param annotations:
"""
self._accounts = {} # type: Dict[int, Account]
self._accounts: Dict[int, Account] = {}
self.balances = Array("balance", 256, 256)
self.starting_balances = deepcopy(self.balances)
self.constraints = constraints or Constraints()
self.node = None # type: Optional['Node']
self.node: Optional["Node"] = None
self.transaction_sequence = transaction_sequence or []
self._annotations = annotations or []

@ -9,8 +9,8 @@ class BasicSearchStrategy(ABC):
"""
def __init__(self, work_list, max_depth, **kwargs):
self.work_list = work_list # type: List[GlobalState]
self.max_depth = max_depth
self.work_list: List[GlobalState] = work_list
self.max_depth: int = max_depth
def __iter__(self):
return self

@ -28,12 +28,11 @@ class DelayConstraintStrategy(BasicSearchStrategy):
:return: Global state
"""
while True:
while len(self.work_list) == 0:
state = self.pending_worklist.pop(0)
model = state.world_state.constraints.get_model()
if model is not None:
self.model_cache.put(model, 1)
self.work_list.append(state)
state = self.work_list.pop(0)
return state
while len(self.work_list) == 0:
state = self.pending_worklist.pop(0)
model = state.world_state.constraints.get_model()
if model is not None:
self.model_cache.put(model, 1)
self.work_list.append(state)
state = self.work_list.pop(0)
return state

@ -14,8 +14,8 @@ class JumpdestCountAnnotation(StateAnnotation):
"""State annotation that counts the number of jumps per destination."""
def __init__(self) -> None:
self._reached_count = {} # type: Dict[int, int]
self.trace = [] # type: List[int]
self._reached_count: Dict[int, int] = {}
self.trace: List[int] = []
def __copy__(self):
result = JumpdestCountAnnotation()

@ -238,7 +238,7 @@ class LaserEVM:
for hook in self._stop_exec_trans_hooks:
hook()
def _execute_transactions_ordered(self, address):
def _execute_transactions_non_ordered(self, address):
"""
This function executes multiple transactions non-incrementally, using some type priority ordering
@ -327,7 +327,7 @@ class LaserEVM:
:param track_gas:
:return:
"""
final_states = [] # type: List[GlobalState]
final_states: List[GlobalState] = []
for hook in self._start_exec_hooks:
hook()
@ -387,7 +387,7 @@ class LaserEVM:
# exceptional halt all changes should be discarded, and this world state would not provide us with a
# previously unseen world state
log.debug("Encountered a VmException, ending path: `{}`".format(error_msg))
new_global_states = [] # type: List[GlobalState]
new_global_states: List[GlobalState] = []
else:
# First execute the post hook for the transaction ending instruction
self._execute_post_hook(op_code, [global_state])

@ -105,7 +105,7 @@ class BaseTransaction:
self.caller = caller
self.callee_account = callee_account
if call_data is None and init_call_data:
self.call_data = SymbolicCalldata(self.id) # type: BaseCalldata
self.call_data: BaseCalldata = SymbolicCalldata(self.id)
else:
self.call_data = (
call_data
@ -119,7 +119,7 @@ class BaseTransaction:
else symbol_factory.BitVecSym(f"callvalue{identifier}", 256)
)
self.static = static
self.return_data = None # type: str
self.return_data: Optional[ReturnData] = None
def initial_global_state_from_environment(self, environment, active_function):
"""
@ -278,7 +278,9 @@ class ContractCreationTransaction(BaseTransaction):
tuple(return_data.return_data)
)
return_data = str(hex(global_state.environment.active_account.address.value))
self.return_data = ReturnData(return_data, len(return_data) // 2)
self.return_data: Optional[ReturnData] = ReturnData(
return_data, symbol_factory.BitVecVal(len(return_data) // 2, 256)
)
assert global_state.environment.active_account.code.instruction_list != []
raise TransactionEndSignal(global_state, revert=revert)

@ -143,7 +143,7 @@ def concrete_int_to_bytes(val):
:return:
"""
# logging.debug("concrete_int_to_bytes " + str(val))
if type(val) == int:
if isinstance(val, int):
return val.to_bytes(32, byteorder="big")
return simplify(val).value.to_bytes(32, byteorder="big")

@ -17,9 +17,9 @@ class LaserPluginLoader(object, metaclass=Singleton):
def __init__(self) -> None:
"""Initializes the plugin loader"""
self.laser_plugin_builders = {} # type: Dict[str, PluginBuilder]
self.plugin_args = {} # type: Dict[str, Dict]
self.plugin_list = {} # type: Dict[str, LaserPlugin]
self.laser_plugin_builders: Dict[str, PluginBuilder] = {}
self.plugin_args: Dict[str, Dict] = {}
self.plugin_list: Dict[str, LaserPlugin] = {}
def add_args(self, plugin_name, **kwargs):
self.plugin_args[plugin_name] = kwargs

@ -30,7 +30,7 @@ class InstructionCoveragePlugin(LaserPlugin):
"""
def __init__(self):
self.coverage = {} # type: Dict[str, Tuple[int, List[bool]]]
self.coverage: Dict[str, Tuple[int, List[bool]]] = {}
self.initial_coverage = 0
self.tx_id = 0
@ -54,7 +54,7 @@ class InstructionCoveragePlugin(LaserPlugin):
else:
cov_percentage = sum(code_cov[1]) / float(code_cov[0]) * 100
string_code = code
if type(code) == tuple:
if isinstance(code, tuple):
try:
string_code = bytearray(code).hex()
except TypeError:

@ -95,10 +95,10 @@ class DependencyPruner(LaserPlugin):
def _reset(self):
self.iteration = 0
self.calls_on_path = {} # type: Dict[int, bool]
self.sloads_on_path = {} # type: Dict[int, List[object]]
self.sstores_on_path = {} # type: Dict[int, List[object]]
self.storage_accessed_global = set() # type: Set
self.calls_on_path: Dict[int, bool] = {}
self.sloads_on_path: Dict[int, List[object]] = {}
self.sstores_on_path: Dict[int, List[object]] = {}
self.storage_accessed_global: Set = set()
def update_sloads(self, path: List[int], target_location: object) -> None:
"""Update the dependency map for the block offsets on the given path.

@ -31,11 +31,11 @@ class DependencyAnnotation(MergeableStateAnnotation):
"""
def __init__(self):
self.storage_loaded = set() # type: Set
self.storage_written = {} # type: Dict[int, Set]
self.has_call = False # type: bool
self.path = [0] # type: List
self.blocks_seen = set() # type: Set[int]
self.storage_loaded: Set = set()
self.storage_written: Dict[int, Set] = {}
self.has_call: bool = False
self.path: List = [0]
self.blocks_seen: Set[int] = set()
def __copy__(self):
result = DependencyAnnotation()

@ -190,7 +190,7 @@ def Sum(*args: BitVec) -> BitVec:
:return:
"""
raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations
annotations: Annotations = set()
for bv in args:
annotations = annotations.union(bv.annotations)

@ -97,7 +97,7 @@ class Bool(Expression[z3.BoolRef]):
def And(*args: Union[Bool, bool]) -> Bool:
"""Create an And expression."""
annotations = set() # type: Set
annotations: Set = set()
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
for arg in args_list:
annotations = annotations.union(arg.annotations)
@ -119,7 +119,7 @@ def Or(*args: Union[Bool, bool]) -> Bool:
:return:
"""
args_list = [arg if isinstance(arg, Bool) else Bool(arg) for arg in args]
annotations = set() # type: Set
annotations: Set = set()
for arg in args_list:
annotations = annotations.union(arg.annotations)
return Bool(z3.Or([a.raw for a in args_list]), annotations=annotations)

@ -19,7 +19,7 @@ class Model:
def decls(self) -> List[z3.ExprRef]:
"""Get the declarations for this model"""
result = [] # type: List[z3.ExprRef]
result: List[z3.ExprRef] = []
for internal_model in self.raw:
result.extend(internal_model.decls())
return result

@ -31,8 +31,8 @@ class DependenceBucket:
:param variables: Variables contained in the conditions
:param conditions: The conditions that are dependent on each other
"""
self.variables = variables or [] # type: List[z3.ExprRef]
self.conditions = conditions or [] # type: List[z3.ExprRef]
self.variables: List[z3.ExprRef] = variables or []
self.conditions: List[z3.ExprRef] = conditions or []
class DependenceMap:
@ -40,8 +40,8 @@ class DependenceMap:
def __init__(self):
"""Initializes a DependenceMap object"""
self.buckets = [] # type: List[DependenceBucket]
self.variable_map = {} # type: Dict[str, DependenceBucket]
self.buckets: List[DependenceBucket] = []
self.variable_map: Dict[str, DependenceBucket] = {}
def add_condition(self, condition: z3.BoolRef) -> None:
"""
@ -70,8 +70,8 @@ class DependenceMap:
def _merge_buckets(self, bucket_list: Set[DependenceBucket]) -> DependenceBucket:
"""Merges the buckets in bucket list"""
variables = [] # type: List[str]
conditions = [] # type: List[z3.BoolRef]
variables: List[str] = []
conditions: List[z3.BoolRef] = []
for bucket in bucket_list:
self.buckets.remove(bucket)
variables += bucket.variables
@ -100,14 +100,14 @@ class IndependenceSolver:
"""
self.raw.set(timeout=timeout)
def add(self, *constraints: Tuple[Bool]) -> None:
def add(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver.
:param constraints: constraints to add
"""
raw_constraints = [
raw_constraints: List[z3.BoolRef] = [
c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef]
]
self.constraints.extend(raw_constraints)
def append(self, *constraints: Tuple[Bool]) -> None:
@ -115,9 +115,9 @@ class IndependenceSolver:
:param constraints: constraints to add
"""
raw_constraints = [
raw_constraints: List[z3.BoolRef] = [
c.raw for c in cast(Tuple[Bool], constraints)
] # type: List[z3.BoolRef]
]
self.constraints.extend(raw_constraints)
@stat_smt_query

@ -28,18 +28,18 @@ class BaseSolver(Generic[T]):
"""
self.raw.set(timeout=timeout)
def add(self, *constraints: List[Bool]) -> None:
def add(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver.
:param constraints:
:return:
"""
z3_constraints = [
z3_constraints: Sequence[z3.BoolRef] = [
c.raw for c in cast(List[Bool], constraints)
] # type: Sequence[z3.BoolRef]
]
self.raw.add(z3_constraints)
def append(self, *constraints: List[Bool]) -> None:
def append(self, *constraints: Bool) -> None:
"""Adds the constraints to this solver.
:param constraints:

@ -47,7 +47,7 @@ class MythrilAnalyzer:
:param address: Address of the contract
"""
self.eth = disassembler.eth
self.contracts = disassembler.contracts or [] # type: List[EVMContract]
self.contracts: List[EVMContract] = disassembler.contracts or []
self.enable_online_lookup = disassembler.enable_online_lookup
self.use_onchain_data = not cmd_args.no_onchain_data
self.strategy = strategy
@ -141,10 +141,10 @@ class MythrilAnalyzer:
:param transaction_count: The amount of transactions to be executed
:return: The Report class which contains the all the issues/vulnerabilities
"""
all_issues = [] # type: List[Issue]
all_issues: List[Issue] = []
SolverStatistics().enabled = True
exceptions = []
execution_info = None # type: Optional[List[ExecutionInfo]]
execution_info: Optional[List[ExecutionInfo]] = None
for contract in self.contracts:
StartTime() # Reinitialize start time for new contracts
try:

@ -22,11 +22,11 @@ class MythrilConfig:
"""
def __init__(self):
self.infura_id = os.getenv("INFURA_ID") # type: str
self.infura_id: str = os.getenv("INFURA_ID")
self.mythril_dir = self.init_mythril_dir()
self.config_path = os.path.join(self.mythril_dir, "config.ini")
self._init_config()
self.eth = None # type: Optional[EthJsonRpc]
self.eth: Optional[EthJsonRpc] = None
def set_api_infura_id(self, id):
self.infura_id = id

@ -30,11 +30,11 @@ from mythril.solidity.soliditycontract import (
from mythril.support.support_args import args
def format_Warning(message, category, filename, lineno, line=""):
def format_warning(message, category, filename, lineno, line=""):
return "{}: {}\n\n".format(str(filename), str(message))
warnings.formatwarning = format_Warning
warnings.formatwarning = format_warning
log = logging.getLogger(__name__)
@ -62,7 +62,7 @@ class MythrilDisassembler:
self.eth = eth
self.enable_online_lookup = enable_online_lookup
self.sigs = signatures.SignatureDB(enable_online_lookup=enable_online_lookup)
self.contracts = [] # type: List[EVMContract]
self.contracts: List[EVMContract] = []
@staticmethod
def _init_solc_binary(version: str) -> Optional[str]:
@ -190,7 +190,6 @@ class MythrilDisassembler:
build_dir = os.path.join(project_root, "artifacts", "contracts", "build-info")
files = os.listdir(build_dir)
address = util.get_indexed_address(0)
files = sorted(

@ -12,7 +12,7 @@ class PluginDiscovery(object, metaclass=Singleton):
"""
# Installed plugins structure. Retrieves all modules that have an entry point for mythril.plugins
_installed_plugins = None # type: Optional[Dict[str, Any]]
_installed_plugins: Optional[Dict[str, Any]] = None
def init_installed_plugins(self):
self._installed_plugins = {

@ -27,7 +27,7 @@ class MythrilPluginLoader(object, metaclass=Singleton):
def __init__(self):
log.info("Initializing mythril plugin loader")
self.loaded_plugins = []
self.plugin_args = dict() # type: Dict[str, Dict]
self.plugin_args: Dict[str, Dict] = dict()
self._load_default_enabled()
def set_args(self, plugin_name: str, **kwargs):

@ -203,7 +203,7 @@ class SolidityFeatureExtractor:
return variables
def extract_address_variable(self, node):
if type(node) == int:
if isinstance(node, int):
return set([])
transfer_vars = set([])
if (
@ -211,15 +211,13 @@ class SolidityFeatureExtractor:
and node.get("expression", {}).get("nodeType") == "FunctionCall"
):
expression = node["expression"].get("expression", None)
if expression is not None:
if (
expression["nodeType"] == "MemberAccess"
and expression["memberName"] in TRANSFER_METHODS
):
print(expression)
address_variable = expression["expression"].get("name")
if address_variable:
transfer_vars.update(set([address_variable]))
if expression is not None and (
expression["nodeType"] == "MemberAccess"
and expression["memberName"] in TRANSFER_METHODS
):
address_variable = expression["expression"].get("name")
if address_variable:
transfer_vars.update(set([address_variable]))
for key, value in node.items():
if isinstance(value, dict):

@ -40,7 +40,7 @@ class DynLoader:
value = self.eth.eth_getStorageAt(
contract_address, position=index, block="latest"
)
if value == "0x":
if value.startswith("0x"):
value = "0x0000000000000000000000000000000000000000000000000000000000000000"
return value
@ -96,7 +96,7 @@ class DynLoader:
code = self.eth.eth_getCode(dependency_address)
if code == "0x":
if code.startswith("0x"):
return None
else:
return Disassembly(code)

@ -86,12 +86,16 @@ def get_model(
if solver_timeout <= 0:
raise SolverTimeOutException
for constraint in constraints:
if type(constraint) == bool and not constraint:
if isinstance(constraint, bool) and not constraint:
raise UnsatError
if type(constraints) != tuple:
if isinstance(constraints, tuple) is False:
constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool]
constraints = [
constraint
for constraint in constraints
if isinstance(constraint, bool) is False
]
if len(maximize) + len(minimize) == 0:
ret_model = model_cache.check_quick_sat(simplify(And(*constraints)).raw)

@ -44,7 +44,7 @@ def synchronized(sync_lock):
class Singleton(type):
"""A metaclass type implementing the singleton pattern."""
_instances = dict() # type: Dict[Singleton, Singleton]
_instances: Dict["Singleton", "Singleton"] = dict()
@synchronized(lock)
def __call__(cls, *args, **kwargs):
@ -123,12 +123,12 @@ class SignatureDB(object, metaclass=Singleton):
:param path:
"""
self.enable_online_lookup = enable_online_lookup
self.online_lookup_miss = set() # type: Set[str]
self.online_lookup_miss: Set[str] = set()
self.online_lookup_timeout = 0
# if we're analysing a Solidity file, store its hashes
# here to prevent unnecessary lookups
self.solidity_sigs = defaultdict(list) # type: DefaultDict[str, List[str]]
self.solidity_sigs: DefaultDict[str, List[str]] = defaultdict(list)
if path is None:
self.path = os.environ.get("MYTHRIL_DIR") or os.path.join(
os.path.expanduser("~"), ".mythril"

@ -38,7 +38,7 @@ class Source:
self.source_format = "evm-byzantium-bytecode"
self.source_type = (
"ethereum-address"
if len(contracts[0].name) == 42 and contracts[0].name[0:2] == "0x"
if len(contracts[0].name) == 42 and contracts[0].name.startswith("0x")
else "raw-bytecode"
)
for contract in contracts:

@ -14,7 +14,7 @@ log = logging.getLogger(__name__)
class Singleton(type):
"""A metaclass type implementing the singleton pattern."""
_instances = {} # type: Dict
_instances: Dict = {}
def __call__(cls, *args, **kwargs):
"""Delegate the call to an existing resource or a a new one.
@ -59,7 +59,6 @@ class ModelCache:
@lru_cache(maxsize=2**10)
def check_quick_sat(self, constraints) -> bool:
model_list = list(reversed(self.model_cache.lru_cache.keys()))
for model in reversed(self.model_cache.lru_cache.keys()):
model_copy = deepcopy(model)
if is_true(model_copy.eval(constraints, model_completion=True)):
@ -77,11 +76,11 @@ def get_code_hash(code) -> str:
:param code: bytecode
:return: Returns hash of the given bytecode
"""
if type(code) == tuple:
if isinstance(code, tuple):
# Temporary hack, since we cannot find symbols of sha3
return str(hash(code))
code = code[2:] if code[:2] == "0x" else code
code = code[2:] if code.startswith("0x") else code
try:
hash_ = keccak(bytes.fromhex(code))
return "0x" + hash_.hex()
@ -91,8 +90,8 @@ def get_code_hash(code) -> str:
def sha3(value):
if type(value) == str:
if value[:2] == "0x":
if isinstance(value, str):
if value.startswith("0x"):
new_hash = keccak(bytes.fromhex(value))
else:
new_hash = keccak(value.encode())

@ -61,7 +61,6 @@ test_cases = [
def test_feature_selfdestruct(file_name, num_funcs, func_name, field, expected_value):
input_file = TEST_FILES / file_name
name = file_name.split(".")[0]
print(name, name.capitalize())
if name[0].islower():
name = name.capitalize()
contract = SolidityContract(str(input_file), name=name, solc_binary=solc_binary)

@ -182,8 +182,8 @@ def test_vmtest(
actual = actual.value
actual = 1 if actual is True else 0 if actual is False else actual
else:
if type(actual) == bytes:
if isinstance(actual, bytes):
actual = int(binascii.b2a_hex(actual), 16)
elif type(actual) == str:
elif isinstance(actual, str):
actual = int(actual, 16)
assert actual == expected

@ -5,7 +5,6 @@ from unittest.mock import Mock, patch, mock_open
def mock_predict_proba(X):
print(X)
if X[0][-1] == 1:
return np.array([[0.1, 0.7, 0.1, 0.1]])
elif X[0][-1] == 2:

Loading…
Cancel
Save