Add multithreaded contract sync, refactor RPC code

pull/94/head
Bernhard Mueller 7 years ago
parent 91b8b616b1
commit bf8301b66b
  1. 61
      myth
  2. 2
      mythril/analysis/modules/ether_send.py
  3. 104
      mythril/ether/contractstorage.py
  4. 14
      mythril/ipc/client.py
  5. 2
      mythril/rpc/client.py

61
myth

@ -79,7 +79,6 @@ utilities.add_argument('--solv', help='specify solidity compiler version. If not
options = parser.add_argument_group('options') options = parser.add_argument_group('options')
options.add_argument('-m', '--modules', help='Comma-separated list of security analysis modules', metavar='MODULES') options.add_argument('-m', '--modules', help='Comma-separated list of security analysis modules', metavar='MODULES')
options.add_argument('--sync-all', action='store_true', help='Also sync contracts with zero balance')
options.add_argument('--max-depth', type=int, default=12, help='Maximum recursion depth for symbolic execution') options.add_argument('--max-depth', type=int, default=12, help='Maximum recursion depth for symbolic execution')
options.add_argument('--solc-args', help='Extra arguments for solc') options.add_argument('--solc-args', help='Extra arguments for solc')
options.add_argument('--phrack', action='store_true', help='Phrack-style call graph') options.add_argument('--phrack', action='store_true', help='Phrack-style call graph')
@ -87,13 +86,10 @@ options.add_argument('--enable-physics', action='store_true', help='enable graph
options.add_argument('-v', type=int, help='log level (0-2)', metavar='LOG_LEVEL') options.add_argument('-v', type=int, help='log level (0-2)', metavar='LOG_LEVEL')
rpc = parser.add_argument_group('RPC options') rpc = parser.add_argument_group('RPC options')
rpc.add_argument('--rpc', help='connect via RPC', metavar='HOST:PORT') rpc.add_argument('-i', action='store_true', help='Preset: Infura Node service (Mainnet)')
rpc.add_argument('--rpc', help='custom RPC settings', metavar='HOST:PORT / ganache / infura-[network_name]')
rpc.add_argument('--rpctls', type=bool, default=False, help='RPC connection over TLS') rpc.add_argument('--rpctls', type=bool, default=False, help='RPC connection over TLS')
rpc.add_argument('--ganache', action='store_true', help='Preset: local Ganache') rpc.add_argument('--ipc', action='store_true', help='Connect via local IPC')
rpc.add_argument('-i', '--infura-mainnet', action='store_true', help='Preset: Infura Node service (Mainnet)')
rpc.add_argument('--infura-rinkeby', action='store_true', help='Preset: Infura Node service (Rinkeby)')
rpc.add_argument('--infura-kovan', action='store_true', help='Preset: Infura Node service (Kovan)')
rpc.add_argument('--infura-ropsten', action='store_true', help='Preset: Infura Node service (Ropsten)')
# Get config values # Get config values
@ -186,32 +182,53 @@ else:
# Establish RPC/IPC connection if necessary # Establish RPC/IPC connection if necessary
eth = None
if args.address or args.init_db: if args.address or args.init_db:
if args.infura_mainnet:
if args.i:
eth = EthJsonRpc('mainnet.infura.io', 443, True) eth = EthJsonRpc('mainnet.infura.io', 443, True)
elif args.infura_rinkeby: logging.info("Using INFURA for RPC queries")
eth = EthJsonRpc('rinkeby.infura.io', 443, True)
elif args.infura_kovan:
eth = EthJsonRpc('kovan.infura.io', 443, True)
elif args.infura_ropsten:
eth = EthJsonRpc('ropsten.infura.io', 443, True)
elif args.ganache:
eth = EthJsonRpc('localhost', 7545, False)
elif args.rpc: elif args.rpc:
if args.rpc == 'ganache':
rpcconfig = ('localhost', 7545, False)
else:
m = re.match(r'infura-(.*)', args.rpc)
if m and m.group(1) in ['mainnet', 'rinkeby', 'kovan', 'ropsten']:
rpcconfig = (m.group(1) + '.infura.io', 443, True)
else:
try: try:
host, port = args.rpc.split(":") host, port = args.rpc.split(":")
rpcconfig = (host, port, args.rpctls)
except ValueError: except ValueError:
exitWithError(args.outform, "Invalid RPC argument, use HOST:PORT") exitWithError(args.outform, "Invalid RPC argument, use HOST:PORT")
rpcconfig = (host, int(port), args.tls)
if (rpcconfig):
eth = EthJsonRpc(rpcconfig[0], int(rpcconfig[1]), rpcconfig[2])
logging.info("Using RPC settings: %s" % str(rpcconfig))
else: else:
tls = args.rpctls exitWithError(args.outform, "Invalid RPC settings, check help for details.")
eth = EthJsonRpc(host, int(port), tls)
else: elif args.ipc:
try: try:
eth = EthIpc() eth = EthIpc()
except Exception as e: except Exception as e:
exitWithError(args.outform, "IPC initialization failed. Please verify that your local Ethereum node is running, or use the -i flag to connect to INFURA. \n" + str(e)) exitWithError(args.outform, "IPC initialization failed. Please verify that your local Ethereum node is running, or use the -i flag to connect to INFURA. \n" + str(e))
else: # Default configuration if neither RPC or IPC are set
eth = EthJsonRpc('localhost', 8545)
logging.info("Using default RPC settings: http://localhost:8545")
# Database search ops # Database search ops
if args.search or args.init_db: if args.search or args.init_db:
@ -223,7 +240,7 @@ if args.search or args.init_db:
exitWithError(args.outform, "Syntax error in search expression.") exitWithError(args.outform, "Syntax error in search expression.")
elif args.init_db: elif args.init_db:
try: try:
contract_storage.initialize(eth, args.sync_all) contract_storage.initialize(eth)
except FileNotFoundError as e: except FileNotFoundError as e:
exitWithError(args.outform, "Error syncing database over IPC: " + str(e)) exitWithError(args.outform, "Error syncing database over IPC: " + str(e))
except ConnectionError as e: except ConnectionError as e:
@ -256,7 +273,7 @@ elif args.address:
except Exception as e: except Exception as e:
exitWithError(args.outform, "IPC / RPC error: " + str(e)) exitWithError(args.outform, "IPC / RPC error: " + str(e))
else: else:
if code == "0x": if code == "0x" or code == "0x0":
exitWithError(args.outform, "Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain.") exitWithError(args.outform, "Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain.")
else: else:
contracts.append(ETHContract(code, name=args.address)) contracts.append(ETHContract(code, name=args.address))

@ -55,7 +55,7 @@ def execute(statespace):
if (m): if (m):
idx = m.group(1) idx = m.group(1)
description += "a non-zero amount of Ether is sent to an address taken from storage slot " + str(idx) description += "a non-zero amount of Ether is sent to an address taken from storage slot " + str(idx) + ".\n"
func = statespace.find_storage_write(idx) func = statespace.find_storage_write(idx)

@ -1,15 +1,18 @@
from mythril.rpc.client import EthJsonRpc
from mythril.ipc.client import EthIpc
from mythril.ether.ethcontract import ETHContract, InstanceList
import hashlib
import os import os
import time import hashlib
import persistent import persistent
import persistent.list import persistent.list
import transaction import transaction
from BTrees.OOBTree import BTree from BTrees.OOBTree import BTree
import ZODB import ZODB
from ZODB import FileStorage from ZODB import FileStorage
from multiprocessing import Pool
import logging
from mythril.ether.ethcontract import ETHContract, InstanceList
BLOCKS_PER_THREAD = 256
NUM_THREADS = 8
def get_persistent_storage(db_dir=None): def get_persistent_storage(db_dir=None):
@ -42,80 +45,93 @@ class ContractStorage(persistent.Persistent):
self.contracts = BTree() self.contracts = BTree()
self.instance_lists = BTree() self.instance_lists = BTree()
self.last_block = 0 self.last_block = 0
self.eth = None
def get_contract_by_hash(self, contract_hash): def get_contract_by_hash(self, contract_hash):
return self.contracts[contract_hash] return self.contracts[contract_hash]
def sync_blocks(self, startblock):
logging.info("SYNC_BLOCKS %d to %d" % (startblock, startblock + BLOCKS_PER_THREAD))
def initialize(self, eth, sync_all): contracts = {}
if self.last_block:
blockNum = self.last_block
print("Resuming synchronization from block " + str(blockNum))
else:
blockNum = eth.eth_blockNumber()
print("Starting synchronization from latest block: " + str(blockNum))
'''
On INFURA, the latest block is not immediately available. Here is a workaround to allow for database sync over INFURA.
Note however that this is extremely slow, contracts should always be loaded from a local node.
'''
block = eth.eth_getBlockByNumber(blockNum)
if not block:
blockNum -= 2
while(blockNum > 0): for blockNum in range(startblock, startblock + BLOCKS_PER_THREAD):
block = self.eth.eth_getBlockByNumber(blockNum)
if not blockNum % 1000:
print("Processing block " + str(blockNum) + ", " + str(len(self.contracts.keys())) + " unique contracts in database")
block = eth.eth_getBlockByNumber(blockNum)
for tx in block['transactions']: for tx in block['transactions']:
if not tx['to']: if not tx['to']:
receipt = eth.eth_getTransactionReceipt(tx['hash']) receipt = self.eth.eth_getTransactionReceipt(tx['hash'])
if receipt is not None: if receipt is not None:
contract_address = receipt['contractAddress'] contract_address = receipt['contractAddress']
contract_code = eth.eth_getCode(contract_address) contract_code = self.eth.eth_getCode(contract_address)
contract_balance = eth.eth_getBalance(contract_address) contract_balance = self.eth.eth_getBalance(contract_address)
if not contract_balance and not sync_all: if not contract_balance:
# skip contracts with zero balance (disable with --sync-all)
continue continue
code = ETHContract(contract_code, tx['input']) ethcontract = ETHContract(contract_code, tx['input'])
m = hashlib.md5() m = hashlib.md5()
m.update(contract_code.encode('UTF-8')) m.update(contract_code.encode('UTF-8'))
contract_hash = m.digest() contract_hash = m.digest()
contracts[contract_hash] = {'ethcontract': ethcontract, 'address': contract_address, 'balance': contract_balance}
blockNum -= 1
return contracts
def initialize(self, eth):
self.eth = eth
if self.last_block:
blockNum = self.last_block
print("Resuming synchronization from block " + str(blockNum))
else:
blockNum = eth.eth_blockNumber()
print("Starting synchronization from latest block: " + str(blockNum))
processed = 0
while (blockNum > 0):
numbers = []
for i in range(1, NUM_THREADS + 1):
numbers.append(blockNum - (i * BLOCKS_PER_THREAD))
pool = Pool(NUM_THREADS)
results = pool.map(self.sync_blocks, numbers)
pool.close()
pool.join()
for result in results:
for (contract_hash, data) in result.items():
try: try:
self.contracts[contract_hash] self.contracts[contract_hash]
except KeyError: except KeyError:
self.contracts[contract_hash] = code self.contracts[contract_hash] = data['ethcontract']
m = InstanceList() m = InstanceList()
self.instance_lists[contract_hash] = m self.instance_lists[contract_hash] = m
self.instance_lists[contract_hash].add(contract_address, contract_balance) self.instance_lists[contract_hash].add(data['address'], data['balance'])
transaction.commit()
blockNum -= NUM_THREADS * BLOCKS_PER_THREAD
processed += NUM_THREADS * BLOCKS_PER_THREAD
self.last_block = blockNum self.last_block = blockNum
blockNum -= 1 transaction.commit()
print("%d blocks processed, %d unique contracts in database, next block: %d" % (processed, len(self.contracts), blockNum))
# If we've finished initializing the database, start over from the end of the chain if we want to initialize again # If we've finished initializing the database, start over from the end of the chain if we want to initialize again
self.last_block = 0 self.last_block = 0
transaction.commit()
def search(self, expression, callback_func): def search(self, expression, callback_func):

@ -25,13 +25,12 @@ class EthIpc(BaseClient):
if ipc_path is None: if ipc_path is None:
ipc_path = get_default_ipc_path(testnet) ipc_path = get_default_ipc_path(testnet)
self.ipc_path = ipc_path self.ipc_path = ipc_path
self._socket = self.get_socket()
def get_socket(self): def get_socket(self):
_socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) _socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
_socket.connect(self.ipc_path) _socket.connect(self.ipc_path)
# Tell the socket not to block on reads. # Tell the socket not to block on reads.
_socket.settimeout(2) _socket.settimeout(0.2)
return _socket return _socket
def _call(self, method, params=None, _id=1): def _call(self, method, params=None, _id=1):
@ -43,22 +42,25 @@ class EthIpc(BaseClient):
'id': _id, 'id': _id,
} }
request = to_bytes(json.dumps(data)) request = to_bytes(json.dumps(data))
_socket = self.get_socket()
for _ in range(3): for _ in range(3):
self._socket.sendall(request) _socket.sendall(request)
response_raw = "" response_raw = ""
while True: while True:
try: try:
response_raw += to_text(self._socket.recv(4096)) response_raw += to_text(_socket.recv(4096))
except socket.timeout: except socket.timeout:
break break
if response_raw == "": if response_raw == "":
self._socket.close() _socket.close()
self._socket = self.get_socket() _socket = self.get_socket()
continue continue
_socket.close()
break break
else: else:
raise ValueError("No JSON returned by socket") raise ValueError("No JSON returned by socket")

@ -16,6 +16,8 @@ JSON_MEDIA_TYPE = 'application/json'
''' '''
This code is adapted from: https://github.com/ConsenSys/ethjsonrpc This code is adapted from: https://github.com/ConsenSys/ethjsonrpc
''' '''
class EthJsonRpc(BaseClient): class EthJsonRpc(BaseClient):
''' '''
Ethereum JSON-RPC client class Ethereum JSON-RPC client class

Loading…
Cancel
Save