diff --git a/mythril/support/truffle.py b/mythril/support/truffle.py index e1857f05..e37b2a50 100644 --- a/mythril/support/truffle.py +++ b/mythril/support/truffle.py @@ -1,5 +1,6 @@ import os import re +import sys import json from mythril.ether import util from mythril.ether.ethcontract import ETHContract @@ -14,67 +15,73 @@ def analyze_truffle_project(): build_dir = os.path.join(project_root, "build", "contracts") - contract_files = os.listdir(build_dir) + files = os.listdir(build_dir) - for contract_file in contract_files: + for filename in files: - with open(os.path.join(build_dir, contract_file)) as cf: - contractdata = json.load(cf) + if re.match(r'.*\.json$', filename): - name = contractdata['contractName'] - bytecode = contractdata['deployedBytecode'] + with open(os.path.join(build_dir, filename)) as cf: + contractdata = json.load(cf) - if (len(bytecode) < 4): - continue + try: + name = contractdata['contractName'] + bytecode = contractdata['deployedBytecode'] + except: + print("Unable to parse contract data. Please use Truffle 4 to compile your project.") + sys.exit() - ethcontract= ETHContract(bytecode, name=name, address = util.get_indexed_address(0)) - contracts = [ethcontract] + if (len(bytecode) < 4): + continue - states = StateSpace(contracts, max_depth = 10) - report = fire_lasers(states) + ethcontract= ETHContract(bytecode, name=name, address = util.get_indexed_address(0)) - # augment with source code + contracts = [ethcontract] - disassembly = ethcontract.get_disassembly() - source = contractdata['source'] + states = StateSpace(contracts, max_depth = 10) + report = fire_lasers(states) - deployedSourceMap = contractdata['deployedSourceMap'].split(";") + # augment with source code - mappings = [] - i = 0 + disassembly = ethcontract.get_disassembly() + source = contractdata['source'] - while(i < len(deployedSourceMap)): + deployedSourceMap = contractdata['deployedSourceMap'].split(";") - m = re.search(r"^(\d+):*(\d+)", deployedSourceMap[i]) + mappings = [] + i = 0 - if (m): - offset = m.group(1) - length = m.group(2) - else: - m = re.search(r"^:(\d+)", deployedSourceMap[i]) + while(i < len(deployedSourceMap)): + + m = re.search(r"^(\d+):*(\d+)", deployedSourceMap[i]) - if m: - length = m.group(1) + if (m): + offset = m.group(1) + length = m.group(2) + else: + m = re.search(r"^:(\d+)", deployedSourceMap[i]) - mappings.append((int(offset), int(length))) + if m: + length = m.group(1) - i += 1 + mappings.append((int(offset), int(length))) - for key, issue in report.issues.items(): + i += 1 - index = helper.get_instruction_index(disassembly.instruction_list, issue.pc) + for key, issue in report.issues.items(): - if index: - issue.code_start = mappings[index][0] - issue.code_length = mappings[index][1] - issue.code = source[mappings[index][0]: mappings[index][0] + mappings[index][1]] + index = helper.get_instruction_index(disassembly.instruction_list, issue.pc) - + if index: + issue.code_start = mappings[index][0] + issue.code_length = mappings[index][1] + issue.code = source[mappings[index][0]: mappings[index][0] + mappings[index][1]] - if len(report.issues): - print("Analysis result for " + name + ":\n" + report.as_text()) - else: - print("Analysis result for " + name + ": No issues found.") + + if len(report.issues): + print("Analysis result for " + name + ":\n" + report.as_text()) + else: + print("Analysis result for " + name + ": No issues found.")