Merge pull request #555 from dmuhs/fix/remove-unused

Remove unused variables and fix name shadowing
pull/598/head
JoranHonig 6 years ago committed by GitHub
commit 7994b9cd16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 4
      mythril/analysis/modules/dependence_on_predictable_vars.py
  2. 4
      mythril/analysis/modules/ether_send.py
  3. 2
      mythril/analysis/modules/multiple_sends.py
  4. 1
      mythril/analysis/symbolic.py
  5. 8
      mythril/analysis/traceexplore.py
  6. 115
      mythril/ether/evm.py
  7. 17
      mythril/ethereum/interface/leveldb/accountindexing.py
  8. 61
      mythril/ethereum/interface/leveldb/client.py
  9. 32
      mythril/ethereum/interface/leveldb/state.py
  10. 4
      mythril/interfaces/cli.py
  11. 2
      mythril/laser/ethereum/call.py
  12. 10
      mythril/laser/ethereum/instructions.py
  13. 2
      mythril/laser/ethereum/state.py
  14. 11
      mythril/mythril.py
  15. 2
      tests/analysis/test_delegatecall.py
  16. 2
      tests/laser/state/mstack_test.py
  17. 2
      tests/laser/transaction/symbolic_test.py
  18. 3
      tests/taint_runner_test.py

@ -120,8 +120,8 @@ def solve(call):
model = solver.get_model(call.node.constraints) model = solver.get_model(call.node.constraints)
logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] MODEL: " + str(model)) logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] MODEL: " + str(model))
for d in model.decls(): for decl in model.decls():
logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] main model: %s = 0x%x" % (d.name(), model[d].as_long())) logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] main model: %s = 0x%x" % (decl.name(), model[decl].as_long()))
return True return True
except UnsatError: except UnsatError:

@ -111,8 +111,8 @@ def execute(statespace):
try: try:
model = solver.get_model(node.constraints) model = solver.get_model(node.constraints)
for d in model.decls(): for decl in model.decls():
logging.debug("[ETHER_SEND] main model: %s = 0x%x" % (d.name(), model[d].as_long())) logging.debug("[ETHER_SEND] main model: %s = 0x%x" % (decl.name(), model[decl].as_long()))
debug = "SOLVER OUTPUT:\n" + solver.pretty_print_model(model) debug = "SOLVER OUTPUT:\n" + solver.pretty_print_model(model)

@ -38,7 +38,7 @@ def execute(statespace):
def _explore_nodes(call, statespace): def _explore_nodes(call, statespace):
children = _child_nodes(statespace, call.node) children = _child_nodes(statespace, call.node)
sending_children = list(filter(lambda call: call.node in children, statespace.calls)) sending_children = list(filter(lambda c: c.node in children, statespace.calls))
return sending_children return sending_children

@ -16,7 +16,6 @@ class SymExecWrapper:
def __init__(self, contract, address, strategy, dynloader=None, max_depth=22, def __init__(self, contract, address, strategy, dynloader=None, max_depth=22,
execution_timeout=None, create_timeout=None, max_transaction_count=3): execution_timeout=None, create_timeout=None, max_transaction_count=3):
s_strategy = None
if strategy == 'dfs': if strategy == 'dfs':
s_strategy = DepthFirstSearchStrategy s_strategy = DepthFirstSearchStrategy
elif strategy == 'bfs': elif strategy == 'bfs':

@ -13,8 +13,8 @@ colors = [
{'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#fff', 'background': '#424db3'}}, {'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#fff', 'background': '#424db3'}},
] ]
def get_serializable_statespace(statespace):
def get_serializable_statespace(statespace):
nodes = [] nodes = []
edges = [] edges = []
@ -40,10 +40,10 @@ def get_serializable_statespace(statespace):
color = color_map[node.get_cfg_dict()['contract_name']] color = color_map[node.get_cfg_dict()['contract_name']]
def get_state_accounts(state): def get_state_accounts(node_state):
state_accounts = [] state_accounts = []
for key in state.accounts: for key in node_state.accounts:
account = state.accounts[key].as_dict account = node_state.accounts[key].as_dict
account.pop('code', None) account.pop('code', None)
account['balance'] = str(account['balance']) account['balance'] = str(account['balance'])

@ -7,69 +7,52 @@ from io import StringIO
import re import re
def trace(code, calldata = ""): def trace(code, calldata=""):
log_handlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage']
log_handlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage'] output = StringIO()
stream_handler = StreamHandler(output)
output = StringIO()
stream_handler = StreamHandler(output) for handler in log_handlers:
log_vm_op = get_logger(handler)
for handler in log_handlers: log_vm_op.setLevel("TRACE")
log_vm_op = get_logger(handler) log_vm_op.addHandler(stream_handler)
log_vm_op.setLevel("TRACE")
log_vm_op.addHandler(stream_handler) addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567')
state = State()
addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567')
ext = messages.VMExt(state, transactions.Transaction(0, 0, 21000, addr, 0, addr))
state = State() message = vm.Message(addr, addr, 0, 21000, calldata)
vm.vm_execute(ext, message, util.safe_decode(code))
ext = messages.VMExt(state, transactions.Transaction(0, 0, 21000, addr, 0, addr)) stream_handler.flush()
ret = output.getvalue()
message = vm.Message(addr, addr, 0, 21000, calldata) lines = ret.split("\n")
res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code)) state_trace = []
for line in lines:
stream_handler.flush() m = re.search(r'pc=b\'(\d+)\'.*op=([A-Z0-9]+)', line)
if m:
ret = output.getvalue() pc = m.group(1)
op = m.group(2)
lines = ret.split("\n") m = re.match(r'.*stack=(\[.*?\])', line)
trace = [] if m:
stackitems = re.findall(r'b\'(\d+)\'', m.group(1))
for line in lines: stack = "["
m = re.search(r'pc=b\'(\d+)\'.*op=([A-Z0-9]+)', line) if len(stackitems):
for i in range(0, len(stackitems) - 1):
if m: stack += hex(int(stackitems[i])) + ", "
pc = m.group(1) stack += hex(int(stackitems[-1]))
op = m.group(2)
stack += "]"
m = re.match(r'.*stack=(\[.*?\])', line) else:
stack = "[]"
if m:
if re.match(r'^PUSH.*', op):
stackitems = re.findall(r'b\'(\d+)\'', m.group(1)) val = re.search(r'pushvalue=(\d+)', line).group(1)
pushvalue = hex(int(val))
stack = "[" state_trace.append({'pc': pc, 'op': op, 'stack': stack, 'pushvalue': pushvalue})
else:
if len(stackitems): state_trace.append({'pc': pc, 'op': op, 'stack': stack})
for i in range(0, len(stackitems) - 1): return state_trace
stack += hex(int(stackitems[i])) + ", "
stack += hex(int(stackitems[-1]))
stack += "]"
else:
stack = "[]"
if re.match(r'^PUSH.*', op):
val = re.search(r'pushvalue=(\d+)', line).group(1)
pushvalue = hex(int(val))
trace.append({'pc': pc, 'op': op, 'stack': stack, 'pushvalue': pushvalue})
else:
trace.append({'pc': pc, 'op': op, 'stack': stack})
return trace

@ -63,16 +63,15 @@ class AccountIndexer(object):
def get_contract_by_hash(self, contract_hash): def get_contract_by_hash(self, contract_hash):
""" """
get mapped address by its hash, if not found try indexing get mapped contract_address by its hash, if not found try indexing
""" """
address = self.db.reader._get_address_by_hash(contract_hash) contract_address = self.db.reader._get_address_by_hash(contract_hash)
if address is not None: if contract_address is not None:
return address return contract_address
else: else:
raise AddressNotFoundError raise AddressNotFoundError
return self.db.reader._get_address_by_hash(contract_hash)
def _process(self, startblock): def _process(self, startblock):
""" """
Processesing method Processesing method
@ -82,9 +81,9 @@ class AccountIndexer(object):
addresses = [] addresses = []
for blockNum in range(startblock, startblock + BATCH_SIZE): for blockNum in range(startblock, startblock + BATCH_SIZE):
hash = self.db.reader._get_block_hash(blockNum) block_hash = self.db.reader._get_block_hash(blockNum)
if hash is not None: if block_hash is not None:
receipts = self.db.reader._get_block_receipts(hash, blockNum) receipts = self.db.reader._get_block_receipts(block_hash, blockNum)
for receipt in receipts: for receipt in receipts:
if receipt.contractAddress is not None and not all(b == 0 for b in receipt.contractAddress): if receipt.contractAddress is not None and not all(b == 0 for b in receipt.contractAddress):

@ -79,52 +79,43 @@ class LevelDBReader(object):
gets head block header gets head block header
""" """
if not self.head_block_header: if not self.head_block_header:
hash = self.db.get(head_header_key) block_hash = self.db.get(head_header_key)
num = self._get_block_number(hash) num = self._get_block_number(block_hash)
self.head_block_header = self._get_block_header(hash, num) self.head_block_header = self._get_block_header(block_hash, num)
# find header with valid state # find header with valid state
while not self.db.get(self.head_block_header.state_root) and self.head_block_header.prevhash is not None: while not self.db.get(self.head_block_header.state_root) and self.head_block_header.prevhash is not None:
hash = self.head_block_header.prevhash block_hash = self.head_block_header.prevhash
num = self._get_block_number(hash) num = self._get_block_number(block_hash)
self.head_block_header = self._get_block_header(hash, num) self.head_block_header = self._get_block_header(block_hash, num)
return self.head_block_header return self.head_block_header
def _get_block_number(self, hash): def _get_block_number(self, block_hash):
""" """Get block number by its hash"""
gets block number by hash number_key = block_hash_prefix + block_hash
"""
number_key = block_hash_prefix + hash
return self.db.get(number_key) return self.db.get(number_key)
def _get_block_header(self, hash, num): def _get_block_header(self, block_hash, num):
""" """Get block header by block header hash & number"""
get block header by block header hash & number header_key = header_prefix + num + block_hash
"""
header_key = header_prefix + num + hash
block_header_data = self.db.get(header_key) block_header_data = self.db.get(header_key)
header = rlp.decode(block_header_data, sedes=BlockHeader) header = rlp.decode(block_header_data, sedes=BlockHeader)
return header return header
def _get_address_by_hash(self, hash): def _get_address_by_hash(self, block_hash):
""" """Get mapped address by its hash"""
get mapped address by its hash address_key = address_prefix + block_hash
"""
address_key = address_prefix + hash
return self.db.get(address_key) return self.db.get(address_key)
def _get_last_indexed_number(self): def _get_last_indexed_number(self):
""" """Get latest indexed block number"""
latest indexed block number
"""
return self.db.get(address_mapping_head_key) return self.db.get(address_mapping_head_key)
def _get_block_receipts(self, hash, num): def _get_block_receipts(self, block_hash, num):
""" """Get block transaction receipts by block header hash & number"""
get block transaction receipts by block header hash & number
"""
number = _format_block_number(num) number = _format_block_number(num)
receipts_key = block_receipts_prefix + number + hash receipts_key = block_receipts_prefix + number + block_hash
receipts_data = self.db.get(receipts_key) receipts_data = self.db.get(receipts_key)
receipts = rlp.decode(receipts_data, sedes=CountableList(ReceiptForStorage)) receipts = rlp.decode(receipts_data, sedes=CountableList(ReceiptForStorage))
return receipts return receipts
@ -216,12 +207,10 @@ class EthLevelDB(object):
if not cnt % 1000: if not cnt % 1000:
logging.info("Searched %d contracts" % cnt) logging.info("Searched %d contracts" % cnt)
def contract_hash_to_address(self, hash): def contract_hash_to_address(self, contract_hash):
""" """Tries to find corresponding account address"""
tries to find corresponding account address
"""
address_hash = binascii.a2b_hex(utils.remove_0x_head(hash)) address_hash = binascii.a2b_hex(utils.remove_0x_head(contract_hash))
indexer = AccountIndexer(self) indexer = AccountIndexer(self)
return _encode_hex(indexer.get_contract_by_hash(address_hash)) return _encode_hex(indexer.get_contract_by_hash(address_hash))
@ -230,9 +219,9 @@ class EthLevelDB(object):
""" """
gets block header by block number gets block header by block number
""" """
hash = self.reader._get_block_hash(number) block_hash = self.reader._get_block_hash(number)
block_number = _format_block_number(number) block_number = _format_block_number(number)
return self.reader._get_block_header(hash, block_number) return self.reader._get_block_header(block_hash, block_number)
def eth_getBlockByNumber(self, number): def eth_getBlockByNumber(self, number):
""" """

@ -43,9 +43,9 @@ class Account(rlp.Serializable):
('code_hash', hash32) ('code_hash', hash32)
] ]
def __init__(self, nonce, balance, storage, code_hash, db, address): def __init__(self, nonce, balance, storage, code_hash, db, addr):
self.db = db self.db = db
self.address = address self.address = addr
super(Account, self).__init__(nonce, balance, storage, code_hash) super(Account, self).__init__(nonce, balance, storage, code_hash)
self.storage_cache = {} self.storage_cache = {}
self.storage_trie = SecureTrie(Trie(self.db)) self.storage_trie = SecureTrie(Trie(self.db))
@ -73,12 +73,12 @@ class Account(rlp.Serializable):
return self.storage_cache[key] return self.storage_cache[key]
@classmethod @classmethod
def blank_account(cls, db, address, initial_nonce=0): def blank_account(cls, db, addr, initial_nonce=0):
""" """
creates a blank account creates a blank account
""" """
db.put(BLANK_HASH, b'') db.put(BLANK_HASH, b'')
o = cls(initial_nonce, 0, trie.BLANK_ROOT, BLANK_HASH, db, address) o = cls(initial_nonce, 0, trie.BLANK_ROOT, BLANK_HASH, db, addr)
o.existent_at_start = False o.existent_at_start = False
return o return o
@ -100,21 +100,21 @@ class State:
self.journal = [] self.journal = []
self.cache = {} self.cache = {}
def get_and_cache_account(self, address): def get_and_cache_account(self, addr):
""" """Gets and caches an account for an addres, creates blank if not found"""
gets and caches an account for an addres, creates blank if not found
""" if addr in self.cache:
if address in self.cache: return self.cache[addr]
return self.cache[address] rlpdata = self.secure_trie.get(addr)
rlpdata = self.secure_trie.get(address) if rlpdata == trie.BLANK_NODE and len(addr) == 32: # support for hashed addresses
if rlpdata == trie.BLANK_NODE and len(address) == 32: # support for hashed addresses rlpdata = self.trie.get(addr)
rlpdata = self.trie.get(address)
if rlpdata != trie.BLANK_NODE: if rlpdata != trie.BLANK_NODE:
o = rlp.decode(rlpdata, Account, db=self.db, address=address) o = rlp.decode(rlpdata, Account, db=self.db, address=addr)
else: else:
o = Account.blank_account( o = Account.blank_account(
self.db, address, 0) self.db, addr, 0)
self.cache[address] = o self.cache[addr] = o
o._mutable = True o._mutable = True
o._cached_rlp = None o._cached_rlp = None
return o return o

@ -17,8 +17,8 @@ from mythril.mythril import Mythril
from mythril.version import VERSION from mythril.version import VERSION
def exit_with_error(format, message): def exit_with_error(format_, message):
if format == 'text' or format == 'markdown': if format_ == 'text' or format_ == 'markdown':
print(message) print(message)
else: else:
result = {'success': False, 'error': str(message), 'issues': []} result = {'success': False, 'error': str(message), 'issues': []}

@ -99,7 +99,7 @@ def get_callee_account(global_state, callee_address, dynamic_loader):
try: try:
code = dynamic_loader.dynld(environment.active_account.address, callee_address) code = dynamic_loader.dynld(environment.active_account.address, callee_address)
except Exception as e: except Exception:
logging.debug("Unable to execute dynamic loader.") logging.debug("Unable to execute dynamic loader.")
raise ValueError() raise ValueError()
if code is None: if code is None:

@ -497,7 +497,6 @@ class Instruction:
global keccak_function_manager global keccak_function_manager
state = global_state.mstate state = global_state.mstate
environment = global_state.environment
op0, op1 = state.stack.pop(), state.stack.pop() op0, op1 = state.stack.pop(), state.stack.pop()
try: try:
@ -701,12 +700,9 @@ class Instruction:
try: try:
# Attempt to concretize value # Attempt to concretize value
_bytes = util.concrete_int_to_bytes(value) _bytes = util.concrete_int_to_bytes(value)
state.memory[mstart: mstart + len(_bytes)] = _bytes
state.memory[mstart:mstart+len(_bytes)] = _bytes except:
except (AttributeError, TypeError):
try: try:
state.memory[mstart] = value state.memory[mstart] = value
except TypeError: except TypeError:
@ -948,7 +944,7 @@ class Instruction:
state = global_state.mstate state = global_state.mstate
dpth = int(self.op_code[3:]) dpth = int(self.op_code[3:])
state.stack.pop(), state.stack.pop() state.stack.pop(), state.stack.pop()
[state.stack.pop() for x in range(dpth)] [state.stack.pop() for _ in range(dpth)]
# Not supported # Not supported
return [global_state] return [global_state]

@ -278,10 +278,10 @@ class GlobalState:
def new_bitvec(self, name, size=256): def new_bitvec(self, name, size=256):
transaction_id = self.current_transaction.id transaction_id = self.current_transaction.id
node_id = self.node.uid
return BitVec("{}_{}".format(transaction_id, name), size) return BitVec("{}_{}".format(transaction_id, name), size)
class WorldState: class WorldState:
""" """
The WorldState class represents the world state as described in the yellow paper The WorldState class represents the world state as described in the yellow paper

@ -87,7 +87,7 @@ class Mythril(object):
self.sigs = signatures.SignatureDb() self.sigs = signatures.SignatureDb()
try: try:
self.sigs.open() # tries mythril_dir/signatures.json by default (provide path= arg to make this configurable) self.sigs.open() # tries mythril_dir/signatures.json by default (provide path= arg to make this configurable)
except FileNotFoundError as fnfe: except FileNotFoundError:
logging.info( logging.info(
"No signature database found. Creating database if sigs are loaded in: " + self.sigs.signatures_file + "\n" + "No signature database found. Creating database if sigs are loaded in: " + self.sigs.signatures_file + "\n" +
"Consider replacing it with the pre-initialized database at https://raw.githubusercontent.com/ConsenSys/mythril/master/signatures.json") "Consider replacing it with the pre-initialized database at https://raw.githubusercontent.com/ConsenSys/mythril/master/signatures.json")
@ -261,8 +261,7 @@ class Mythril(object):
def search_db(self, search): def search_db(self, search):
def search_callback(contract, address, balance): def search_callback(_, address, balance):
print("Address: " + address + ", balance: " + str(balance)) print("Address: " + address + ", balance: " + str(balance))
try: try:
@ -290,10 +289,10 @@ class Mythril(object):
code = self.eth.eth_getCode(address) code = self.eth.eth_getCode(address)
except FileNotFoundError as e: except FileNotFoundError as e:
raise CriticalError("IPC error: " + str(e)) raise CriticalError("IPC error: " + str(e))
except ConnectionError as e: except ConnectionError:
raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.") raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.")
except Exception as e: except Exception as e:
raise CriticalError("IPC / RPC error: " + str(e)) raise CriticalError("IPC / RPC error: " + str(e))
else: else:
if code == "0x" or code == "0x0": if code == "0x" or code == "0x0":
raise CriticalError("Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain.") raise CriticalError("Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain.")
@ -435,7 +434,7 @@ class Mythril(object):
outtxt.append("{}: {}".format(hex(i), self.eth.eth_getStorageAt(address, i))) outtxt.append("{}: {}".format(hex(i), self.eth.eth_getStorageAt(address, i)))
except FileNotFoundError as e: except FileNotFoundError as e:
raise CriticalError("IPC error: " + str(e)) raise CriticalError("IPC error: " + str(e))
except ConnectionError as e: except ConnectionError:
raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.") raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.")
return '\n'.join(outtxt) return '\n'.join(outtxt)

@ -189,7 +189,7 @@ def test_delegate_call(sym_mock, concrete_mock, curr_instruction):
statespace.calls = [call] statespace.calls = [call]
# act # act
issues = execute(statespace) execute(statespace)
# assert # assert
assert concrete_mock.call_count == 1 assert concrete_mock.call_count == 1

@ -45,7 +45,7 @@ class MachineStackTest(BaseTestCase):
mstack = MachineStack([0, 1]) mstack = MachineStack([0, 1])
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
mstack = mstack + [2] mstack + [2]
@staticmethod @staticmethod
def test_mstack_no_support_iadd(): def test_mstack_no_support_iadd():

@ -49,7 +49,7 @@ def test_execute_contract_creation(mocked_setup: MagicMock):
mocked_setup.side_effect = _is_contract_creation mocked_setup.side_effect = _is_contract_creation
# Act # Act
new_account = execute_contract_creation(laser_evm, "606000") execute_contract_creation(laser_evm, "606000")
# Assert # Assert
# mocked_setup.assert_called() # mocked_setup.assert_called()

@ -6,6 +6,7 @@ from mythril.laser.ethereum.cfg import Node, Edge
from mythril.laser.ethereum.state import MachineState, Account, Environment, GlobalState from mythril.laser.ethereum.state import MachineState, Account, Environment, GlobalState
from mythril.laser.ethereum.svm import LaserEVM from mythril.laser.ethereum.svm import LaserEVM
def test_execute_state(mocker): def test_execute_state(mocker):
record = TaintRecord() record = TaintRecord()
record.stack = [True, False, True] record.stack = [True, False, True]
@ -54,8 +55,6 @@ def test_execute_node(mocker):
assert state_1 in record.states assert state_1 in record.states
def test_execute(mocker): def test_execute(mocker):
active_account = Account('0x00') active_account = Account('0x00')
environment = Environment(active_account, None, None, None, None, None) environment = Environment(active_account, None, None, None, None, None)

Loading…
Cancel
Save