diff --git a/README.md b/README.md index 09c4fe55..addccfe9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Mythril +# Mythril [![Tweet](https://img.shields.io/twitter/url/http/shields.io.svg?style=social)](https://twitter.com/intent/tweet?text=Mythril%20-%20Security%20Analyzer%20for%20Ethereum%20Smart%20Contracts&url=https://www.github.com/ConsenSys/mythril) [![PyPI](https://badge.fury.io/py/mythril.svg)](https://pypi.python.org/pypi/mythril) [![Join the chat at https://gitter.im/ConsenSys/mythril](https://badges.gitter.im/ConsenSys/mythril.svg)](https://gitter.im/ConsenSys/mythril?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) ![Master Build Status](https://img.shields.io/circleci/project/github/ConsenSys/mythril/master.svg) diff --git a/mythril/analysis/modules/external_calls.py b/mythril/analysis/modules/external_calls.py index 5a2c60b1..884e9c7c 100644 --- a/mythril/analysis/modules/external_calls.py +++ b/mythril/analysis/modules/external_calls.py @@ -21,11 +21,11 @@ def search_children(statespace, node, start_index=0, depth=0, results=[]): if(depth < MAX_SEARCH_DEPTH): - nStates = len(node.states) + n_states = len(node.states) - if nStates > start_index: + if n_states > start_index: - for j in range(start_index, nStates): + for j in range(start_index, n_states): if node.states[j].get_current_instruction()['opcode'] == 'SSTORE': results.append(node.states[j].get_current_instruction()['address']) diff --git a/mythril/analysis/modules/unchecked_retval.py b/mythril/analysis/modules/unchecked_retval.py index 7ed9858d..0c6f6b09 100644 --- a/mythril/analysis/modules/unchecked_retval.py +++ b/mythril/analysis/modules/unchecked_retval.py @@ -57,9 +57,9 @@ def execute(statespace): else: - nStates = len(node.states) + n_states = len(node.states) - for idx in range(0, nStates - 1): # Ignore CALLs at last position in a node + for idx in range(0, n_states - 1): # Ignore CALLs at last position in a node state = node.states[idx] instr = state.get_current_instruction() diff --git a/mythril/ether/evm.py b/mythril/ether/evm.py index 29376e0d..449fcdcf 100644 --- a/mythril/ether/evm.py +++ b/mythril/ether/evm.py @@ -9,15 +9,15 @@ import re def trace(code, calldata = ""): - logHandlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage'] + log_handlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage'] output = StringIO() - streamHandler = StreamHandler(output) + stream_handler = StreamHandler(output) - for handler in logHandlers: + for handler in log_handlers: log_vm_op = get_logger(handler) log_vm_op.setLevel("TRACE") - log_vm_op.addHandler(streamHandler) + log_vm_op.addHandler(stream_handler) addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567') @@ -29,7 +29,7 @@ def trace(code, calldata = ""): res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code)) - streamHandler.flush() + stream_handler.flush() ret = output.getvalue() diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index ae7ef8fe..224bc58c 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -516,9 +516,44 @@ class Instruction: @instruction def codecopy_(self, global_state): - # FIXME: not implemented - state = global_state.mstate - start, s1, size = state.stack.pop(), state.stack.pop(), state.stack.pop() + memory_offset, code_offset, size = global_state.mstate.stack.pop(), global_state.mstate.stack.pop(), global_state.mstate.stack.pop() + + try: + concrete_memory_offset = helper.get_concrete_int(memory_offset) + except AttributeError: + logging.debug("Unsupported symbolic memory offset in CODECOPY") + return [global_state] + + try: + concrete_size = helper.get_concrete_int(size) + global_state.mstate.mem_extend(concrete_memory_offset, concrete_size) + except: + # except both attribute error and Exception + global_state.mstate.mem_extend(concrete_memory_offset, 1) + global_state.mstate.memory[concrete_memory_offset] = \ + BitVec("code({})".format(global_state.environment.active_account.contract_name), 256) + return [global_state] + + try: + concrete_code_offset = helper.get_concrete_int(code_offset) + except AttributeError: + logging.debug("Unsupported symbolic code offset in CODECOPY") + global_state.mstate.mem_extend(concrete_memory_offset, concrete_size) + for i in range(concrete_size): + global_state.mstate.memory[concrete_memory_offset + i] = \ + BitVec("code({})".format(global_state.environment.active_account.contract_name), 256) + return [global_state] + + bytecode = global_state.environment.active_account.code.bytecode + + for i in range(concrete_size): + try: + global_state.mstate.memory[concrete_memory_offset + i] =\ + int(bytecode[2*(concrete_code_offset + i): 2*(concrete_code_offset + i + 1)], 16) + except IndexError: + global_state.mstate.memory[concrete_memory_offset + i] = \ + BitVec("code({})".format(global_state.environment.active_account.contract_name), 256) + return [global_state] @instruction @@ -767,25 +802,25 @@ class Instruction: instr = disassembly.instruction_list[index] # True case - condi = condition if type(condition) == BoolRef else condition != 0 + condi = simplify(condition) if type(condition) == BoolRef else condition != 0 if instr['opcode'] == "JUMPDEST": - if (type(condi) == bool and condi) or (type(condi) == BoolRef and not is_false(simplify(condi))): + if (type(condi) == bool and condi) or (type(condi) == BoolRef and not is_false(condi)): new_state = copy(global_state) new_state.mstate.pc = index new_state.mstate.depth += 1 - new_state.mstate.constraints.append(simplify(condi)) + new_state.mstate.constraints.append(condi) states.append(new_state) else: logging.debug("Pruned unreachable states.") # False case - negated = Not(condition) if type(condition) == BoolRef else condition == 0 + negated = simplify(Not(condition)) if type(condition) == BoolRef else condition == 0 - if (type(negated) == bool and negated) or (type(negated) == BoolRef and not is_false(simplify(negated))): + if (type(negated) == bool and negated) or (type(negated) == BoolRef and not is_false(negated)): new_state = copy(global_state) new_state.mstate.depth += 1 - new_state.mstate.constraints.append(simplify(negated)) + new_state.mstate.constraints.append(negated) states.append(new_state) else: logging.debug("Pruned unreachable states.") diff --git a/mythril/laser/ethereum/state.py b/mythril/laser/ethereum/state.py index 032cf66e..5876977d 100644 --- a/mythril/laser/ethereum/state.py +++ b/mythril/laser/ethereum/state.py @@ -72,6 +72,7 @@ class Environment: self.active_function_name = "" self.address = BitVecVal(int(active_account.address, 16), 256) + self.code = active_account.code self.sender = sender @@ -81,6 +82,7 @@ class Environment: self.origin = origin self.callvalue = callvalue + def __str__(self): return str(self.as_dict) diff --git a/mythril/laser/ethereum/transaction.py b/mythril/laser/ethereum/transaction.py index ec927c16..32189a08 100644 --- a/mythril/laser/ethereum/transaction.py +++ b/mythril/laser/ethereum/transaction.py @@ -50,7 +50,7 @@ class MessageCall: evm.edges.append(Edge(open_world_state.node.uid, new_node.uid, edge_type=JumpType.Transaction, condition=None)) global_state = GlobalState(open_world_state.accounts, environment, new_node) - global_state.environment.active_function_name = 'fallback()' + global_state.environment.active_function_name = 'fallback' new_node.states.append(global_state) evm.work_list.append(global_state) diff --git a/mythril/leveldb/client.py b/mythril/leveldb/client.py index 231fe98f..3507264d 100644 --- a/mythril/leveldb/client.py +++ b/mythril/leveldb/client.py @@ -11,21 +11,21 @@ from mythril.ether.ethcontract import ETHContract # Per https://github.com/ethereum/go-ethereum/blob/master/core/database_util.go # prefixes and suffixes for keys in geth -headerPrefix = b'h' # headerPrefix + num (uint64 big endian) + hash -> header -bodyPrefix = b'b' # bodyPrefix + num (uint64 big endian) + hash -> block body -numSuffix = b'n' # headerPrefix + num (uint64 big endian) + numSuffix -> hash -blockHashPrefix = b'H' # blockHashPrefix + hash -> num (uint64 big endian) -blockReceiptsPrefix = b'r' # blockReceiptsPrefix + num (uint64 big endian) + hash -> block receipts +header_prefix = b'h' # header_prefix + num (uint64 big endian) + hash -> header +body_prefix = b'b' # body_prefix + num (uint64 big endian) + hash -> block body +num_suffix = b'n' # header_prefix + num (uint64 big endian) + num_suffix -> hash +block_hash_prefix = b'H' # block_hash_prefix + hash -> num (uint64 big endian) +block_receipts_prefix = b'r' # block_receipts_prefix + num (uint64 big endian) + hash -> block receipts # known geth keys -headHeaderKey = b'LastBlock' # head (latest) header hash +head_header_key = b'LastBlock' # head (latest) header hash # custom prefixes -addressPrefix = b'AM' # addressPrefix + hash -> address +address_prefix = b'AM' # address_prefix + hash -> address # custom keys -addressMappingHeadKey = b'accountMapping' # head (latest) number of indexed block -headHeaderKey = b'LastBlock' # head (latest) header hash +address_mapping_head_key = b'accountMapping' # head (latest) number of indexed block +head_header_key = b'LastBlock' # head (latest) header hash -def _formatBlockNumber(number): +def _format_block_number(number): ''' formats block number to uint64 big endian ''' @@ -46,87 +46,87 @@ class LevelDBReader(object): def __init__(self, db): self.db = db - self.headBlockHeader = None - self.headState = None + self.head_block_header = None + self.head_state = None def _get_head_state(self): ''' gets head state ''' - if not self.headState: + if not self.head_state: root = self._get_head_block().state_root - self.headState = State(self.db, root) - return self.headState + self.head_state = State(self.db, root) + return self.head_state def _get_account(self, address): ''' gets account by address ''' state = self._get_head_state() - accountAddress = binascii.a2b_hex(utils.remove_0x_head(address)) - return state.get_and_cache_account(accountAddress) + account_address = binascii.a2b_hex(utils.remove_0x_head(address)) + return state.get_and_cache_account(account_address) def _get_block_hash(self, number): ''' gets block hash by block number ''' - num = _formatBlockNumber(number) - hashKey = headerPrefix + num + numSuffix - return self.db.get(hashKey) + num = _format_block_number(number) + hash_key = header_prefix + num + num_suffix + return self.db.get(hash_key) def _get_head_block(self): ''' gets head block header ''' - if not self.headBlockHeader: - hash = self.db.get(headHeaderKey) + if not self.head_block_header: + hash = self.db.get(head_header_key) num = self._get_block_number(hash) - self.headBlockHeader = self._get_block_header(hash, num) + self.head_block_header = self._get_block_header(hash, num) # find header with valid state - while not self.db.get(self.headBlockHeader.state_root) and self.headBlockHeader.prevhash is not None: - hash = self.headBlockHeader.prevhash + while not self.db.get(self.head_block_header.state_root) and self.head_block_header.prevhash is not None: + hash = self.head_block_header.prevhash num = self._get_block_number(hash) - self.headBlockHeader = self._get_block_header(hash, num) + self.head_block_header = self._get_block_header(hash, num) - return self.headBlockHeader + return self.head_block_header def _get_block_number(self, hash): ''' gets block number by hash ''' - numberKey = blockHashPrefix + hash - return self.db.get(numberKey) + number_key = block_hash_prefix + hash + return self.db.get(number_key) def _get_block_header(self, hash, num): ''' get block header by block header hash & number ''' - headerKey = headerPrefix + num + hash - blockHeaderData = self.db.get(headerKey) - header = rlp.decode(blockHeaderData, sedes=BlockHeader) + header_key = header_prefix + num + hash + block_header_data = self.db.get(header_key) + header = rlp.decode(block_header_data, sedes=BlockHeader) return header def _get_address_by_hash(self, hash): ''' get mapped address by its hash ''' - addressKey = addressPrefix + hash - return self.db.get(addressKey) + address_key = address_prefix + hash + return self.db.get(address_key) def _get_last_indexed_number(self): ''' latest indexed block number ''' - return self.db.get(addressMappingHeadKey) + return self.db.get(address_mapping_head_key) def _get_block_receipts(self, hash, num): ''' get block transaction receipts by block header hash & number ''' - number = _formatBlockNumber(num) - receiptsKey = blockReceiptsPrefix + number + hash - receiptsData = self.db.get(receiptsKey) - receipts = rlp.decode(receiptsData, sedes=CountableList(ReceiptForStorage)) + number = _format_block_number(num) + receipts_key = block_receipts_prefix + number + hash + receipts_data = self.db.get(receipts_key) + receipts = rlp.decode(receipts_data, sedes=CountableList(ReceiptForStorage)) return receipts @@ -143,7 +143,7 @@ class LevelDBWriter(object): ''' sets latest indexed block number ''' - return self.db.put(addressMappingHeadKey, _formatBlockNumber(number)) + return self.db.put(address_mapping_head_key, _format_block_number(number)) def _start_writing(self): ''' @@ -161,8 +161,8 @@ class LevelDBWriter(object): ''' get block transaction receipts by block header hash & number ''' - addressKey = addressPrefix + utils.sha3(address) - self.wb.put(addressKey, address) + address_key = address_prefix + utils.sha3(address) + self.wb.put(address_key, address) class EthLevelDB(object): @@ -211,8 +211,8 @@ class EthLevelDB(object): tries to find corresponding account address ''' indexer = AccountIndexer(self) - addressHash = binascii.a2b_hex(utils.remove_0x_head(hash)) - address = indexer.get_contract_by_hash(addressHash) + address_hash = binascii.a2b_hex(utils.remove_0x_head(hash)) + address = indexer.get_contract_by_hash(address_hash) if address: return _encode_hex(address) else: @@ -223,18 +223,18 @@ class EthLevelDB(object): gets block header by block number ''' hash = self.reader._get_block_hash(number) - blockNumber = _formatBlockNumber(number) - return self.reader._get_block_header(hash, blockNumber) + block_number = _format_block_number(number) + return self.reader._get_block_header(hash, block_number) def eth_getBlockByNumber(self, number): ''' gets block body by block number ''' - blockHash = self.reader._get_block_hash(number) - blockNumber = _formatBlockNumber(number) - bodyKey = bodyPrefix + blockNumber + blockHash - blockData = self.db.get(bodyKey) - body = rlp.decode(blockData, sedes=Block) + block_hash = self.reader._get_block_hash(number) + block_number = _format_block_number(number) + body_key = body_prefix + block_number + block_hash + block_data = self.db.get(body_key) + body = rlp.decode(block_data, sedes=Block) return body def eth_getCode(self, address): diff --git a/mythril/leveldb/state.py b/mythril/leveldb/state.py index fbf17cf0..96360300 100644 --- a/mythril/leveldb/state.py +++ b/mythril/leveldb/state.py @@ -96,7 +96,7 @@ class State(): def __init__(self, db, root): self.db = db self.trie = Trie(self.db, root) - self.secureTrie = SecureTrie(self.trie) + self.secure_trie = SecureTrie(self.trie) self.journal = [] self.cache = {} @@ -106,7 +106,7 @@ class State(): ''' if address in self.cache: return self.cache[address] - rlpdata = self.secureTrie.get(address) + rlpdata = self.secure_trie.get(address) if rlpdata == trie.BLANK_NODE and len(address) == 32: # support for hashed addresses rlpdata = self.trie.get(address) if rlpdata != trie.BLANK_NODE: @@ -123,6 +123,6 @@ class State(): ''' iterates through trie to and yields non-blank leafs as accounts ''' - for addressHash, rlpdata in self.secureTrie.trie.iter_branch(): + for address_hash, rlpdata in self.secure_trie.trie.iter_branch(): if rlpdata != trie.BLANK_NODE: - yield rlp.decode(rlpdata, Account, db=self.db, address=addressHash) \ No newline at end of file + yield rlp.decode(rlpdata, Account, db=self.db, address=address_hash) \ No newline at end of file diff --git a/mythril/mythril.py b/mythril/mythril.py index e687e480..9b7b69b5 100644 --- a/mythril/mythril.py +++ b/mythril/mythril.py @@ -101,7 +101,7 @@ class Mythril(object): self.leveldb_dir = self._init_config() self.eth = None # ethereum API client - self.ethDb = None # ethereum LevelDB client + self.eth_db = None # ethereum LevelDB client self.contracts = [] # loaded contracts @@ -213,8 +213,8 @@ class Mythril(object): return solc_binary def set_api_leveldb(self, leveldb): - self.ethDb = EthLevelDB(leveldb) - self.eth = self.ethDb + self.eth_db = EthLevelDB(leveldb) + self.eth = self.eth_db return self.eth def set_api_rpc_infura(self): @@ -277,7 +277,7 @@ class Mythril(object): print("Address: " + addresses[i] + ", balance: " + str(balances[i])) try: - self.ethDb.search(search, search_callback) + self.eth_db.search(search, search_callback) except SyntaxError: raise CriticalError("Syntax error in search expression.") @@ -286,7 +286,7 @@ class Mythril(object): if not re.match(r'0x[a-fA-F0-9]{64}', hash): raise CriticalError("Invalid address hash. Expected format is '0x...'.") - print(self.ethDb.contract_hash_to_address(hash)) + print(self.eth_db.contract_hash_to_address(hash)) def load_from_bytecode(self, code): address = util.get_indexed_address(0) diff --git a/mythril/support/truffle.py b/mythril/support/truffle.py index d59879ab..ad4d4151 100644 --- a/mythril/support/truffle.py +++ b/mythril/support/truffle.py @@ -60,11 +60,11 @@ def analyze_truffle_project(args): disassembly = ethcontract.disassembly source = contractdata['source'] - deployedSourceMap = contractdata['deployedSourceMap'].split(";") + deployed_source_map = contractdata['deployedSourceMap'].split(";") mappings = [] - for item in deployedSourceMap: + for item in deployed_source_map: mapping = item.split(":") if len(mapping) > 0 and len(mapping[0]) > 0: diff --git a/tests/instructions/__init__.py b/tests/instructions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/instructions/codecopy_test.py b/tests/instructions/codecopy_test.py new file mode 100644 index 00000000..53ae049d --- /dev/null +++ b/tests/instructions/codecopy_test.py @@ -0,0 +1,20 @@ +from mythril.disassembler.disassembly import Disassembly +from mythril.laser.ethereum.state import MachineState, GlobalState, Environment, Account +from mythril.laser.ethereum.instructions import Instruction + + +def test_codecopy_concrete(): + # Arrange + active_account = Account("0x0", code= Disassembly("60606040")) + environment = Environment(active_account, None, None, None, None, None) + og_state = GlobalState(None, environment, None, MachineState(gas=10000000)) + + og_state.mstate.stack = [2, 2, 2] + instruction = Instruction("codecopy", dynamic_loader=None) + + # Act + new_state = instruction.evaluate(og_state)[0] + + # Assert + assert new_state.mstate.memory[2] == 96 + assert new_state.mstate.memory[3] == 64 diff --git a/tests/report_test.py b/tests/report_test.py index 7f1da181..a5c41f21 100644 --- a/tests/report_test.py +++ b/tests/report_test.py @@ -38,7 +38,7 @@ def _generate_report(input_file): def reports(): """ Fixture that analyses all reports""" pool = Pool(cpu_count()) - input_files = [f for f in TESTDATA_INPUTS.iterdir()] + input_files = sorted([f for f in TESTDATA_INPUTS.iterdir()]) results = pool.map(_generate_report, input_files) return results