From c44923c75a16e1f8a017974b589b2a2f9106bec6 Mon Sep 17 00:00:00 2001 From: Bernhard Mueller Date: Fri, 29 Sep 2017 15:54:38 +0700 Subject: [PATCH] Update backend database to ZODB --- .gitignore | 2 +- contractstorage.py | 72 ++++++++++++++-------------- database/leveldb.py | 112 ++++++++++++++++++++++++++++++++++++++++++++ ethcontract.py | 26 ++++++++-- mythril | 20 +++++++- 5 files changed, 190 insertions(+), 42 deletions(-) create mode 100644 database/leveldb.py diff --git a/.gitignore b/.gitignore index cbef371d..68d89dde 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,4 @@ build dist contracts.json hunt* -utils +*.fs* diff --git a/contractstorage.py b/contractstorage.py index 0e9a90ee..74fedec2 100644 --- a/contractstorage.py +++ b/contractstorage.py @@ -1,23 +1,36 @@ from rpc.client import EthJsonRpc -from ethcontract import ETHContract +from ethcontract import ETHCode, AddressesByCodeHash, CodeHashByAddress +from ether import util from ethereum import utils -from tinydb import TinyDB, Query import codecs import hashlib import re +import ZODB +import persistent +import persistent.list +import transaction +from BTrees.OOBTree import BTree -class ContractStorage: +class ContractStorage(persistent.Persistent): def __init__(self): - self.db = TinyDB('./contracts.json') - + self.contracts = BTree() + self.address_to_hash_map = BTree() + self.hash_to_addresses_map = BTree() + self.last_block = 0 def initialize(self, rpchost, rpcport): eth = EthJsonRpc(rpchost, rpcport) - blockNum = eth.eth_blockNumber() + if self.last_block: + blockNum = self.last_block + print("Resuming synchronization from block " + str(blockNum)) + else: + + blockNum = eth.eth_blockNumber() + print("Starting synchronization from latest block: " + str(blockNum)) while(blockNum > 0): @@ -35,51 +48,42 @@ class ContractStorage: contract_address = receipt['contractAddress'] contract_code = eth.eth_getCode(contract_address) - - m = hashlib.md5() - - m.update(contract_code.encode('UTF-8')) - - contract_hash = codecs.encode(m.digest(), 'hex_codec') - contract_id = contract_hash.decode("utf-8") - contract_balance = eth.eth_getBalance(contract_address) - Contract = Query() + code = ETHCode(contract_code) - new_instance = {'address': contract_address, 'balance': contract_balance} - - s = self.db.search(Contract.id == contract_id) - - if not len(s): - - self.db.insert({'id': contract_id, 'code': contract_code, 'instances': [new_instance]}) + m = hashlib.md5() + m.update(contract_code.encode('UTF-8')) + contract_hash = m.digest() - else: + try: + self.contracts[contract_hash] + except KeyError: + self.contracts[contract_hash] = code - instances = s[0]['instances'] + m = CodeHashByAddress(contract_hash, contract_balance) + self.address_to_hash_map[contract_address] = m - instances.append(new_instance) + m = AddressesByCodeHash(contract_address, contract_balance) + self.hash_to_addresses_map[contract_hash] = m - self.db.update({'instances': instances}, Contract.id == contract_id) + transaction.commit() + self.last_block = blockNum blockNum -= 1 - def get_contract_code_by_address(self, address): - Contract = Query() - Instance = Query() + def get_all(self): + return self.contracts - ret = self.db.search(Contract.instances.any(Instance.address == address)) + def get_contract_code_by_address(self, address): - return ret[0]['code'] + pass def search(self, expression, callback_func): - all_contracts = self.db.all() - matches = re.findall(r'func\[([a-zA-Z0-9\s,()]+)\]', expression) for m in matches: @@ -89,7 +93,7 @@ class ContractStorage: expression = expression.replace(m, sign_hash) - for c in all_contracts: + for c in self.contracts: for instance in c['instances']: diff --git a/database/leveldb.py b/database/leveldb.py new file mode 100644 index 00000000..0912aa6d --- /dev/null +++ b/database/leveldb.py @@ -0,0 +1,112 @@ +from ethereum.db import BaseDB +import leveldb +from ethereum import slogging + +slogging.set_level('db', 'debug') +log = slogging.get_logger('db') + +compress = decompress = lambda x: x + + +class LevelDB(BaseDB): + """ + filename the database directory + block_cache_size (default: 8 * (2 << 20)) maximum allowed size for the block cache in bytes + write_buffer_size (default 2 * (2 << 20)) + block_size (default: 4096) unit of transfer for the block cache in bytes + max_open_files: (default: 1000) + create_if_missing (default: True) if True, creates a new database if none exists + error_if_exists (default: False) if True, raises and error if the database exists + paranoid_checks (default: False) if True, raises an error as soon as an internal + corruption is detected + """ + + max_open_files = 32000 + block_cache_size = 8 * 1024**2 + write_buffer_size = 4 * 1024**2 + + def __init__(self, dbfile): + self.uncommitted = dict() + log.info('opening LevelDB', + path=dbfile, + block_cache_size=self.block_cache_size, + write_buffer_size=self.write_buffer_size, + max_open_files=self.max_open_files) + self.dbfile = dbfile + self.db = leveldb.LevelDB(dbfile, max_open_files=self.max_open_files) + self.commit_counter = 0 + + def reopen(self): + del self.db + self.db = leveldb.LevelDB(self.dbfile) + + def get(self, key): + log.trace('getting entry', key=key.encode('hex')[:8]) + if key in self.uncommitted: + if self.uncommitted[key] is None: + raise KeyError("key not in db") + log.trace('from uncommitted') + return self.uncommitted[key] + log.trace('from db') + o = decompress(self.db.Get(key)) + self.uncommitted[key] = o + return o + + def put(self, key, value): + log.trace('putting entry', key=key.encode('hex')[:8], len=len(value)) + self.uncommitted[key] = value + + def commit(self): + log.debug('committing', db=self) + batch = leveldb.WriteBatch() + for k, v in self.uncommitted.items(): + if v is None: + batch.Delete(k) + else: + batch.Put(k, compress(v)) + self.db.Write(batch, sync=False) + self.uncommitted.clear() + log.debug('committed', db=self, num=len(self.uncommitted)) + # self.commit_counter += 1 + # if self.commit_counter % 100 == 0: + # self.reopen() + + def delete(self, key): + log.trace('deleting entry', key=key) + self.uncommitted[key] = None + + def _has_key(self, key): + try: + self.get(key) + return True + except KeyError: + return False + + def __contains__(self, key): + return self._has_key(key) + + def __eq__(self, other): + return isinstance(other, self.__class__) and self.db == other.db + + def __repr__(self): + return '' % (id(self.db), len(self.uncommitted)) + + def inc_refcount(self, key, value): + self.put(key, value) + + def dec_refcount(self, key): + pass + + def revert_refcount_changes(self, epoch): + pass + + def commit_refcount_changes(self, epoch): + pass + + def cleanup(self, epoch): + pass + + def put_temporarily(self, key, value): + self.inc_refcount(key, value) + self.dec_refcount(key) + diff --git a/ethcontract.py b/ethcontract.py index 417af570..6dc95c9e 100644 --- a/ethcontract.py +++ b/ethcontract.py @@ -1,15 +1,14 @@ from ether import asm, util import re +import persistent -class ETHContract: - def __init__(self, code = "", balance = 0): +class ETHCode(persistent.Persistent): - self.disassembly = asm.disassemble(util.safe_decode(code)) - self.easm_code = asm.disassembly_to_easm(self.disassembly) - self.balance = balance + def __init__(self, code = ""): + self.disassembly = asm.disassemble(util.safe_decode(code)) def matches_expression(self, expression): @@ -45,3 +44,20 @@ class ETHContract: return eval(str_eval) + +class CodeHashByAddress(persistent.Persistent): + + def __init__(self, code_hash, balance = 0): + self.code_hash = code_hash + self.balance = balance + +class AddressesByCodeHash(persistent.Persistent): + + def __init__(self, address, balance = 0): + self.addresses = [address] + self.balances = [balance] + + def add(self, address, balance = 0): + self.addresses.append(address) + self.balances.append(balance) + self._p_changed = True diff --git a/mythril b/mythril index 731dc33c..17a271d9 100755 --- a/mythril +++ b/mythril @@ -10,6 +10,9 @@ from contractstorage import ContractStorage import sys import argparse from rpc.client import EthJsonRpc +import ZODB +from ZODB import FileStorage +import os def searchCallback(address): @@ -35,8 +38,21 @@ parser.add_argument('--rpchost', default='127.0.0.1', help='RPC host') parser.add_argument('--rpcport', type=int, default=8545, help='RPC port') -storage = ContractStorage() +app_root = os.path.dirname(os.path.realpath(__file__)) +db_path = os.path.join(app_root, "database", "contractstorage.fs") +storage = FileStorage.FileStorage(db_path) +db = ZODB.DB(storage) +connection = db.open() +storage_root = connection.root() + +try: + contract_storage = storage_root['contractStorage'] +except KeyError: + contract_storage = ContractStorage() + storage_root['contractStorage'] = contract_storage + +print(len(contract_storage.get_all())) args = parser.parse_args() @@ -104,7 +120,7 @@ elif (args.search): storage.search(args.search, searchCallback) elif (args.init_db): - storage.initialize(args.rpchost, args.rpcport) + contract_storage.initialize(args.rpchost, args.rpcport) elif (args.hash): print(utils.sha3(args.hash)[:4].hex())