Remove capital letters from local variables

pull/388/head
Robert Valta 6 years ago
parent b6beed98d4
commit 51207a6cfc
  1. 6
      mythril/analysis/modules/external_calls.py
  2. 4
      mythril/analysis/modules/unchecked_retval.py
  3. 10
      mythril/ether/evm.py
  4. 102
      mythril/leveldb/client.py
  5. 8
      mythril/leveldb/state.py
  6. 10
      mythril/mythril.py
  7. 4
      mythril/support/truffle.py

@ -21,11 +21,11 @@ def search_children(statespace, node, start_index=0, depth=0, results=[]):
if(depth < MAX_SEARCH_DEPTH): if(depth < MAX_SEARCH_DEPTH):
nStates = len(node.states) n_states = len(node.states)
if nStates > start_index: if n_states > start_index:
for j in range(start_index, nStates): for j in range(start_index, n_states):
if node.states[j].get_current_instruction()['opcode'] == 'SSTORE': if node.states[j].get_current_instruction()['opcode'] == 'SSTORE':
results.append(node.states[j].get_current_instruction()['address']) results.append(node.states[j].get_current_instruction()['address'])

@ -57,9 +57,9 @@ def execute(statespace):
else: else:
nStates = len(node.states) n_states = len(node.states)
for idx in range(0, nStates - 1): # Ignore CALLs at last position in a node for idx in range(0, n_states - 1): # Ignore CALLs at last position in a node
state = node.states[idx] state = node.states[idx]
instr = state.get_current_instruction() instr = state.get_current_instruction()

@ -9,15 +9,15 @@ import re
def trace(code, calldata = ""): def trace(code, calldata = ""):
logHandlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage'] log_handlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage']
output = StringIO() output = StringIO()
streamHandler = StreamHandler(output) stream_handler = StreamHandler(output)
for handler in logHandlers: for handler in log_handlers:
log_vm_op = get_logger(handler) log_vm_op = get_logger(handler)
log_vm_op.setLevel("TRACE") log_vm_op.setLevel("TRACE")
log_vm_op.addHandler(streamHandler) log_vm_op.addHandler(stream_handler)
addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567') addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567')
@ -29,7 +29,7 @@ def trace(code, calldata = ""):
res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code)) res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code))
streamHandler.flush() stream_handler.flush()
ret = output.getvalue() ret = output.getvalue()

@ -11,21 +11,21 @@ from mythril.ether.ethcontract import ETHContract
# Per https://github.com/ethereum/go-ethereum/blob/master/core/database_util.go # Per https://github.com/ethereum/go-ethereum/blob/master/core/database_util.go
# prefixes and suffixes for keys in geth # prefixes and suffixes for keys in geth
headerPrefix = b'h' # headerPrefix + num (uint64 big endian) + hash -> header header_prefix = b'h' # header_prefix + num (uint64 big endian) + hash -> header
bodyPrefix = b'b' # bodyPrefix + num (uint64 big endian) + hash -> block body body_prefix = b'b' # body_prefix + num (uint64 big endian) + hash -> block body
numSuffix = b'n' # headerPrefix + num (uint64 big endian) + numSuffix -> hash num_suffix = b'n' # header_prefix + num (uint64 big endian) + num_suffix -> hash
blockHashPrefix = b'H' # blockHashPrefix + hash -> num (uint64 big endian) block_hash_prefix = b'H' # block_hash_prefix + hash -> num (uint64 big endian)
blockReceiptsPrefix = b'r' # blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts block_receipts_prefix = b'r' # block_receipts_prefix + num (uint64 big endian) + hash -> block receipts
# known geth keys # known geth keys
headHeaderKey = b'LastBlock' # head (latest) header hash head_header_key = b'last_block' # head (latest) header hash
# custom prefixes # custom prefixes
addressPrefix = b'AM' # addressPrefix + hash -> address address_prefix = b'AM' # address_prefix + hash -> address
# custom keys # custom keys
addressMappingHeadKey = b'accountMapping' # head (latest) number of indexed block address_mapping_head_key = b'account_mapping' # head (latest) number of indexed block
headHeaderKey = b'LastBlock' # head (latest) header hash head_header_key = b'last_block' # head (latest) header hash
def _formatBlockNumber(number): def _format_block_number(number):
''' '''
formats block number to uint64 big endian formats block number to uint64 big endian
''' '''
@ -46,87 +46,87 @@ class LevelDBReader(object):
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.headBlockHeader = None self.head_block_header = None
self.headState = None self.head_state = None
def _get_head_state(self): def _get_head_state(self):
''' '''
gets head state gets head state
''' '''
if not self.headState: if not self.head_state:
root = self._get_head_block().state_root root = self._get_head_block().state_root
self.headState = State(self.db, root) self.head_state = State(self.db, root)
return self.headState return self.head_state
def _get_account(self, address): def _get_account(self, address):
''' '''
gets account by address gets account by address
''' '''
state = self._get_head_state() state = self._get_head_state()
accountAddress = binascii.a2b_hex(utils.remove_0x_head(address)) account_address = binascii.a2b_hex(utils.remove_0x_head(address))
return state.get_and_cache_account(accountAddress) return state.get_and_cache_account(account_address)
def _get_block_hash(self, number): def _get_block_hash(self, number):
''' '''
gets block hash by block number gets block hash by block number
''' '''
num = _formatBlockNumber(number) num = _format_block_number(number)
hashKey = headerPrefix + num + numSuffix hash_key = header_prefix + num + num_suffix
return self.db.get(hashKey) return self.db.get(hash_key)
def _get_head_block(self): def _get_head_block(self):
''' '''
gets head block header gets head block header
''' '''
if not self.headBlockHeader: if not self.head_block_header:
hash = self.db.get(headHeaderKey) hash = self.db.get(head_header_key)
num = self._get_block_number(hash) num = self._get_block_number(hash)
self.headBlockHeader = self._get_block_header(hash, num) self.head_block_header = self._get_block_header(hash, num)
# find header with valid state # find header with valid state
while not self.db.get(self.headBlockHeader.state_root) and self.headBlockHeader.prevhash is not None: while not self.db.get(self.head_block_header.state_root) and self.head_block_header.prevhash is not None:
hash = self.headBlockHeader.prevhash hash = self.head_block_header.prevhash
num = self._get_block_number(hash) num = self._get_block_number(hash)
self.headBlockHeader = self._get_block_header(hash, num) self.head_block_header = self._get_block_header(hash, num)
return self.headBlockHeader return self.head_block_header
def _get_block_number(self, hash): def _get_block_number(self, hash):
''' '''
gets block number by hash gets block number by hash
''' '''
numberKey = blockHashPrefix + hash number_key = block_hash_prefix + hash
return self.db.get(numberKey) return self.db.get(number_key)
def _get_block_header(self, hash, num): def _get_block_header(self, hash, num):
''' '''
get block header by block header hash & number get block header by block header hash & number
''' '''
headerKey = headerPrefix + num + hash header_key = header_prefix + num + hash
blockHeaderData = self.db.get(headerKey) block_header_data = self.db.get(header_key)
header = rlp.decode(blockHeaderData, sedes=BlockHeader) header = rlp.decode(block_header_data, sedes=BlockHeader)
return header return header
def _get_address_by_hash(self, hash): def _get_address_by_hash(self, hash):
''' '''
get mapped address by its hash get mapped address by its hash
''' '''
addressKey = addressPrefix + hash address_key = address_prefix + hash
return self.db.get(addressKey) return self.db.get(address_key)
def _get_last_indexed_number(self): def _get_last_indexed_number(self):
''' '''
latest indexed block number latest indexed block number
''' '''
return self.db.get(addressMappingHeadKey) return self.db.get(address_mapping_head_key)
def _get_block_receipts(self, hash, num): def _get_block_receipts(self, hash, num):
''' '''
get block transaction receipts by block header hash & number get block transaction receipts by block header hash & number
''' '''
number = _formatBlockNumber(num) number = _format_block_number(num)
receiptsKey = blockReceiptsPrefix + number + hash receipts_key = block_receipts_prefix + number + hash
receiptsData = self.db.get(receiptsKey) receipts_data = self.db.get(receipts_key)
receipts = rlp.decode(receiptsData, sedes=CountableList(ReceiptForStorage)) receipts = rlp.decode(receipts_data, sedes=CountableList(ReceiptForStorage))
return receipts return receipts
@ -143,7 +143,7 @@ class LevelDBWriter(object):
''' '''
sets latest indexed block number sets latest indexed block number
''' '''
return self.db.put(addressMappingHeadKey, _formatBlockNumber(number)) return self.db.put(address_mapping_head_key, _format_block_number(number))
def _start_writing(self): def _start_writing(self):
''' '''
@ -161,8 +161,8 @@ class LevelDBWriter(object):
''' '''
get block transaction receipts by block header hash & number get block transaction receipts by block header hash & number
''' '''
addressKey = addressPrefix + utils.sha3(address) address_key = address_prefix + utils.sha3(address)
self.wb.put(addressKey, address) self.wb.put(address_key, address)
class EthLevelDB(object): class EthLevelDB(object):
@ -211,8 +211,8 @@ class EthLevelDB(object):
tries to find corresponding account address tries to find corresponding account address
''' '''
indexer = AccountIndexer(self) indexer = AccountIndexer(self)
addressHash = binascii.a2b_hex(utils.remove_0x_head(hash)) address_hash = binascii.a2b_hex(utils.remove_0x_head(hash))
address = indexer.get_contract_by_hash(addressHash) address = indexer.get_contract_by_hash(address_hash)
if address: if address:
return _encode_hex(address) return _encode_hex(address)
else: else:
@ -223,18 +223,18 @@ class EthLevelDB(object):
gets block header by block number gets block header by block number
''' '''
hash = self.reader._get_block_hash(number) hash = self.reader._get_block_hash(number)
blockNumber = _formatBlockNumber(number) block_number = _format_block_number(number)
return self.reader._get_block_header(hash, blockNumber) return self.reader._get_block_header(hash, block_number)
def eth_getBlockByNumber(self, number): def eth_getBlockByNumber(self, number):
''' '''
gets block body by block number gets block body by block number
''' '''
blockHash = self.reader._get_block_hash(number) block_hash = self.reader._get_block_hash(number)
blockNumber = _formatBlockNumber(number) block_number = _format_block_number(number)
bodyKey = bodyPrefix + blockNumber + blockHash body_key = body_prefix + block_number + block_hash
blockData = self.db.get(bodyKey) block_data = self.db.get(body_key)
body = rlp.decode(blockData, sedes=Block) body = rlp.decode(block_data, sedes=Block)
return body return body
def eth_getCode(self, address): def eth_getCode(self, address):

@ -96,7 +96,7 @@ class State():
def __init__(self, db, root): def __init__(self, db, root):
self.db = db self.db = db
self.trie = Trie(self.db, root) self.trie = Trie(self.db, root)
self.secureTrie = SecureTrie(self.trie) self.secure_trie = SecureTrie(self.trie)
self.journal = [] self.journal = []
self.cache = {} self.cache = {}
@ -106,7 +106,7 @@ class State():
''' '''
if address in self.cache: if address in self.cache:
return self.cache[address] return self.cache[address]
rlpdata = self.secureTrie.get(address) rlpdata = self.secure_trie.get(address)
if rlpdata == trie.BLANK_NODE and len(address) == 32: # support for hashed addresses if rlpdata == trie.BLANK_NODE and len(address) == 32: # support for hashed addresses
rlpdata = self.trie.get(address) rlpdata = self.trie.get(address)
if rlpdata != trie.BLANK_NODE: if rlpdata != trie.BLANK_NODE:
@ -123,6 +123,6 @@ class State():
''' '''
iterates through trie to and yields non-blank leafs as accounts iterates through trie to and yields non-blank leafs as accounts
''' '''
for addressHash, rlpdata in self.secureTrie.trie.iter_branch(): for address_hash, rlpdata in self.secure_trie.trie.iter_branch():
if rlpdata != trie.BLANK_NODE: if rlpdata != trie.BLANK_NODE:
yield rlp.decode(rlpdata, Account, db=self.db, address=addressHash) yield rlp.decode(rlpdata, Account, db=self.db, address=address_hash)

@ -101,7 +101,7 @@ class Mythril(object):
self.leveldb_dir = self._init_config() self.leveldb_dir = self._init_config()
self.eth = None # ethereum API client self.eth = None # ethereum API client
self.ethDb = None # ethereum LevelDB client self.eth_db = None # ethereum LevelDB client
self.contracts = [] # loaded contracts self.contracts = [] # loaded contracts
@ -213,8 +213,8 @@ class Mythril(object):
return solc_binary return solc_binary
def set_api_leveldb(self, leveldb): def set_api_leveldb(self, leveldb):
self.ethDb = EthLevelDB(leveldb) self.eth_db = EthLevelDB(leveldb)
self.eth = self.ethDb self.eth = self.eth_db
return self.eth return self.eth
def set_api_rpc_infura(self): def set_api_rpc_infura(self):
@ -277,7 +277,7 @@ class Mythril(object):
print("Address: " + addresses[i] + ", balance: " + str(balances[i])) print("Address: " + addresses[i] + ", balance: " + str(balances[i]))
try: try:
self.ethDb.search(search, search_callback) self.eth_db.search(search, search_callback)
except SyntaxError: except SyntaxError:
raise CriticalError("Syntax error in search expression.") raise CriticalError("Syntax error in search expression.")
@ -286,7 +286,7 @@ class Mythril(object):
if not re.match(r'0x[a-fA-F0-9]{64}', hash): if not re.match(r'0x[a-fA-F0-9]{64}', hash):
raise CriticalError("Invalid address hash. Expected format is '0x...'.") raise CriticalError("Invalid address hash. Expected format is '0x...'.")
print(self.ethDb.contract_hash_to_address(hash)) print(self.eth_db.contract_hash_to_address(hash))
def load_from_bytecode(self, code): def load_from_bytecode(self, code):
address = util.get_indexed_address(0) address = util.get_indexed_address(0)

@ -60,11 +60,11 @@ def analyze_truffle_project(args):
disassembly = ethcontract.disassembly disassembly = ethcontract.disassembly
source = contractdata['source'] source = contractdata['source']
deployedSourceMap = contractdata['deployedSourceMap'].split(";") deployed_source_map = contractdata['deployedSourceMap'].split(";")
mappings = [] mappings = []
for item in deployedSourceMap: for item in deployed_source_map:
mapping = item.split(":") mapping = item.split(":")
if len(mapping) > 0 and len(mapping[0]) > 0: if len(mapping) > 0 and len(mapping[0]) > 0:

Loading…
Cancel
Save