diff --git a/mythril/__init__.py b/mythril/__init__.py index c0494623..bfd74099 100644 --- a/mythril/__init__.py +++ b/mythril/__init__.py @@ -1,6 +1,6 @@ # We use RsT document formatting in docstring. For example :param to mark parameters. # See PEP 287 -__docformat__ = 'restructuredtext' +__docformat__ = "restructuredtext" # Accept mythril.VERSION to get mythril's current version number from .version import VERSION # NOQA diff --git a/mythril/analysis/callgraph.py b/mythril/analysis/callgraph.py index d9cd2a80..b5dd1ee3 100644 --- a/mythril/analysis/callgraph.py +++ b/mythril/analysis/callgraph.py @@ -5,87 +5,119 @@ from mythril.laser.ethereum.svm import NodeFlags import z3 default_opts = { - 'autoResize': True, - 'height': '100%', - 'width': '100%', - 'manipulation': False, - 'layout': { - 'improvedLayout': True, - 'hierarchical': { - 'enabled': True, - 'levelSeparation': 450, - 'nodeSpacing': 200, - 'treeSpacing': 100, - 'blockShifting': True, - 'edgeMinimization': True, - 'parentCentralization': False, - 'direction': 'LR', - 'sortMethod': 'directed' - } + "autoResize": True, + "height": "100%", + "width": "100%", + "manipulation": False, + "layout": { + "improvedLayout": True, + "hierarchical": { + "enabled": True, + "levelSeparation": 450, + "nodeSpacing": 200, + "treeSpacing": 100, + "blockShifting": True, + "edgeMinimization": True, + "parentCentralization": False, + "direction": "LR", + "sortMethod": "directed", + }, }, - 'nodes': { - 'color': '#000000', - 'borderWidth': 1, - 'borderWidthSelected': 2, - 'chosen': True, - 'shape': 'box', - 'font': {'align': 'left', 'color': '#FFFFFF'}, + "nodes": { + "color": "#000000", + "borderWidth": 1, + "borderWidthSelected": 2, + "chosen": True, + "shape": "box", + "font": {"align": "left", "color": "#FFFFFF"}, }, - 'edges': { - 'font': { - 'color': '#FFFFFF', - 'face': 'arial', - 'background': 'none', - 'strokeWidth': 0, - 'strokeColor': '#ffffff', - 'align': 'horizontal', - 'multi': False, - 'vadjust': 0, + "edges": { + "font": { + "color": "#FFFFFF", + "face": "arial", + "background": "none", + "strokeWidth": 0, + "strokeColor": "#ffffff", + "align": "horizontal", + "multi": False, + "vadjust": 0, } }, - 'physics': {'enabled': False} + "physics": {"enabled": False}, } phrack_opts = { - 'nodes': { - 'color': '#000000', - 'borderWidth': 1, - 'borderWidthSelected': 1, - 'shapeProperties': { - 'borderDashes': False, - 'borderRadius': 0, - }, - 'chosen': True, - 'shape': 'box', - 'font': {'face': 'courier new', 'align': 'left', 'color': '#000000'}, + "nodes": { + "color": "#000000", + "borderWidth": 1, + "borderWidthSelected": 1, + "shapeProperties": {"borderDashes": False, "borderRadius": 0}, + "chosen": True, + "shape": "box", + "font": {"face": "courier new", "align": "left", "color": "#000000"}, }, - 'edges': { - 'font': { - 'color': '#000000', - 'face': 'courier new', - 'background': 'none', - 'strokeWidth': 0, - 'strokeColor': '#ffffff', - 'align': 'horizontal', - 'multi': False, - 'vadjust': 0, + "edges": { + "font": { + "color": "#000000", + "face": "courier new", + "background": "none", + "strokeWidth": 0, + "strokeColor": "#ffffff", + "align": "horizontal", + "multi": False, + "vadjust": 0, } - } + }, } default_colors = [ - {'border': '#26996f', 'background': '#2f7e5b', 'highlight': {'border': '#26996f', 'background': '#28a16f'}}, - {'border': '#9e42b3', 'background': '#842899', 'highlight': {'border': '#9e42b3', 'background': '#933da6'}}, - {'border': '#b82323', 'background': '#991d1d', 'highlight': {'border': '#b82323', 'background': '#a61f1f'}}, - {'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#4753bf', 'background': '#424db3'}}, - {'border': '#26996f', 'background': '#2f7e5b', 'highlight': {'border': '#26996f', 'background': '#28a16f'}}, - {'border': '#9e42b3', 'background': '#842899', 'highlight': {'border': '#9e42b3', 'background': '#933da6'}}, - {'border': '#b82323', 'background': '#991d1d', 'highlight': {'border': '#b82323', 'background': '#a61f1f'}}, - {'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#4753bf', 'background': '#424db3'}}, + { + "border": "#26996f", + "background": "#2f7e5b", + "highlight": {"border": "#26996f", "background": "#28a16f"}, + }, + { + "border": "#9e42b3", + "background": "#842899", + "highlight": {"border": "#9e42b3", "background": "#933da6"}, + }, + { + "border": "#b82323", + "background": "#991d1d", + "highlight": {"border": "#b82323", "background": "#a61f1f"}, + }, + { + "border": "#4753bf", + "background": "#3b46a1", + "highlight": {"border": "#4753bf", "background": "#424db3"}, + }, + { + "border": "#26996f", + "background": "#2f7e5b", + "highlight": {"border": "#26996f", "background": "#28a16f"}, + }, + { + "border": "#9e42b3", + "background": "#842899", + "highlight": {"border": "#9e42b3", "background": "#933da6"}, + }, + { + "border": "#b82323", + "background": "#991d1d", + "highlight": {"border": "#b82323", "background": "#a61f1f"}, + }, + { + "border": "#4753bf", + "background": "#3b46a1", + "highlight": {"border": "#4753bf", "background": "#424db3"}, + }, ] -phrack_color = {'border': '#000000', 'background': '#ffffff', - 'highlight': {'border': '#000000', 'background': '#ffffff'}} +phrack_color = { + "border": "#000000", + "background": "#ffffff", + "highlight": {"border": "#000000", "background": "#ffffff"}, +} def extract_nodes(statespace, color_map): @@ -95,28 +127,43 @@ def extract_nodes(statespace, color_map): instructions = [state.get_current_instruction() for state in node.states] code_split = [] for instruction in instructions: - if instruction['opcode'].startswith("PUSH"): - code_line = "%d %s %s" % (instruction['address'], instruction['opcode'], instruction['argument']) - elif instruction['opcode'].startswith("JUMPDEST") and NodeFlags.FUNC_ENTRY in node.flags and instruction['address'] == node.start_addr: + if instruction["opcode"].startswith("PUSH"): + code_line = "%d %s %s" % ( + instruction["address"], + instruction["opcode"], + instruction["argument"], + ) + elif ( + instruction["opcode"].startswith("JUMPDEST") + and NodeFlags.FUNC_ENTRY in node.flags + and instruction["address"] == node.start_addr + ): code_line = node.function_name else: - code_line = "%d %s" % (instruction['address'], instruction['opcode']) + code_line = "%d %s" % (instruction["address"], instruction["opcode"]) - code_line = re.sub("([0-9a-f]{8})[0-9a-f]+", lambda m: m.group(1) + "(...)", code_line) + code_line = re.sub( + "([0-9a-f]{8})[0-9a-f]+", lambda m: m.group(1) + "(...)", code_line + ) code_split.append(code_line) - truncated_code = '\n'.join(code_split) if (len(code_split) < 7) \ - else '\n'.join(code_split[:6]) + "\n(click to expand +)" - - nodes.append({ - 'id': str(node_key), - 'color': color_map[node.get_cfg_dict()['contract_name']], - 'size': 150, - 'fullLabel': '\n'.join(code_split), - 'label': truncated_code, - 'truncLabel': truncated_code, - 'isExpanded': False - }) + truncated_code = ( + "\n".join(code_split) + if (len(code_split) < 7) + else "\n".join(code_split[:6]) + "\n(click to expand +)" + ) + + nodes.append( + { + "id": str(node_key), + "color": color_map[node.get_cfg_dict()["contract_name"]], + "size": 150, + "fullLabel": "\n".join(code_split), + "label": truncated_code, + "truncLabel": truncated_code, + "isExpanded": False, + } + ) return nodes @@ -131,21 +178,33 @@ def extract_edges(statespace): except z3.Z3Exception: label = str(edge.condition).replace("\n", "") - label = re.sub(r'([^_])([\d]{2}\d+)', lambda m: m.group(1) + hex(int(m.group(2))), label) - - edges.append({ - 'from': str(edge.as_dict['from']), - 'to': str(edge.as_dict['to']), - 'arrows': 'to', - 'label': label, - 'smooth': {'type': 'cubicBezier'} - }) + label = re.sub( + r"([^_])([\d]{2}\d+)", lambda m: m.group(1) + hex(int(m.group(2))), label + ) + + edges.append( + { + "from": str(edge.as_dict["from"]), + "to": str(edge.as_dict["to"]), + "arrows": "to", + "label": label, + "smooth": {"type": "cubicBezier"}, + } + ) return edges -def generate_graph(statespace, title="Mythril / Ethereum LASER Symbolic VM", physics=False, phrackify=False): - env = Environment(loader=PackageLoader('mythril.analysis'), autoescape=select_autoescape(['html', 'xml'])) - template = env.get_template('callgraph.html') +def generate_graph( + statespace, + title="Mythril / Ethereum LASER Symbolic VM", + physics=False, + phrackify=False, +): + env = Environment( + loader=PackageLoader("mythril.analysis"), + autoescape=select_autoescape(["html", "xml"]), + ) + template = env.get_template("callgraph.html") graph_opts = default_opts accounts = statespace.accounts @@ -154,13 +213,17 @@ def generate_graph(statespace, title="Mythril / Ethereum LASER Symbolic VM", phy color_map = {accounts[k].contract_name: phrack_color for k in accounts} graph_opts.update(phrack_opts) else: - color_map = {accounts[k].contract_name: default_colors[i % len(default_colors)] for i, k in enumerate(accounts)} + color_map = { + accounts[k].contract_name: default_colors[i % len(default_colors)] + for i, k in enumerate(accounts) + } - graph_opts['physics']['enabled'] = physics + graph_opts["physics"]["enabled"] = physics - return template.render(title=title, - nodes=extract_nodes(statespace, color_map), - edges=extract_edges(statespace), - phrackify=phrackify, - opts=graph_opts - ) + return template.render( + title=title, + nodes=extract_nodes(statespace, color_map), + edges=extract_edges(statespace), + phrackify=phrackify, + opts=graph_opts, + ) diff --git a/mythril/analysis/modules/delegatecall.py b/mythril/analysis/modules/delegatecall.py index d676f8f2..eeb81e48 100644 --- a/mythril/analysis/modules/delegatecall.py +++ b/mythril/analysis/modules/delegatecall.py @@ -5,11 +5,11 @@ from mythril.analysis.report import Issue import logging -''' +""" MODULE DESCRIPTION: Check for invocations of delegatecall(msg.data) in the fallback function. -''' +""" def execute(statespace): @@ -28,7 +28,7 @@ def execute(statespace): continue state = call.state - address = state.get_current_instruction()['address'] + address = state.get_current_instruction()["address"] meminstart = get_variable(state.mstate.stack[-3]) if meminstart.type == VarType.CONCRETE: @@ -41,17 +41,23 @@ def execute(statespace): def _concrete_call(call, state, address, meminstart): - if not re.search(r'calldata.*_0', str(state.mstate.memory[meminstart.val])): + if not re.search(r"calldata.*_0", str(state.mstate.memory[meminstart.val])): return [] - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, address=address, - swc_id=DELEGATECALL_TO_UNTRUSTED_CONTRACT, title="Call data forwarded with delegatecall()", - _type="Informational") - - issue.description = \ - "This contract forwards its call data via DELEGATECALL in its fallback function. " \ - "This means that any function in the called contract can be executed. Note that the callee contract will have " \ + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + swc_id=DELEGATECALL_TO_UNTRUSTED_CONTRACT, + title="Call data forwarded with delegatecall()", + _type="Informational", + ) + + issue.description = ( + "This contract forwards its call data via DELEGATECALL in its fallback function. " + "This means that any function in the called contract can be executed. Note that the callee contract will have " "access to the storage of the calling contract.\n " + ) target = hex(call.to.val) if call.to.type == VarType.CONCRETE else str(call.to) issue.description += "DELEGATECALL target: {}".format(target) @@ -60,23 +66,34 @@ def _concrete_call(call, state, address, meminstart): def _symbolic_call(call, state, address, statespace): - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, address=address, - swc_id=DELEGATECALL_TO_UNTRUSTED_CONTRACT, title=call.type + " to a user-supplied address") + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + swc_id=DELEGATECALL_TO_UNTRUSTED_CONTRACT, + title=call.type + " to a user-supplied address", + ) if "calldata" in str(call.to): - issue.description = \ - "This contract delegates execution to a contract address obtained from calldata. " + issue.description = "This contract delegates execution to a contract address obtained from calldata. " else: - m = re.search(r'storage_([a-z0-9_&^]+)', str(call.to)) + m = re.search(r"storage_([a-z0-9_&^]+)", str(call.to)) if m: idx = m.group(1) - func = statespace.find_storage_write(state.environment.active_account.address, idx) + func = statespace.find_storage_write( + state.environment.active_account.address, idx + ) if func: - issue.description = "This contract delegates execution to a contract address in storage slot " + str( - idx) + ". This storage slot can be written to by calling the function `" + func + "`. " + issue.description = ( + "This contract delegates execution to a contract address in storage slot " + + str(idx) + + ". This storage slot can be written to by calling the function `" + + func + + "`. " + ) else: logging.debug("[DELEGATECALL] No storage writes to index " + str(idx)) diff --git a/mythril/analysis/modules/dependence_on_predictable_vars.py b/mythril/analysis/modules/dependence_on_predictable_vars.py index 56cc3156..f3fcf41a 100644 --- a/mythril/analysis/modules/dependence_on_predictable_vars.py +++ b/mythril/analysis/modules/dependence_on_predictable_vars.py @@ -7,7 +7,7 @@ from mythril.analysis.swc_data import TIMESTAMP_DEPENDENCE, PREDICTABLE_VARS_DEP from mythril.exceptions import UnsatError import logging -''' +""" MODULE DESCRIPTION: Check for CALLs that send >0 Ether as a result of computation based on predictable variables such as @@ -17,7 +17,7 @@ TODO: - block.blockhash(block.number-1) - block.blockhash(some_block_past_256_blocks_from_now)==0 - external source of random numbers (e.g. Oraclize) -''' +""" def execute(statespace): @@ -37,7 +37,7 @@ def execute(statespace): if call.value.type == VarType.CONCRETE and call.value.val == 0: continue - address = call.state.get_current_instruction()['address'] + address = call.state.get_current_instruction()["address"] description = "In the function `" + call.node.function_name + "` " description += "the following predictable state variables are used to determine Ether recipient:\n" @@ -56,10 +56,20 @@ def execute(statespace): for item in found: description += "- block.{}\n".format(item) if solve(call): - swc_type = TIMESTAMP_DEPENDENCE if item == 'timestamp' else PREDICTABLE_VARS_DEPENDENCE - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, address=address, - swc_id=swc_type, title="Dependence on predictable environment variable", - _type="Warning", description=description) + swc_type = ( + TIMESTAMP_DEPENDENCE + if item == "timestamp" + else PREDICTABLE_VARS_DEPENDENCE + ) + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + swc_id=swc_type, + title="Dependence on predictable environment variable", + _type="Warning", + description=description, + ) issues.append(issue) # Second check: blockhash @@ -68,48 +78,75 @@ def execute(statespace): if "blockhash" in str(constraint): description = "In the function `" + call.node.function_name + "` " if "number" in str(constraint): - m = re.search(r'blockhash\w+(\s-\s(\d+))*', str(constraint)) + m = re.search(r"blockhash\w+(\s-\s(\d+))*", str(constraint)) if m and solve(call): found = m.group(1) if found: # block.blockhash(block.number - N) - description += "predictable expression 'block.blockhash(block.number - " + m.group(2) + \ - ")' is used to determine Ether recipient" + description += ( + "predictable expression 'block.blockhash(block.number - " + + m.group(2) + + ")' is used to determine Ether recipient" + ) if int(m.group(2)) > 255: - description += ", this expression will always be equal to zero." - elif "storage" in str(constraint): # block.blockhash(block.number - storage_0) - description += "predictable expression 'block.blockhash(block.number - " + \ - "some_storage_var)' is used to determine Ether recipient" + description += ( + ", this expression will always be equal to zero." + ) + elif "storage" in str( + constraint + ): # block.blockhash(block.number - storage_0) + description += ( + "predictable expression 'block.blockhash(block.number - " + + "some_storage_var)' is used to determine Ether recipient" + ) else: # block.blockhash(block.number) - description += "predictable expression 'block.blockhash(block.number)'" + \ - " is used to determine Ether recipient" - description += ", this expression will always be equal to zero." - - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, - address=address, title="Dependence on predictable variable", - _type="Warning", description=description, swc_id=PREDICTABLE_VARS_DEPENDENCE) + description += ( + "predictable expression 'block.blockhash(block.number)'" + + " is used to determine Ether recipient" + ) + description += ( + ", this expression will always be equal to zero." + ) + + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="Dependence on predictable variable", + _type="Warning", + description=description, + swc_id=PREDICTABLE_VARS_DEPENDENCE, + ) issues.append(issue) break else: - r = re.search(r'storage_([a-z0-9_&^]+)', str(constraint)) + r = re.search(r"storage_([a-z0-9_&^]+)", str(constraint)) if r: # block.blockhash(storage_0) - ''' + """ We actually can do better here by adding a constraint blockhash_block_storage_0 == 0 and checking model satisfiability. When this is done, severity can be raised from 'Informational' to 'Warning'. Checking that storage at given index can be tainted is not necessary, since it usually contains block.number of the 'commit' transaction in commit-reveal workflow. - ''' + """ index = r.group(1) if index and solve(call): - description += 'block.blockhash() is calculated using a value from storage ' \ - 'at index {}'.format(index) - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, - address=address, title="Dependence on predictable variable", - _type="Informational", description=description, swc_id=PREDICTABLE_VARS_DEPENDENCE) + description += ( + "block.blockhash() is calculated using a value from storage " + "at index {}".format(index) + ) + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="Dependence on predictable variable", + _type="Informational", + description=description, + swc_id=PREDICTABLE_VARS_DEPENDENCE, + ) issues.append(issue) break return issues @@ -121,7 +158,10 @@ def solve(call): logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] MODEL: " + str(model)) for d in model.decls(): - logging.debug("[DEPENDENCE_ON_PREDICTABLE_VARS] main model: %s = 0x%x" % (d.name(), model[d].as_long())) + logging.debug( + "[DEPENDENCE_ON_PREDICTABLE_VARS] main model: %s = 0x%x" + % (d.name(), model[d].as_long()) + ) return True except UnsatError: diff --git a/mythril/analysis/modules/deprecated_ops.py b/mythril/analysis/modules/deprecated_ops.py index 2b187e2d..ec38e272 100644 --- a/mythril/analysis/modules/deprecated_ops.py +++ b/mythril/analysis/modules/deprecated_ops.py @@ -3,11 +3,11 @@ from mythril.analysis.swc_data import TX_ORIGIN_USAGE import logging -''' +""" MODULE DESCRIPTION: Check for constraints on tx.origin (i.e., access to some functionality is restricted to a specific origin). -''' +""" def execute(statespace): @@ -23,14 +23,24 @@ def execute(statespace): instruction = state.get_current_instruction() - if instruction['opcode'] == "ORIGIN": - description = "The function `{}` retrieves the transaction origin (tx.origin) using the ORIGIN opcode. " \ - "Use msg.sender instead.\nSee also: " \ - "https://solidity.readthedocs.io/en/develop/security-considerations.html#tx-origin".format(node.function_name) - - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - title="Use of tx.origin", _type="Warning", swc_id=TX_ORIGIN_USAGE, - description=description) + if instruction["opcode"] == "ORIGIN": + description = ( + "The function `{}` retrieves the transaction origin (tx.origin) using the ORIGIN opcode. " + "Use msg.sender instead.\nSee also: " + "https://solidity.readthedocs.io/en/develop/security-considerations.html#tx-origin".format( + node.function_name + ) + ) + + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + title="Use of tx.origin", + _type="Warning", + swc_id=TX_ORIGIN_USAGE, + description=description, + ) issues.append(issue) return issues diff --git a/mythril/analysis/modules/ether_send.py b/mythril/analysis/modules/ether_send.py index bfb0d057..fa967f2f 100644 --- a/mythril/analysis/modules/ether_send.py +++ b/mythril/analysis/modules/ether_send.py @@ -8,13 +8,13 @@ import re import logging -''' +""" MODULE DESCRIPTION: Check for CALLs that send >0 Ether to either the transaction sender, or to an address provided as a function argument. If msg.sender is checked against a value in storage, check whether that storage index is tainted (i.e. there's an unconstrained write to that index). -''' +""" def execute(statespace): @@ -26,7 +26,7 @@ def execute(statespace): for call in statespace.calls: state = call.state - address = state.get_current_instruction()['address'] + address = state.get_current_instruction()["address"] if "callvalue" in str(call.value): logging.debug("[ETHER_SEND] Skipping refund function") @@ -41,26 +41,38 @@ def execute(statespace): description = "In the function `" + call.node.function_name + "` " - if re.search(r'caller', str(call.to)): + if re.search(r"caller", str(call.to)): description += "a non-zero amount of Ether is sent to msg.sender.\n" interesting = True - elif re.search(r'calldata', str(call.to)): + elif re.search(r"calldata", str(call.to)): description += "a non-zero amount of Ether is sent to an address taken from function arguments.\n" interesting = True else: - m = re.search(r'storage_([a-z0-9_&^]+)', str(call.to)) + m = re.search(r"storage_([a-z0-9_&^]+)", str(call.to)) if m: idx = m.group(1) - description += "a non-zero amount of Ether is sent to an address taken from storage slot " + str(idx) + ".\n" + description += ( + "a non-zero amount of Ether is sent to an address taken from storage slot " + + str(idx) + + ".\n" + ) - func = statespace.find_storage_write(state.environment.active_account.address, idx) + func = statespace.find_storage_write( + state.environment.active_account.address, idx + ) if func: - description += "There is a check on storage index " + str(idx) + ". This storage slot can be written to by calling the function `" + func + "`.\n" + description += ( + "There is a check on storage index " + + str(idx) + + ". This storage slot can be written to by calling the function `" + + func + + "`.\n" + ) interesting = True else: logging.debug("[ETHER_SEND] No storage writes to index " + str(idx)) @@ -80,31 +92,45 @@ def execute(statespace): index += 1 logging.debug("[ETHER_SEND] Constraint: " + str(constraint)) - m = re.search(r'storage_([a-z0-9_&^]+)', str(constraint)) + m = re.search(r"storage_([a-z0-9_&^]+)", str(constraint)) if m: constrained = True idx = m.group(1) - func = statespace.find_storage_write(state.environment.active_account.address, idx) + func = statespace.find_storage_write( + state.environment.active_account.address, idx + ) if func: - description += "\nThere is a check on storage index " + str(idx) + ". This storage slot can be written to by calling the function `" + func + "`." + description += ( + "\nThere is a check on storage index " + + str(idx) + + ". This storage slot can be written to by calling the function `" + + func + + "`." + ) else: - logging.debug("[ETHER_SEND] No storage writes to index " + str(idx)) + logging.debug( + "[ETHER_SEND] No storage writes to index " + str(idx) + ) can_solve = False break # CALLER may also be constrained to hardcoded address. I.e. 'caller' and some integer - elif re.search(r"caller", str(constraint)) and re.search(r'[0-9]{20}', str(constraint)): + elif re.search(r"caller", str(constraint)) and re.search( + r"[0-9]{20}", str(constraint) + ): constrained = True can_solve = False break if not constrained: - description += "It seems that this function can be called without restrictions." + description += ( + "It seems that this function can be called without restrictions." + ) if can_solve: @@ -112,13 +138,23 @@ def execute(statespace): model = solver.get_model(node.constraints) for d in model.decls(): - logging.debug("[ETHER_SEND] main model: %s = 0x%x" % (d.name(), model[d].as_long())) + logging.debug( + "[ETHER_SEND] main model: %s = 0x%x" + % (d.name(), model[d].as_long()) + ) debug = "SOLVER OUTPUT:\n" + solver.pretty_print_model(model) - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, address=address, - title="Ether send", _type="Warning", swc_id=UNPROTECTED_ETHER_WITHDRAWAL, - description=description, debug=debug) + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="Ether send", + _type="Warning", + swc_id=UNPROTECTED_ETHER_WITHDRAWAL, + description=description, + debug=debug, + ) issues.append(issue) except UnsatError: diff --git a/mythril/analysis/modules/exceptions.py b/mythril/analysis/modules/exceptions.py index ac0723e0..a9e63205 100644 --- a/mythril/analysis/modules/exceptions.py +++ b/mythril/analysis/modules/exceptions.py @@ -5,12 +5,12 @@ from mythril.analysis import solver import logging -''' +""" MODULE DESCRIPTION: Checks whether any exception states are reachable. -''' +""" def execute(statespace): @@ -25,25 +25,40 @@ def execute(statespace): for state in node.states: instruction = state.get_current_instruction() - if instruction['opcode'] == "ASSERT_FAIL": + if instruction["opcode"] == "ASSERT_FAIL": try: model = solver.get_model(node.constraints) - address = state.get_current_instruction()['address'] - - description = "A reachable exception (opcode 0xfe) has been detected. " \ - "This can be caused by type errors, division by zero, " \ - "out-of-bounds array access, or assert violations. " - description += "This is acceptable in most situations. " \ - "Note however that `assert()` should only be used to check invariants. " \ - "Use `require()` for regular input checking. " - - debug = "The exception is triggered under the following conditions:\n\n" + address = state.get_current_instruction()["address"] + + description = ( + "A reachable exception (opcode 0xfe) has been detected. " + "This can be caused by type errors, division by zero, " + "out-of-bounds array access, or assert violations. " + ) + description += ( + "This is acceptable in most situations. " + "Note however that `assert()` should only be used to check invariants. " + "Use `require()` for regular input checking. " + ) + + debug = ( + "The exception is triggered under the following conditions:\n\n" + ) debug += solver.pretty_print_model(model) - issues.append(Issue(contract=node.contract_name, function=node.function_name, address=address, - swc_id=ASSERT_VIOLATION, title="Exception state", _type="Informational", - description=description, debug=debug)) + issues.append( + Issue( + contract=node.contract_name, + function=node.function_name, + address=address, + swc_id=ASSERT_VIOLATION, + title="Exception state", + _type="Informational", + description=description, + debug=debug, + ) + ) except UnsatError: logging.debug("[EXCEPTIONS] no model found") diff --git a/mythril/analysis/modules/external_calls.py b/mythril/analysis/modules/external_calls.py index 2f9350af..1b45205a 100644 --- a/mythril/analysis/modules/external_calls.py +++ b/mythril/analysis/modules/external_calls.py @@ -7,11 +7,11 @@ import re import logging -''' +""" MODULE DESCRIPTION: Check for call.value()() to external addresses -''' +""" MAX_SEARCH_DEPTH = 64 @@ -28,8 +28,8 @@ def search_children(statespace, node, start_index=0, depth=0, results=None): if n_states > start_index: for j in range(start_index, n_states): - if node.states[j].get_current_instruction()['opcode'] == 'SSTORE': - results.append(node.states[j].get_current_instruction()['address']) + if node.states[j].get_current_instruction()["opcode"] == "SSTORE": + results.append(node.states[j].get_current_instruction()["address"]) children = [] @@ -39,7 +39,9 @@ def search_children(statespace, node, start_index=0, depth=0, results=None): if len(children): for node in children: - return search_children(statespace, node, depth=depth + 1, results=results) + return search_children( + statespace, node, depth=depth + 1, results=results + ) return results @@ -54,13 +56,20 @@ def execute(statespace): for call in statespace.calls: state = call.state - address = state.get_current_instruction()['address'] + address = state.get_current_instruction()["address"] if call.type == "CALL": - logging.info("[EXTERNAL_CALLS] Call to: %s, value = %s, gas = %s" % (str(call.to), str(call.value), str(call.gas))) + logging.info( + "[EXTERNAL_CALLS] Call to: %s, value = %s, gas = %s" + % (str(call.to), str(call.value), str(call.gas)) + ) - if call.to.type == VarType.SYMBOLIC and (call.gas.type == VarType.CONCRETE and call.gas.val > 2300) or (call.gas.type == VarType.SYMBOLIC and "2300" not in str(call.gas)): + if ( + call.to.type == VarType.SYMBOLIC + and (call.gas.type == VarType.CONCRETE and call.gas.val > 2300) + or (call.gas.type == VarType.SYMBOLIC and "2300" not in str(call.gas)) + ): description = "This contract executes a message call to " @@ -76,59 +85,96 @@ def execute(statespace): user_supplied = True else: - m = re.search(r'storage_([a-z0-9_&^]+)', str(call.to)) + m = re.search(r"storage_([a-z0-9_&^]+)", str(call.to)) if m: idx = m.group(1) - func = statespace.find_storage_write(state.environment.active_account.address, idx) + func = statespace.find_storage_write( + state.environment.active_account.address, idx + ) if func: - description += \ - "an address found at storage slot " + str(idx) + ". " + \ - "This storage slot can be written to by calling the function `" + func + "`. " + description += ( + "an address found at storage slot " + + str(idx) + + ". " + + "This storage slot can be written to by calling the function `" + + func + + "`. " + ) user_supplied = True if user_supplied: - description += "Generally, it is not recommended to call user-supplied addresses using Solidity's call() construct. " \ - "Note that attackers might leverage reentrancy attacks to exploit race conditions or manipulate this contract's state." - - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, - address=address, title="Message call to external contract", _type="Warning", - description=description, swc_id=REENTRANCY) + description += ( + "Generally, it is not recommended to call user-supplied addresses using Solidity's call() construct. " + "Note that attackers might leverage reentrancy attacks to exploit race conditions or manipulate this contract's state." + ) + + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="Message call to external contract", + _type="Warning", + description=description, + swc_id=REENTRANCY, + ) else: description += "to another contract. Make sure that the called contract is trusted and does not execute user-supplied code." - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, address=address, - title="Message call to external contract", _type="Informational", - description=description, swc_id=REENTRANCY) + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="Message call to external contract", + _type="Informational", + description=description, + swc_id=REENTRANCY, + ) issues.append(issue) if address not in calls_visited: calls_visited.append(address) - logging.debug("[EXTERNAL_CALLS] Checking for state changes starting from " + call.node.function_name) + logging.debug( + "[EXTERNAL_CALLS] Checking for state changes starting from " + + call.node.function_name + ) # Check for SSTORE in remaining instructions in current node & nodes down the CFG - state_change_addresses = search_children(statespace, call.node, call.state_index + 1, depth=0, results=[]) + state_change_addresses = search_children( + statespace, call.node, call.state_index + 1, depth=0, results=[] + ) - logging.debug("[EXTERNAL_CALLS] Detected state changes at addresses: " + str(state_change_addresses)) + logging.debug( + "[EXTERNAL_CALLS] Detected state changes at addresses: " + + str(state_change_addresses) + ) if len(state_change_addresses): for address in state_change_addresses: - description = "The contract account state is changed after an external call. " \ - "Consider that the called contract could re-enter the function before this " \ - "state change takes place. This can lead to business logic vulnerabilities." - - issue = Issue(contract=call.node.contract_name, function=call.node.function_name, - address=address, title="State change after external call", _type="Warning", - description=description, swc_id=REENTRANCY) + description = ( + "The contract account state is changed after an external call. " + "Consider that the called contract could re-enter the function before this " + "state change takes place. This can lead to business logic vulnerabilities." + ) + + issue = Issue( + contract=call.node.contract_name, + function=call.node.function_name, + address=address, + title="State change after external call", + _type="Warning", + description=description, + swc_id=REENTRANCY, + ) issues.append(issue) return issues diff --git a/mythril/analysis/modules/integer.py b/mythril/analysis/modules/integer.py index 48707c11..091b8cfb 100644 --- a/mythril/analysis/modules/integer.py +++ b/mythril/analysis/modules/integer.py @@ -9,13 +9,13 @@ import re import copy import logging -''' +""" MODULE DESCRIPTION: Check for integer underflows. For every SUB instruction, check if there's a possible state where op1 > op0. For every ADD, MUL instruction, check if there's a possible state where op1 + op0 > 2^32 - 1 -''' +""" def execute(statespace): @@ -50,7 +50,7 @@ def _check_integer_overflow(statespace, state, node): # Check the instruction instruction = state.get_current_instruction() - if instruction['opcode'] not in ("ADD", "MUL"): + if instruction["opcode"] not in ("ADD", "MUL"): return issues # Formulate overflow constraints @@ -70,7 +70,7 @@ def _check_integer_overflow(statespace, state, node): op1 = BitVecVal(op1, 256) # Formulate expression - if instruction['opcode'] == "ADD": + if instruction["opcode"] == "ADD": expr = op0 + op1 else: expr = op1 * op0 @@ -83,27 +83,42 @@ def _check_integer_overflow(statespace, state, node): logging.debug("[INTEGER_OVERFLOW] no model found") return issues - if not _verify_integer_overflow(statespace, node, expr, state, model, constraint, op0, op1): + if not _verify_integer_overflow( + statespace, node, expr, state, model, constraint, op0, op1 + ): return issues # Build issue - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - swc_id=INTEGER_OVERFLOW_AND_UNDERFLOW, title="Integer Overflow", _type="Warning") - - issue.description = "A possible integer overflow exists in the function `{}`.\n" \ - "The addition or multiplication may result in a value higher than the maximum representable integer.".format( - node.function_name) + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + swc_id=INTEGER_OVERFLOW_AND_UNDERFLOW, + title="Integer Overflow", + _type="Warning", + ) + + issue.description = ( + "A possible integer overflow exists in the function `{}`.\n" + "The addition or multiplication may result in a value higher than the maximum representable integer.".format( + node.function_name + ) + ) issue.debug = solver.pretty_print_model(model) issues.append(issue) return issues -def _verify_integer_overflow(statespace, node, expr, state, model, constraint, op0, op1): +def _verify_integer_overflow( + statespace, node, expr, state, model, constraint, op0, op1 +): """ Verifies existence of integer overflow """ # If we get to this point then there has been an integer overflow # Find out if the overflowed value is actually used - interesting_usages = _search_children(statespace, node, expr, constraint=[constraint], index=node.states.index(state)) + interesting_usages = _search_children( + statespace, node, expr, constraint=[constraint], index=node.states.index(state) + ) # Stop if it isn't if len(interesting_usages) == 0: @@ -111,6 +126,7 @@ def _verify_integer_overflow(statespace, node, expr, state, model, constraint, o return _try_constraints(node.constraints, [Not(constraint)]) is not None + def _try_constraints(constraints, new_constraints): """ Tries new constraints @@ -135,7 +151,7 @@ def _check_integer_underflow(statespace, state, node): """ issues = [] instruction = state.get_current_instruction() - if instruction['opcode'] == "SUB": + if instruction["opcode"] == "SUB": stack = state.mstate.stack @@ -150,15 +166,22 @@ def _check_integer_underflow(statespace, state, node): # Pattern 2: (256*If(1 & storage_0 == 0, 1, 0)) - 1, this would underlow if storage_0 = 0 if type(op0) == int and type(op1) == int: return [] - if re.search(r'calldatasize_', str(op0)): + if re.search(r"calldatasize_", str(op0)): return [] - if re.search(r'256\*.*If\(1', str(op0), re.DOTALL) or re.search(r'256\*.*If\(1', str(op1), re.DOTALL): + if re.search(r"256\*.*If\(1", str(op0), re.DOTALL) or re.search( + r"256\*.*If\(1", str(op1), re.DOTALL + ): return [] - if re.search(r'32 \+.*calldata', str(op0), re.DOTALL) or re.search(r'32 \+.*calldata', str(op1), re.DOTALL): + if re.search(r"32 \+.*calldata", str(op0), re.DOTALL) or re.search( + r"32 \+.*calldata", str(op1), re.DOTALL + ): return [] - logging.debug("[INTEGER_UNDERFLOW] Checking SUB {0}, {1} at address {2}".format(str(op0), str(op1), - str(instruction['address']))) + logging.debug( + "[INTEGER_UNDERFLOW] Checking SUB {0}, {1} at address {2}".format( + str(op0), str(op1), str(instruction["address"]) + ) + ) allowed_types = [int, BitVecRef, BitVecNumRef] if type(op0) in allowed_types and type(op1) in allowed_types: @@ -170,17 +193,29 @@ def _check_integer_underflow(statespace, state, node): # If we get to this point then there has been an integer overflow # Find out if the overflowed value is actually used - interesting_usages = _search_children(statespace, node, (op0 - op1), index=node.states.index(state)) + interesting_usages = _search_children( + statespace, node, (op0 - op1), index=node.states.index(state) + ) # Stop if it isn't if len(interesting_usages) == 0: return issues - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - swc_id=INTEGER_OVERFLOW_AND_UNDERFLOW, title="Integer Underflow", _type="Warning") - - issue.description = "A possible integer underflow exists in the function `" + node.function_name + "`.\n" \ - "The subtraction may result in a value < 0." + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + swc_id=INTEGER_OVERFLOW_AND_UNDERFLOW, + title="Integer Underflow", + _type="Warning", + ) + + issue.description = ( + "A possible integer underflow exists in the function `" + + node.function_name + + "`.\n" + "The subtraction may result in a value < 0." + ) issue.debug = solver.pretty_print_model(model) issues.append(issue) @@ -192,29 +227,39 @@ def _check_integer_underflow(statespace, state, node): def _check_usage(state, taint_result): """Delegates checks to _check_{instruction_name}()""" - opcode = state.get_current_instruction()['opcode'] + opcode = state.get_current_instruction()["opcode"] - if opcode == 'JUMPI': + if opcode == "JUMPI": if _check_jumpi(state, taint_result): return [state] - elif opcode == 'SSTORE': + elif opcode == "SSTORE": if _check_sstore(state, taint_result): return [state] return [] + def _check_jumpi(state, taint_result): """ Check if conditional jump is dependent on the result of expression""" - assert state.get_current_instruction()['opcode'] == 'JUMPI' + assert state.get_current_instruction()["opcode"] == "JUMPI" return taint_result.check(state, -2) def _check_sstore(state, taint_result): """ Check if store operation is dependent on the result of expression""" - assert state.get_current_instruction()['opcode'] == 'SSTORE' + assert state.get_current_instruction()["opcode"] == "SSTORE" return taint_result.check(state, -2) -def _search_children(statespace, node, expression, taint_result=None, constraint=None, index=0, depth=0, max_depth=64): +def _search_children( + statespace, + node, + expression, + taint_result=None, + constraint=None, + index=0, + depth=0, + max_depth=64, +): """ Checks the statespace for children states, with JUMPI or SSTORE instuctions, for dependency on expression @@ -236,7 +281,9 @@ def _search_children(statespace, node, expression, taint_result=None, constraint state = node.states[index] taint_stack = [False for _ in state.mstate.stack] taint_stack[-1] = True - taint_result = TaintRunner.execute(statespace, node, state, initial_stack=taint_stack) + taint_result = TaintRunner.execute( + statespace, node, state, initial_stack=taint_stack + ) results = [] @@ -247,25 +294,31 @@ def _search_children(statespace, node, expression, taint_result=None, constraint for j in range(index, len(node.states)): current_state = node.states[j] current_instruction = current_state.get_current_instruction() - if current_instruction['opcode'] in ('JUMPI', 'SSTORE'): + if current_instruction["opcode"] in ("JUMPI", "SSTORE"): element = _check_usage(current_state, taint_result) if len(element) < 1: continue if _check_requires(element[0], node, statespace, constraint): - continue + continue results += element # Recursively search children - children = \ - [ - statespace.nodes[edge.node_to] - for edge in statespace.edges - if edge.node_from == node.uid - # and _try_constraints(statespace.nodes[edge.node_to].constraints, constraint) is not None - ] + children = [ + statespace.nodes[edge.node_to] + for edge in statespace.edges + if edge.node_from == node.uid + # and _try_constraints(statespace.nodes[edge.node_to].constraints, constraint) is not None + ] for child in children: - results += _search_children(statespace, child, expression, taint_result, depth=depth + 1, max_depth=max_depth) + results += _search_children( + statespace, + child, + expression, + taint_result, + depth=depth + 1, + max_depth=max_depth, + ) return results @@ -273,16 +326,16 @@ def _search_children(statespace, node, expression, taint_result=None, constraint def _check_requires(state, node, statespace, constraint): """Checks if usage of overflowed statement results in a revert statement""" instruction = state.get_current_instruction() - if instruction['opcode'] is not "JUMPI": + if instruction["opcode"] is not "JUMPI": return False children = [ - statespace.nodes[edge.node_to] - for edge in statespace.edges - if edge.node_from == node.uid - ] + statespace.nodes[edge.node_to] + for edge in statespace.edges + if edge.node_from == node.uid + ] for child in children: - opcodes = [s.get_current_instruction()['opcode'] for s in child.states] + opcodes = [s.get_current_instruction()["opcode"] for s in child.states] if "REVERT" in opcodes or "ASSERT_FAIL" in opcodes: return True # I added the following case, bc of false positives if the max depth is not high enough diff --git a/mythril/analysis/modules/multiple_sends.py b/mythril/analysis/modules/multiple_sends.py index f00a5064..459405dd 100644 --- a/mythril/analysis/modules/multiple_sends.py +++ b/mythril/analysis/modules/multiple_sends.py @@ -1,6 +1,7 @@ from mythril.analysis.report import Issue from mythril.analysis.swc_data import * from mythril.laser.ethereum.cfg import JumpType + """ MODULE DESCRIPTION: @@ -21,16 +22,24 @@ def execute(statespace): if len(findings) > 0: node = call.node instruction = call.state.get_current_instruction() - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - swc_id=MULTIPLE_SENDS, title="Multiple Calls", _type="Informational") - - issue.description = \ - "Multiple sends exist in one transaction. Try to isolate each external call into its own transaction," \ + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + swc_id=MULTIPLE_SENDS, + title="Multiple Calls", + _type="Informational", + ) + + issue.description = ( + "Multiple sends exist in one transaction. Try to isolate each external call into its own transaction," " as external calls can fail accidentally or deliberately.\nConsecutive calls: \n" + ) for finding in findings: - issue.description += \ - "Call at address: {}\n".format(finding.state.get_current_instruction()['address']) + issue.description += "Call at address: {}\n".format( + finding.state.get_current_instruction()["address"] + ) issues.append(issue) return issues @@ -38,21 +47,30 @@ def execute(statespace): def _explore_nodes(call, statespace): children = _child_nodes(statespace, call.node) - sending_children = list(filter(lambda call: call.node in children, statespace.calls)) + sending_children = list( + filter(lambda call: call.node in children, statespace.calls) + ) return sending_children def _explore_states(call, statespace): other_calls = list( - filter(lambda other: other.node == call.node and other.state_index > call.state_index, statespace.calls) + filter( + lambda other: other.node == call.node + and other.state_index > call.state_index, + statespace.calls, ) + ) return other_calls def _child_nodes(statespace, node): result = [] - children = [statespace.nodes[edge.node_to] for edge in statespace.edges if edge.node_from == node.uid - and edge.type != JumpType.Transaction] + children = [ + statespace.nodes[edge.node_to] + for edge in statespace.edges + if edge.node_from == node.uid and edge.type != JumpType.Transaction + ] for child in children: result.append(child) diff --git a/mythril/analysis/modules/suicide.py b/mythril/analysis/modules/suicide.py index e2185fe5..e2955486 100644 --- a/mythril/analysis/modules/suicide.py +++ b/mythril/analysis/modules/suicide.py @@ -6,12 +6,12 @@ from mythril.exceptions import UnsatError import logging -''' +""" MODULE DESCRIPTION: Check for SUICIDE instructions that either can be reached by anyone, or where msg.sender is checked against a tainted storage index (i.e. there's a write to that index is unconstrained by msg.sender). -''' +""" def execute(state_space): @@ -33,13 +33,15 @@ def _analyze_state(state, node): issues = [] instruction = state.get_current_instruction() - if instruction['opcode'] != "SUICIDE": + if instruction["opcode"] != "SUICIDE": return [] to = state.mstate.stack[-1] logging.debug("[UNCHECKED_SUICIDE] suicide in function " + node.function_name) - description = "The function `" + node.function_name + "` executes the SUICIDE instruction. " + description = ( + "The function `" + node.function_name + "` executes the SUICIDE instruction. " + ) if "caller" in str(to): description += "The remaining Ether is sent to the caller's address.\n" @@ -56,19 +58,30 @@ def _analyze_state(state, node): if len(state.world_state.transaction_sequence) > 1: creator = state.world_state.transaction_sequence[0].caller for transaction in state.world_state.transaction_sequence[1:]: - not_creator_constraints.append(Not(Extract(159, 0, transaction.caller) == Extract(159, 0, creator))) - not_creator_constraints.append(Not(Extract(159, 0, transaction.caller) == 0)) + not_creator_constraints.append( + Not(Extract(159, 0, transaction.caller) == Extract(159, 0, creator)) + ) + not_creator_constraints.append( + Not(Extract(159, 0, transaction.caller) == 0) + ) try: model = solver.get_model(node.constraints + not_creator_constraints) debug = "SOLVER OUTPUT:\n" + solver.pretty_print_model(model) - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - swc_id=UNPROTECTED_SELFDESTRUCT, title="Unchecked SUICIDE", _type="Warning", - description=description, debug=debug) + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + swc_id=UNPROTECTED_SELFDESTRUCT, + title="Unchecked SUICIDE", + _type="Warning", + description=description, + debug=debug, + ) issues.append(issue) except UnsatError: - logging.debug("[UNCHECKED_SUICIDE] no model found") + logging.debug("[UNCHECKED_SUICIDE] no model found") return issues diff --git a/mythril/analysis/modules/transaction_order_dependence.py b/mythril/analysis/modules/transaction_order_dependence.py index f6621293..b8d34cd2 100644 --- a/mythril/analysis/modules/transaction_order_dependence.py +++ b/mythril/analysis/modules/transaction_order_dependence.py @@ -7,12 +7,12 @@ from mythril.analysis.report import Issue from mythril.analysis.swc_data import TX_ORDER_DEPENDENCE from mythril.exceptions import UnsatError -''' +""" MODULE DESCRIPTION: This module finds the existance of transaction order dependence vulnerabilities. The following webpage contains an extensive description of the vulnerability: https://consensys.github.io/smart-contract-best-practices/known_attacks/#transaction-ordering-dependence-tod-front-running -''' +""" def execute(statespace): @@ -24,19 +24,29 @@ def execute(statespace): for call in statespace.calls: # Do analysis interesting_storages = list(_get_influencing_storages(call)) - changing_sstores = list(_get_influencing_sstores(statespace, interesting_storages)) + changing_sstores = list( + _get_influencing_sstores(statespace, interesting_storages) + ) # Build issue if necessary if len(changing_sstores) > 0: node = call.node instruction = call.state.get_current_instruction() - issue = Issue(contract=node.contract_name, function=node.function_name, address=instruction['address'], - title="Transaction order dependence", swc_id=TX_ORDER_DEPENDENCE, _type="Warning") - - issue.description = \ - "A possible transaction order dependence vulnerability exists in function {}. The value or " \ - "direction of the call statement is determined from a tainted storage location"\ - .format(node.function_name) + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=instruction["address"], + title="Transaction order dependence", + swc_id=TX_ORDER_DEPENDENCE, + _type="Warning", + ) + + issue.description = ( + "A possible transaction order dependence vulnerability exists in function {}. The value or " + "direction of the call statement is determined from a tainted storage location".format( + node.function_name + ) + ) issues.append(issue) return issues @@ -65,7 +75,7 @@ def _get_storage_variable(storage, state): :param state: state to retrieve the variable from :return: z3 object representing storage """ - index = int(re.search('[0-9]+', storage).group()) + index = int(re.search("[0-9]+", storage).group()) try: return state.environment.active_account.storage[index] except KeyError: @@ -85,6 +95,7 @@ def _can_change(constraints, variable): except AttributeError: return False + def _get_influencing_storages(call): """ Examines a Call object and returns an iterator of all storages that influence the call value or direction""" state = call.state @@ -108,7 +119,7 @@ def _get_influencing_storages(call): def _get_influencing_sstores(statespace, interesting_storages): """ Gets sstore (state, node) tuples that write to interesting_storages""" - for sstore_state, node in _get_states_with_opcode(statespace, 'SSTORE'): + for sstore_state, node in _get_states_with_opcode(statespace, "SSTORE"): index, value = sstore_state.mstate.stack[-1], sstore_state.mstate.stack[-2] try: index = util.get_concrete_int(index) diff --git a/mythril/analysis/modules/unchecked_retval.py b/mythril/analysis/modules/unchecked_retval.py index 5ee2b327..4effcbbe 100644 --- a/mythril/analysis/modules/unchecked_retval.py +++ b/mythril/analysis/modules/unchecked_retval.py @@ -6,7 +6,7 @@ import logging import re -''' +""" MODULE DESCRIPTION: Test whether CALL return value is checked. @@ -22,7 +22,7 @@ For low-level-calls this check is omitted. E.g.: c.call.value(0)(bytes4(sha3("ping(uint256)")),1); -''' +""" def execute(statespace): @@ -43,19 +43,27 @@ def execute(statespace): instr = state.get_current_instruction() - if instr['opcode'] == 'ISZERO' and re.search(r'retval', str(state.mstate.stack[-1])): + if instr["opcode"] == "ISZERO" and re.search( + r"retval", str(state.mstate.stack[-1]) + ): retval_checked = True break if not retval_checked: - address = state.get_current_instruction()['address'] - issue = Issue(contract=node.contract_name, function=node.function_name, address=address, - title="Unchecked CALL return value", swc_id=UNCHECKED_RET_VAL) - - issue.description = \ - "The return value of an external call is not checked. " \ + address = state.get_current_instruction()["address"] + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=address, + title="Unchecked CALL return value", + swc_id=UNCHECKED_RET_VAL, + ) + + issue.description = ( + "The return value of an external call is not checked. " "Note that execution continue even if the called contract throws." + ) issues.append(issue) @@ -63,12 +71,14 @@ def execute(statespace): n_states = len(node.states) - for idx in range(0, n_states - 1): # Ignore CALLs at last position in a node + for idx in range( + 0, n_states - 1 + ): # Ignore CALLs at last position in a node state = node.states[idx] instr = state.get_current_instruction() - if instr['opcode'] == 'CALL': + if instr["opcode"] == "CALL": retval_checked = False @@ -78,7 +88,9 @@ def execute(statespace): _state = node.states[_idx] _instr = _state.get_current_instruction() - if _instr['opcode'] == 'ISZERO' and re.search(r'retval', str(_state .mstate.stack[-1])): + if _instr["opcode"] == "ISZERO" and re.search( + r"retval", str(_state.mstate.stack[-1]) + ): retval_checked = True break @@ -87,13 +99,19 @@ def execute(statespace): if not retval_checked: - address = instr['address'] - issue = Issue(contract=node.contract_name, function=node.function_name, - address=address, title="Unchecked CALL return value", swc_id=UNCHECKED_RET_VAL) - - issue.description = \ - "The return value of an external call is not checked. " \ + address = instr["address"] + issue = Issue( + contract=node.contract_name, + function=node.function_name, + address=address, + title="Unchecked CALL return value", + swc_id=UNCHECKED_RET_VAL, + ) + + issue.description = ( + "The return value of an external call is not checked. " "Note that execution continue even if the called contract throws." + ) issues.append(issue) diff --git a/mythril/analysis/ops.py b/mythril/analysis/ops.py index b2329294..6c3cfea1 100644 --- a/mythril/analysis/ops.py +++ b/mythril/analysis/ops.py @@ -9,7 +9,6 @@ class VarType(Enum): class Variable: - def __init__(self, val, _type): self.val = val self.type = _type @@ -26,7 +25,6 @@ def get_variable(i): class Op: - def __init__(self, node, state, state_index): self.node = node self.state = state @@ -34,8 +32,17 @@ class Op: class Call(Op): - - def __init__(self, node, state, state_index, _type, to, gas, value=Variable(0, VarType.CONCRETE), data=None): + def __init__( + self, + node, + state, + state_index, + _type, + to, + gas, + value=Variable(0, VarType.CONCRETE), + data=None, + ): super().__init__(node, state, state_index) self.to = to @@ -46,7 +53,6 @@ class Call(Op): class SStore(Op): - def __init__(self, node, state, state_index, value): super().__init__(node, state, state_index) self.value = value diff --git a/mythril/analysis/report.py b/mythril/analysis/report.py index a6cfd59d..652d7d65 100644 --- a/mythril/analysis/report.py +++ b/mythril/analysis/report.py @@ -5,8 +5,17 @@ from jinja2 import PackageLoader, Environment class Issue: - - def __init__(self, contract, function, address, swc_id, title, _type="Informational", description="", debug=""): + def __init__( + self, + contract, + function, + address, + swc_id, + title, + _type="Informational", + description="", + debug="", + ): self.title = title self.contract = contract @@ -20,32 +29,43 @@ class Issue: self.code = None self.lineno = None - @property def as_dict(self): - issue = {'title': self.title, 'swc_id': self.swc_id, 'contract': self.contract, 'description': self.description, - 'function': self.function, 'type': self.type, 'address': self.address, 'debug': self.debug} + issue = { + "title": self.title, + "swc_id": self.swc_id, + "contract": self.contract, + "description": self.description, + "function": self.function, + "type": self.type, + "address": self.address, + "debug": self.debug, + } if self.filename and self.lineno: - issue['filename'] = self.filename - issue['lineno'] = self.lineno + issue["filename"] = self.filename + issue["lineno"] = self.lineno if self.code: - issue['code'] = self.code + issue["code"] = self.code return issue def add_code_info(self, contract): if self.address: - codeinfo = contract.get_source_info(self.address, constructor=(self.function == 'constructor')) + codeinfo = contract.get_source_info( + self.address, constructor=(self.function == "constructor") + ) self.filename = codeinfo.filename self.code = codeinfo.code self.lineno = codeinfo.lineno class Report: - environment = Environment(loader=PackageLoader('mythril.analysis'), trim_blocks=True) + environment = Environment( + loader=PackageLoader("mythril.analysis"), trim_blocks=True + ) def __init__(self, verbose=False): self.issues = {} @@ -54,26 +74,30 @@ class Report: def sorted_issues(self): issue_list = [issue.as_dict for key, issue in self.issues.items()] - return sorted(issue_list, key=operator.itemgetter('address', 'title')) + return sorted(issue_list, key=operator.itemgetter("address", "title")) def append_issue(self, issue): m = hashlib.md5() - m.update((issue.contract + str(issue.address) + issue.title).encode('utf-8')) + m.update((issue.contract + str(issue.address) + issue.title).encode("utf-8")) self.issues[m.digest()] = issue def as_text(self): name = self._file_name() - template = Report.environment.get_template('report_as_text.jinja2') - return template.render(filename=name, issues=self.sorted_issues(), verbose=self.verbose) + template = Report.environment.get_template("report_as_text.jinja2") + return template.render( + filename=name, issues=self.sorted_issues(), verbose=self.verbose + ) def as_json(self): - result = {'success': True, 'error': None, 'issues': self.sorted_issues()} + result = {"success": True, "error": None, "issues": self.sorted_issues()} return json.dumps(result, sort_keys=True) def as_markdown(self): filename = self._file_name() - template = Report.environment.get_template('report_as_markdown.jinja2') - return template.render(filename=filename, issues=self.sorted_issues(), verbose=self.verbose) + template = Report.environment.get_template("report_as_markdown.jinja2") + return template.render( + filename=filename, issues=self.sorted_issues(), verbose=self.verbose + ) def _file_name(self): if len(self.issues.values()) > 0: diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 9c5a2dd0..afa4ef5b 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -2,6 +2,7 @@ from z3 import Solver, simplify, sat, unknown from mythril.exceptions import UnsatError import logging + def get_model(constraints): s = Solver() s.set("timeout", 100000) @@ -27,6 +28,6 @@ def pretty_print_model(model): except: condition = str(simplify(model[d])) - ret += ("%s: %s\n" % (d.name(), condition)) + ret += "%s: %s\n" % (d.name(), condition) return ret diff --git a/mythril/analysis/swc_data.py b/mythril/analysis/swc_data.py index 3ab8f9b4..2f4d5ded 100644 --- a/mythril/analysis/swc_data.py +++ b/mythril/analysis/swc_data.py @@ -1,25 +1,27 @@ -DEFAULT_FUNCTION_VISIBILITY = '100' -INTEGER_OVERFLOW_AND_UNDERFLOW = '101' -OUTDATED_COMPILER_VERSION = '102' -FLOATING_PRAGMA = '103' -UNCHECKED_RET_VAL = '104' -UNPROTECTED_ETHER_WITHDRAWAL = '105' -UNPROTECTED_SELFDESTRUCT = '106' -REENTRANCY = '107' -DEFAULT_STATE_VARIABLE_VISIBILITY = '108' -UNINITIALIZED_STORAGE_POINTER = '109' -ASSERT_VIOLATION = '110' -DEPRICATED_FUNCTIONS_USAGE = '111' -DELEGATECALL_TO_UNTRUSTED_CONTRACT = '112' -MULTIPLE_SENDS = '113' -TX_ORDER_DEPENDENCE = '114' -TX_ORIGIN_USAGE = '115' -TIMESTAMP_DEPENDENCE = '116' +DEFAULT_FUNCTION_VISIBILITY = "100" +INTEGER_OVERFLOW_AND_UNDERFLOW = "101" +OUTDATED_COMPILER_VERSION = "102" +FLOATING_PRAGMA = "103" +UNCHECKED_RET_VAL = "104" +UNPROTECTED_ETHER_WITHDRAWAL = "105" +UNPROTECTED_SELFDESTRUCT = "106" +REENTRANCY = "107" +DEFAULT_STATE_VARIABLE_VISIBILITY = "108" +UNINITIALIZED_STORAGE_POINTER = "109" +ASSERT_VIOLATION = "110" +DEPRICATED_FUNCTIONS_USAGE = "111" +DELEGATECALL_TO_UNTRUSTED_CONTRACT = "112" +MULTIPLE_SENDS = "113" +TX_ORDER_DEPENDENCE = "114" +TX_ORIGIN_USAGE = "115" +TIMESTAMP_DEPENDENCE = "116" # TODO: SWC ID 116 is missing, Add it if it's added to the https://github.com/SmartContractSecurity/SWC-registry -INCORRECT_CONSTRUCTOR_NAME = '118' -SHADOWING_STATE_VARIABLES = '119' -WEAK_RANDOMNESS = '120' -SIGNATURE_REPLAY = '121' -IMPROPER_VERIFICATION_BASED_ON_MSG_SENDER = '122' +INCORRECT_CONSTRUCTOR_NAME = "118" +SHADOWING_STATE_VARIABLES = "119" +WEAK_RANDOMNESS = "120" +SIGNATURE_REPLAY = "121" +IMPROPER_VERIFICATION_BASED_ON_MSG_SENDER = "122" -PREDICTABLE_VARS_DEPENDENCE = 'N/A' # TODO: Add the swc id when this is added to the SWC Registry +PREDICTABLE_VARS_DEPENDENCE = ( + "N/A" +) # TODO: Add the swc id when this is added to the SWC Registry diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index c9431257..871f0d61 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -4,7 +4,10 @@ from mythril.ether.soliditycontract import SolidityContract import copy import logging from .ops import get_variable, SStore, Call, VarType -from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy, BreadthFirstSearchStrategy +from mythril.laser.ethereum.strategy.basic import ( + DepthFirstSearchStrategy, + BreadthFirstSearchStrategy, +) class SymExecWrapper: @@ -13,28 +16,49 @@ class SymExecWrapper: Wrapper class for the LASER Symbolic virtual machine. Symbolically executes the code and does a bit of pre-analysis for convenience. """ - def __init__(self, contract, address, strategy, dynloader=None, max_depth=22, - execution_timeout=None, create_timeout=None, max_transaction_count=3): + def __init__( + self, + contract, + address, + strategy, + dynloader=None, + max_depth=22, + execution_timeout=None, + create_timeout=None, + max_transaction_count=3, + ): s_strategy = None - if strategy == 'dfs': + if strategy == "dfs": s_strategy = DepthFirstSearchStrategy - elif strategy == 'bfs': + elif strategy == "bfs": s_strategy = BreadthFirstSearchStrategy else: raise ValueError("Invalid strategy argument supplied") - account = Account(address, contract.disassembly, dynamic_loader=dynloader, contract_name=contract.name) + account = Account( + address, + contract.disassembly, + dynamic_loader=dynloader, + contract_name=contract.name, + ) self.accounts = {address: account} - self.laser = svm.LaserEVM(self.accounts, dynamic_loader=dynloader, max_depth=max_depth, - execution_timeout=execution_timeout, strategy=s_strategy, - create_timeout=create_timeout, - max_transaction_count=max_transaction_count) + self.laser = svm.LaserEVM( + self.accounts, + dynamic_loader=dynloader, + max_depth=max_depth, + execution_timeout=execution_timeout, + strategy=s_strategy, + create_timeout=create_timeout, + max_transaction_count=max_transaction_count, + ) if isinstance(contract, SolidityContract): - self.laser.sym_exec(creation_code=contract.creation_code, contract_name=contract.name) + self.laser.sym_exec( + creation_code=contract.creation_code, contract_name=contract.name + ) else: self.laser.sym_exec(address) @@ -54,31 +78,72 @@ class SymExecWrapper: instruction = state.get_current_instruction() - op = instruction['opcode'] + op = instruction["opcode"] - if op in ('CALL', 'CALLCODE', 'DELEGATECALL', 'STATICCALL'): + if op in ("CALL", "CALLCODE", "DELEGATECALL", "STATICCALL"): stack = state.mstate.stack - if op in ('CALL', 'CALLCODE'): - gas, to, value, meminstart, meminsz, memoutstart, memoutsz = \ - get_variable(stack[-1]), get_variable(stack[-2]), get_variable(stack[-3]), get_variable(stack[-4]), get_variable(stack[-5]), get_variable(stack[-6]), get_variable(stack[-7]) + if op in ("CALL", "CALLCODE"): + gas, to, value, meminstart, meminsz, memoutstart, memoutsz = ( + get_variable(stack[-1]), + get_variable(stack[-2]), + get_variable(stack[-3]), + get_variable(stack[-4]), + get_variable(stack[-5]), + get_variable(stack[-6]), + get_variable(stack[-7]), + ) if to.type == VarType.CONCRETE and to.val < 5: - # ignore prebuilts - continue - - if meminstart.type == VarType.CONCRETE and meminsz.type == VarType.CONCRETE: - self.calls.append(Call(self.nodes[key], state, state_index, op, to, gas, value, state.mstate.memory[meminstart.val:meminsz.val * 4])) + # ignore prebuilts + continue + + if ( + meminstart.type == VarType.CONCRETE + and meminsz.type == VarType.CONCRETE + ): + self.calls.append( + Call( + self.nodes[key], + state, + state_index, + op, + to, + gas, + value, + state.mstate.memory[ + meminstart.val : meminsz.val * 4 + ], + ) + ) else: - self.calls.append(Call(self.nodes[key], state, state_index, op, to, gas, value)) + self.calls.append( + Call( + self.nodes[key], + state, + state_index, + op, + to, + gas, + value, + ) + ) else: - gas, to, meminstart, meminsz, memoutstart, memoutsz = \ - get_variable(stack[-1]), get_variable(stack[-2]), get_variable(stack[-3]), get_variable(stack[-4]), get_variable(stack[-5]), get_variable(stack[-6]) - - self.calls.append(Call(self.nodes[key], state, state_index, op, to, gas)) - - elif op == 'SSTORE': + gas, to, meminstart, meminsz, memoutstart, memoutsz = ( + get_variable(stack[-1]), + get_variable(stack[-2]), + get_variable(stack[-3]), + get_variable(stack[-4]), + get_variable(stack[-5]), + get_variable(stack[-6]), + ) + + self.calls.append( + Call(self.nodes[key], state, state_index, op, to, gas) + ) + + elif op == "SSTORE": stack = copy.deepcopy(state.mstate.stack) address = state.environment.active_account.address @@ -90,9 +155,13 @@ class SymExecWrapper: self.sstors[address] = {} try: - self.sstors[address][str(index)].append(SStore(self.nodes[key], state, state_index, value)) + self.sstors[address][str(index)].append( + SStore(self.nodes[key], state, state_index, value) + ) except KeyError: - self.sstors[address][str(index)] = [SStore(self.nodes[key], state, state_index, value)] + self.sstors[address][str(index)] = [ + SStore(self.nodes[key], state, state_index, value) + ] state_index += 1 diff --git a/mythril/analysis/traceexplore.py b/mythril/analysis/traceexplore.py index 62d9ea50..d3e87cab 100644 --- a/mythril/analysis/traceexplore.py +++ b/mythril/analysis/traceexplore.py @@ -3,32 +3,65 @@ from mythril.laser.ethereum.svm import NodeFlags import re colors = [ - {'border': '#26996f', 'background': '#2f7e5b', 'highlight': {'border': '#fff', 'background': '#28a16f'}}, - {'border': '#9e42b3', 'background': '#842899', 'highlight': {'border': '#fff', 'background': '#933da6'}}, - {'border': '#b82323', 'background': '#991d1d', 'highlight': {'border': '#fff', 'background': '#a61f1f'}}, - {'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#fff', 'background': '#424db3'}}, - {'border': '#26996f', 'background': '#2f7e5b', 'highlight': {'border': '#fff', 'background': '#28a16f'}}, - {'border': '#9e42b3', 'background': '#842899', 'highlight': {'border': '#fff', 'background': '#933da6'}}, - {'border': '#b82323', 'background': '#991d1d', 'highlight': {'border': '#fff', 'background': '#a61f1f'}}, - {'border': '#4753bf', 'background': '#3b46a1', 'highlight': {'border': '#fff', 'background': '#424db3'}}, + { + "border": "#26996f", + "background": "#2f7e5b", + "highlight": {"border": "#fff", "background": "#28a16f"}, + }, + { + "border": "#9e42b3", + "background": "#842899", + "highlight": {"border": "#fff", "background": "#933da6"}, + }, + { + "border": "#b82323", + "background": "#991d1d", + "highlight": {"border": "#fff", "background": "#a61f1f"}, + }, + { + "border": "#4753bf", + "background": "#3b46a1", + "highlight": {"border": "#fff", "background": "#424db3"}, + }, + { + "border": "#26996f", + "background": "#2f7e5b", + "highlight": {"border": "#fff", "background": "#28a16f"}, + }, + { + "border": "#9e42b3", + "background": "#842899", + "highlight": {"border": "#fff", "background": "#933da6"}, + }, + { + "border": "#b82323", + "background": "#991d1d", + "highlight": {"border": "#fff", "background": "#a61f1f"}, + }, + { + "border": "#4753bf", + "background": "#3b46a1", + "highlight": {"border": "#fff", "background": "#424db3"}, + }, ] + def get_serializable_statespace(statespace): - + nodes = [] edges = [] - + color_map = {} i = 0 for k in statespace.accounts: color_map[statespace.accounts[k].contract_name] = colors[i] i += 1 - + for node_key in statespace.nodes: node = statespace.nodes[node_key] - code = node.get_cfg_dict()['code'] + code = node.get_cfg_dict()["code"] code = re.sub("([0-9a-f]{8})[0-9a-f]+", lambda m: m.group(1) + "(...)", code) if NodeFlags.FUNC_ENTRY in node.flags: @@ -36,47 +69,51 @@ def get_serializable_statespace(statespace): code_split = code.split("\\n") - truncated_code = code if (len(code_split) < 7) else "\\n".join(code_split[:6]) + "\\n(click to expand +)" + truncated_code = ( + code + if (len(code_split) < 7) + else "\\n".join(code_split[:6]) + "\\n(click to expand +)" + ) + + color = color_map[node.get_cfg_dict()["contract_name"]] - color = color_map[node.get_cfg_dict()['contract_name']] - def get_state_accounts(state): state_accounts = [] for key in state.accounts: account = state.accounts[key].as_dict - account.pop('code', None) - account['balance'] = str(account['balance']) - + account.pop("code", None) + account["balance"] = str(account["balance"]) + storage = {} - for storage_key in account['storage']: - storage[str(storage_key)] = str(account['storage'][storage_key]) - - state_accounts.append({ - 'address': key, - 'storage': storage - }) - return state_accounts - - states = [{'machine': x.mstate.as_dict, 'accounts': get_state_accounts(x)} for x in node.states] - + for storage_key in account["storage"]: + storage[str(storage_key)] = str(account["storage"][storage_key]) + + state_accounts.append({"address": key, "storage": storage}) + return state_accounts + + states = [ + {"machine": x.mstate.as_dict, "accounts": get_state_accounts(x)} + for x in node.states + ] + for state in states: - state['machine']['stack'] = [str(s) for s in state['machine']['stack']] - state['machine']['memory'] = [str(m) for m in state['machine']['memory']] - - truncated_code = truncated_code.replace('\\n', '\n') - code = code.replace('\\n', '\n') - + state["machine"]["stack"] = [str(s) for s in state["machine"]["stack"]] + state["machine"]["memory"] = [str(m) for m in state["machine"]["memory"]] + + truncated_code = truncated_code.replace("\\n", "\n") + code = code.replace("\\n", "\n") + s_node = { - 'id': str(node_key), - 'func': str(node.function_name), - 'label': truncated_code, - 'code': code, - 'truncated': truncated_code, - 'states': states, - 'color': color, - 'instructions': code.split('\n') + "id": str(node_key), + "func": str(node.function_name), + "label": truncated_code, + "code": code, + "truncated": truncated_code, + "states": states, + "color": color, + "instructions": code.split("\n"), } - + nodes.append(s_node) for edge in statespace.edges: @@ -90,20 +127,19 @@ def get_serializable_statespace(statespace): except Z3Exception: label = str(edge.condition).replace("\n", "") - label = re.sub("([^_])([\d]{2}\d+)", lambda m: m.group(1) + hex(int(m.group(2))), label) + label = re.sub( + "([^_])([\d]{2}\d+)", lambda m: m.group(1) + hex(int(m.group(2))), label + ) code = re.sub("([0-9a-f]{8})[0-9a-f]+", lambda m: m.group(1) + "(...)", code) s_edge = { - 'from': str(edge.as_dict['from']), - 'to': str(edge.as_dict['to']), - 'arrows': 'to', - 'label': label, - 'smooth': { 'type': "cubicBezier" } + "from": str(edge.as_dict["from"]), + "to": str(edge.as_dict["to"]), + "arrows": "to", + "label": label, + "smooth": {"type": "cubicBezier"}, } - + edges.append(s_edge) - return { - 'edges': edges, - 'nodes': nodes - } + return {"edges": edges, "nodes": nodes} diff --git a/mythril/disassembler/disassembly.py b/mythril/disassembler/disassembly.py index 93ec0c27..6d3ef817 100644 --- a/mythril/disassembler/disassembly.py +++ b/mythril/disassembler/disassembly.py @@ -4,7 +4,6 @@ import logging class Disassembly(object): - def __init__(self, code, enable_online_lookup=True): self.instruction_list = asm.disassemble(util.safe_decode(code)) self.func_hashes = [] @@ -12,20 +11,25 @@ class Disassembly(object): self.addr_to_func = {} self.bytecode = code - signatures = SignatureDb(enable_online_lookup=enable_online_lookup) # control if you want to have online sighash lookups + signatures = SignatureDb( + enable_online_lookup=enable_online_lookup + ) # control if you want to have online sighash lookups try: signatures.open() # open from default locations except FileNotFoundError: - logging.info("Missing function signature file. Resolving of function names from signature file disabled.") + logging.info( + "Missing function signature file. Resolving of function names from signature file disabled." + ) # Parse jump table & resolve function names # Need to take from PUSH1 to PUSH4 because solc seems to remove excess 0s at the beginning for optimizing - jmptable_indices = asm.find_opcode_sequence([("PUSH1", "PUSH2", "PUSH3", "PUSH4"), ("EQ",)], - self.instruction_list) + jmptable_indices = asm.find_opcode_sequence( + [("PUSH1", "PUSH2", "PUSH3", "PUSH4"), ("EQ",)], self.instruction_list + ) for i in jmptable_indices: - func_hash = self.instruction_list[i]['argument'] + func_hash = self.instruction_list[i]["argument"] # Append with missing 0s at the beginning func_hash = "0x" + func_hash[2:].rjust(8, "0") @@ -37,7 +41,9 @@ class Disassembly(object): func_names = signatures.get(func_hash) if len(func_names) > 1: # ambigious result - func_name = "**ambiguous** %s" % func_names[0] # return first hit but note that result was ambiguous + func_name = ( + "**ambiguous** %s" % func_names[0] + ) # return first hit but note that result was ambiguous else: # only one item func_name = func_names[0] @@ -45,7 +51,7 @@ class Disassembly(object): func_name = "_function_" + func_hash try: - offset = self.instruction_list[i + 2]['argument'] + offset = self.instruction_list[i + 2]["argument"] jump_target = int(offset, 16) self.func_to_addr[func_name] = jump_target diff --git a/mythril/ether/asm.py b/mythril/ether/asm.py index 985b2f07..4540bd72 100644 --- a/mythril/ether/asm.py +++ b/mythril/ether/asm.py @@ -4,11 +4,11 @@ from ethereum.opcodes import opcodes from mythril.ether import util -regex_PUSH = re.compile('^PUSH(\d*)$') +regex_PUSH = re.compile("^PUSH(\d*)$") # Additional mnemonic to catch failed assertions -opcodes[254] = ['ASSERT_FAIL', 0, 0, 0] +opcodes[254] = ["ASSERT_FAIL", 0, 0, 0] def instruction_list_to_easm(instruction_list): @@ -16,10 +16,10 @@ def instruction_list_to_easm(instruction_list): for instruction in instruction_list: - easm += str(instruction['address']) + " " + instruction['opcode'] + easm += str(instruction["address"]) + " " + instruction["opcode"] - if 'argument' in instruction: - easm += " " + instruction['argument'] + if "argument" in instruction: + easm += " " + instruction["argument"] easm += "\n" @@ -28,11 +28,11 @@ def instruction_list_to_easm(instruction_list): def easm_to_instruction_list(easm): - regex_CODELINE = re.compile('^([A-Z0-9]+)(?:\s+([0-9a-fA-Fx]+))?$') + regex_CODELINE = re.compile("^([A-Z0-9]+)(?:\s+([0-9a-fA-Fx]+))?$") instruction_list = [] - codelines = easm.split('\n') + codelines = easm.split("\n") for line in codelines: @@ -42,10 +42,10 @@ def easm_to_instruction_list(easm): # Invalid code line continue - instruction = {'opcode': m.group(1)} + instruction = {"opcode": m.group(1)} if m.group(2): - instruction['argument'] = m.group(2)[2:] + instruction["argument"] = m.group(2)[2:] instruction_list.append(instruction) @@ -70,13 +70,13 @@ def find_opcode_sequence(pattern, instruction_list): for i in range(0, len(instruction_list) - pattern_length + 1): - if instruction_list[i]['opcode'] in pattern[0]: + if instruction_list[i]["opcode"] in pattern[0]: matched = True for j in range(1, len(pattern)): - if not (instruction_list[i + j]['opcode'] in pattern[j]): + if not (instruction_list[i + j]["opcode"] in pattern[j]): matched = False break @@ -99,7 +99,7 @@ def disassemble(bytecode): while addr < length: - instruction = {'address': addr} + instruction = {"address": addr} try: if sys.version_info > (3, 0): @@ -110,21 +110,20 @@ def disassemble(bytecode): except KeyError: # invalid opcode - instruction_list.append({'address': addr, 'opcode': "INVALID"}) + instruction_list.append({"address": addr, "opcode": "INVALID"}) addr += 1 continue - instruction['opcode'] = opcode[0] + instruction["opcode"] = opcode[0] m = re.search(regex_PUSH, opcode[0]) if m: - argument = bytecode[addr+1:addr+1+int(m.group(1))] - instruction['argument'] = "0x" + argument.hex() + argument = bytecode[addr + 1 : addr + 1 + int(m.group(1))] + instruction["argument"] = "0x" + argument.hex() addr += int(m.group(1)) - instruction_list.append(instruction) addr += 1 @@ -139,14 +138,14 @@ def assemble(instruction_list): for instruction in instruction_list: try: - opcode = get_opcode_from_name(instruction['opcode']) + opcode = get_opcode_from_name(instruction["opcode"]) except RuntimeError: - opcode = 0xbb + opcode = 0xBB - bytecode += opcode.to_bytes(1, byteorder='big') + bytecode += opcode.to_bytes(1, byteorder="big") - if 'argument' in instruction: + if "argument" in instruction: - bytecode += util.safe_decode(instruction['argument']) + bytecode += util.safe_decode(instruction["argument"]) return bytecode diff --git a/mythril/ether/ethcontract.py b/mythril/ether/ethcontract.py index 801e063b..b2a55106 100644 --- a/mythril/ether/ethcontract.py +++ b/mythril/ether/ethcontract.py @@ -5,30 +5,33 @@ import re class ETHContract(persistent.Persistent): + def __init__( + self, code, creation_code="", name="Unknown", enable_online_lookup=True + ): - def __init__(self, code, creation_code="", name="Unknown", enable_online_lookup=True): - # Workaround: We currently do not support compile-time linking. # Dynamic contract addresses of the format __[contract-name]_____________ are replaced with a generic address # Apply this for creation_code & code - creation_code = re.sub(r'(_{2}.{38})', 'aa' * 20, creation_code) - code = re.sub(r'(_{2}.{38})', 'aa' * 20, code) + creation_code = re.sub(r"(_{2}.{38})", "aa" * 20, creation_code) + code = re.sub(r"(_{2}.{38})", "aa" * 20, code) self.creation_code = creation_code self.name = name self.code = code self.disassembly = Disassembly(code, enable_online_lookup=enable_online_lookup) - self.creation_disassembly = Disassembly(creation_code, enable_online_lookup=enable_online_lookup) + self.creation_disassembly = Disassembly( + creation_code, enable_online_lookup=enable_online_lookup + ) def as_dict(self): return { - 'address': self.address, - 'name': self.name, - 'code': self.code, - 'creation_code': self.creation_code, - 'disassembly': self.disassembly + "address": self.address, + "name": self.name, + "code": self.code, + "creation_code": self.creation_code, + "disassembly": self.disassembly, } def get_easm(self): @@ -37,7 +40,7 @@ class ETHContract(persistent.Persistent): def matches_expression(self, expression): - str_eval = '' + str_eval = "" easm_code = None tokens = re.split("\s+(and|or|not)\s+", expression, re.IGNORECASE) @@ -48,23 +51,23 @@ class ETHContract(persistent.Persistent): str_eval += " " + token + " " continue - m = re.match(r'^code#([a-zA-Z0-9\s,\[\]]+)#', token) + m = re.match(r"^code#([a-zA-Z0-9\s,\[\]]+)#", token) if m: if easm_code is None: easm_code = self.get_easm() code = m.group(1).replace(",", "\\n") - str_eval += "\"" + code + "\" in easm_code" + str_eval += '"' + code + '" in easm_code' continue - m = re.match(r'^func#([a-zA-Z0-9\s_,(\\)\[\]]+)#$', token) + m = re.match(r"^func#([a-zA-Z0-9\s_,(\\)\[\]]+)#$", token) if m: sign_hash = "0x" + utils.sha3(m.group(1))[:4].hex() - str_eval += "\"" + sign_hash + "\" in self.disassembly.func_hashes" + str_eval += '"' + sign_hash + '" in self.disassembly.func_hashes' continue diff --git a/mythril/ether/evm.py b/mythril/ether/evm.py index 0bcc7206..020dcd94 100644 --- a/mythril/ether/evm.py +++ b/mythril/ether/evm.py @@ -7,69 +7,76 @@ from io import StringIO import re -def trace(code, calldata = ""): +def trace(code, calldata=""): - log_handlers = ['eth.vm.op', 'eth.vm.op.stack', 'eth.vm.op.memory', 'eth.vm.op.storage'] + log_handlers = [ + "eth.vm.op", + "eth.vm.op.stack", + "eth.vm.op.memory", + "eth.vm.op.storage", + ] - output = StringIO() - stream_handler = StreamHandler(output) + output = StringIO() + stream_handler = StreamHandler(output) - for handler in log_handlers: - log_vm_op = get_logger(handler) - log_vm_op.setLevel("TRACE") - log_vm_op.addHandler(stream_handler) + for handler in log_handlers: + log_vm_op = get_logger(handler) + log_vm_op.setLevel("TRACE") + log_vm_op.addHandler(stream_handler) - addr = bytes.fromhex('0123456789ABCDEF0123456789ABCDEF01234567') + addr = bytes.fromhex("0123456789ABCDEF0123456789ABCDEF01234567") - state = State() + state = State() - ext = messages.VMExt(state, transactions.Transaction(0, 0, 21000, addr, 0, addr)) + ext = messages.VMExt(state, transactions.Transaction(0, 0, 21000, addr, 0, addr)) - message = vm.Message(addr, addr, 0, 21000, calldata) + message = vm.Message(addr, addr, 0, 21000, calldata) - res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code)) + res, gas, dat = vm.vm_execute(ext, message, util.safe_decode(code)) - stream_handler.flush() + stream_handler.flush() - ret = output.getvalue() + ret = output.getvalue() - lines = ret.split("\n") + lines = ret.split("\n") - trace = [] + trace = [] - for line in lines: + for line in lines: - m = re.search(r'pc=b\'(\d+)\'.*op=([A-Z0-9]+)', line) + m = re.search(r"pc=b\'(\d+)\'.*op=([A-Z0-9]+)", line) - if m: - pc = m.group(1) - op = m.group(2) + if m: + pc = m.group(1) + op = m.group(2) - m = re.match(r'.*stack=(\[.*?\])', line) - - if m: + m = re.match(r".*stack=(\[.*?\])", line) - stackitems = re.findall(r'b\'(\d+)\'', m.group(1)) + if m: - stack = "[" + stackitems = re.findall(r"b\'(\d+)\'", m.group(1)) - if len(stackitems): + stack = "[" - for i in range(0, len(stackitems) - 1): - stack += hex(int(stackitems[i])) + ", " + if len(stackitems): - stack += hex(int(stackitems[-1])) + for i in range(0, len(stackitems) - 1): + stack += hex(int(stackitems[i])) + ", " - stack += "]" + stack += hex(int(stackitems[-1])) - else: - stack = "[]" + stack += "]" - if re.match(r'^PUSH.*', op): - val = re.search(r'pushvalue=(\d+)', line).group(1) - pushvalue = hex(int(val)) - trace.append({'pc': pc, 'op': op, 'stack': stack, 'pushvalue': pushvalue}) - else: - trace.append({'pc': pc, 'op': op, 'stack': stack}) + else: + stack = "[]" - return trace + if re.match(r"^PUSH.*", op): + val = re.search(r"pushvalue=(\d+)", line).group(1) + pushvalue = hex(int(val)) + trace.append( + {"pc": pc, "op": op, "stack": stack, "pushvalue": pushvalue} + ) + else: + trace.append({"pc": pc, "op": op, "stack": stack}) + + return trace diff --git a/mythril/ether/soliditycontract.py b/mythril/ether/soliditycontract.py index 0642ea21..14995b8b 100644 --- a/mythril/ether/soliditycontract.py +++ b/mythril/ether/soliditycontract.py @@ -5,7 +5,6 @@ from mythril.exceptions import NoContractFoundError class SourceMapping: - def __init__(self, solidity_file_idx, offset, length, lineno): self.solidity_file_idx = solidity_file_idx self.offset = offset @@ -14,14 +13,12 @@ class SourceMapping: class SolidityFile: - def __init__(self, filename, data): self.filename = filename self.data = data class SourceCodeInfo: - def __init__(self, filename, lineno, code): self.filename = filename self.lineno = lineno @@ -30,22 +27,21 @@ class SourceCodeInfo: def get_contracts_from_file(input_file, solc_args=None): data = get_solc_json(input_file, solc_args=solc_args) - for key, contract in data['contracts'].items(): + for key, contract in data["contracts"].items(): filename, name = key.split(":") - if filename == input_file and len(contract['bin-runtime']): + if filename == input_file and len(contract["bin-runtime"]): yield SolidityContract(input_file, name, solc_args) class SolidityContract(ETHContract): - def __init__(self, input_file, name=None, solc_args=None): data = get_solc_json(input_file, solc_args=solc_args) self.solidity_files = [] - for filename in data['sourceList']: - with open(filename, 'r', encoding='utf-8') as file: + for filename in data["sourceList"]: + with open(filename, "r", encoding="utf-8") as file: code = file.read() self.solidity_files.append(SolidityFile(filename, code)) @@ -55,28 +51,32 @@ class SolidityContract(ETHContract): srcmap_constructor = [] srcmap = [] if name: - for key, contract in sorted(data['contracts'].items()): + for key, contract in sorted(data["contracts"].items()): filename, _name = key.split(":") - if filename == input_file and name == _name and len(contract['bin-runtime']): - code = contract['bin-runtime'] - creation_code = contract['bin'] - srcmap = contract['srcmap-runtime'].split(";") - srcmap_constructor = contract['srcmap'].split(";") + if ( + filename == input_file + and name == _name + and len(contract["bin-runtime"]) + ): + code = contract["bin-runtime"] + creation_code = contract["bin"] + srcmap = contract["srcmap-runtime"].split(";") + srcmap_constructor = contract["srcmap"].split(";") has_contract = True break # If no contract name is specified, get the last bytecode entry for the input file else: - for key, contract in sorted(data['contracts'].items()): + for key, contract in sorted(data["contracts"].items()): filename, name = key.split(":") - if filename == input_file and len(contract['bin-runtime']): - code = contract['bin-runtime'] - creation_code = contract['bin'] - srcmap = contract['srcmap-runtime'].split(";") - srcmap_constructor = contract['srcmap'].split(";") + if filename == input_file and len(contract["bin-runtime"]): + code = contract["bin-runtime"] + creation_code = contract["bin"] + srcmap = contract["srcmap-runtime"].split(";") + srcmap_constructor = contract["srcmap"].split(";") has_contract = True if not has_contract: @@ -102,7 +102,9 @@ class SolidityContract(ETHContract): offset = mappings[index].offset length = mappings[index].length - code = solidity_file.data.encode('utf-8')[offset:offset + length].decode('utf-8', errors="ignore") + code = solidity_file.data.encode("utf-8")[offset : offset + length].decode( + "utf-8", errors="ignore" + ) lineno = mappings[index].lineno return SourceCodeInfo(filename, lineno, code) @@ -120,6 +122,11 @@ class SolidityContract(ETHContract): if len(mapping) > 2 and len(mapping[2]) > 0: idx = int(mapping[2]) - lineno = self.solidity_files[idx].data.encode('utf-8')[0:offset].count('\n'.encode('utf-8')) + 1 + lineno = ( + self.solidity_files[idx] + .data.encode("utf-8")[0:offset] + .count("\n".encode("utf-8")) + + 1 + ) mappings.append(SourceMapping(idx, offset, length, lineno)) diff --git a/mythril/ether/util.py b/mythril/ether/util.py index 6b351665..93cb2c06 100644 --- a/mythril/ether/util.py +++ b/mythril/ether/util.py @@ -39,9 +39,14 @@ def get_solc_json(file, solc_binary="solc", solc_args=None): ret = p.returncode if ret != 0: - raise CompilerError("Solc experienced a fatal error (code %d).\n\n%s" % (ret, stderr.decode('UTF-8'))) + raise CompilerError( + "Solc experienced a fatal error (code %d).\n\n%s" + % (ret, stderr.decode("UTF-8")) + ) except FileNotFoundError: - raise CompilerError("Compiler not found. Make sure that solc is installed and in PATH, or set the SOLC environment variable.") + raise CompilerError( + "Compiler not found. Make sure that solc is installed and in PATH, or set the SOLC environment variable." + ) out = stdout.decode("UTF-8") @@ -59,7 +64,7 @@ def encode_calldata(func_name, arg_types, args): def get_random_address(): - return binascii.b2a_hex(os.urandom(20)).decode('UTF-8') + return binascii.b2a_hex(os.urandom(20)).decode("UTF-8") def get_indexed_address(index): @@ -67,7 +72,9 @@ def get_indexed_address(index): def solc_exists(version): - solc_binary = os.path.join(os.environ['HOME'], ".py-solc/solc-v" + version, "bin/solc") + solc_binary = os.path.join( + os.environ["HOME"], ".py-solc/solc-v" + version, "bin/solc" + ) if os.path.exists(solc_binary): return True else: diff --git a/mythril/ethereum/interface/leveldb/accountindexing.py b/mythril/ethereum/interface/leveldb/accountindexing.py index 7578afd1..21203740 100644 --- a/mythril/ethereum/interface/leveldb/accountindexing.py +++ b/mythril/ethereum/interface/leveldb/accountindexing.py @@ -39,13 +39,13 @@ class ReceiptForStorage(rlp.Serializable): """ fields = [ - ('state_root', binary), - ('cumulative_gas_used', big_endian_int), - ('bloom', int256), - ('tx_hash', hash32), - ('contractAddress', address), - ('logs', CountableList(Log)), - ('gas_used', big_endian_int) + ("state_root", binary), + ("cumulative_gas_used", big_endian_int), + ("bloom", int256), + ("tx_hash", hash32), + ("contractAddress", address), + ("logs", CountableList(Log)), + ("gas_used", big_endian_int), ] @@ -77,7 +77,9 @@ class AccountIndexer(object): """ Processesing method """ - logging.debug("Processing blocks %d to %d" % (startblock, startblock + BATCH_SIZE)) + logging.debug( + "Processing blocks %d to %d" % (startblock, startblock + BATCH_SIZE) + ) addresses = [] @@ -87,7 +89,9 @@ class AccountIndexer(object): receipts = self.db.reader._get_block_receipts(hash, blockNum) for receipt in receipts: - if receipt.contractAddress is not None and not all(b == 0 for b in receipt.contractAddress): + if receipt.contractAddress is not None and not all( + b == 0 for b in receipt.contractAddress + ): addresses.append(receipt.contractAddress) else: if len(addresses) == 0: @@ -113,15 +117,21 @@ class AccountIndexer(object): # in fast sync head block is at 0 (e.g. in fastSync), we can't use it to determine length if self.lastBlock is not None and self.lastBlock == 0: - self.lastBlock = 2e+9 + self.lastBlock = 2e9 - if self.lastBlock is None or (self.lastProcessedBlock is not None and self.lastBlock <= self.lastProcessedBlock): + if self.lastBlock is None or ( + self.lastProcessedBlock is not None + and self.lastBlock <= self.lastProcessedBlock + ): return blockNum = 0 if self.lastProcessedBlock is not None: blockNum = self.lastProcessedBlock + 1 - print("Updating hash-to-address index from block " + str(self.lastProcessedBlock)) + print( + "Updating hash-to-address index from block " + + str(self.lastProcessedBlock) + ) else: print("Starting hash-to-address index") @@ -148,7 +158,10 @@ class AccountIndexer(object): blockNum = min(blockNum + BATCH_SIZE, self.lastBlock + 1) cost_time = time.time() - ether.start_time - print("%d blocks processed (in %d seconds), %d unique addresses found, next block: %d" % (processed, cost_time, count, min(self.lastBlock, blockNum))) + print( + "%d blocks processed (in %d seconds), %d unique addresses found, next block: %d" + % (processed, cost_time, count, min(self.lastBlock, blockNum)) + ) self.lastProcessedBlock = blockNum - 1 self.db.writer._set_last_indexed_number(self.lastProcessedBlock) diff --git a/mythril/ethereum/interface/leveldb/client.py b/mythril/ethereum/interface/leveldb/client.py index a1b4323b..b3500baf 100644 --- a/mythril/ethereum/interface/leveldb/client.py +++ b/mythril/ethereum/interface/leveldb/client.py @@ -1,7 +1,10 @@ import binascii import rlp from mythril.ethereum.interface.leveldb.accountindexing import CountableList -from mythril.ethereum.interface.leveldb.accountindexing import ReceiptForStorage, AccountIndexer +from mythril.ethereum.interface.leveldb.accountindexing import ( + ReceiptForStorage, + AccountIndexer, +) import logging from ethereum import utils from ethereum.block import BlockHeader, Block @@ -12,17 +15,19 @@ from mythril.exceptions import AddressNotFoundError # Per https://github.com/ethereum/go-ethereum/blob/master/core/rawdb/schema.go # prefixes and suffixes for keys in geth -header_prefix = b'h' # header_prefix + num (uint64 big endian) + hash -> header -body_prefix = b'b' # body_prefix + num (uint64 big endian) + hash -> block body -num_suffix = b'n' # header_prefix + num (uint64 big endian) + num_suffix -> hash -block_hash_prefix = b'H' # block_hash_prefix + hash -> num (uint64 big endian) -block_receipts_prefix = b'r' # block_receipts_prefix + num (uint64 big endian) + hash -> block receipts +header_prefix = b"h" # header_prefix + num (uint64 big endian) + hash -> header +body_prefix = b"b" # body_prefix + num (uint64 big endian) + hash -> block body +num_suffix = b"n" # header_prefix + num (uint64 big endian) + num_suffix -> hash +block_hash_prefix = b"H" # block_hash_prefix + hash -> num (uint64 big endian) +block_receipts_prefix = ( + b"r" +) # block_receipts_prefix + num (uint64 big endian) + hash -> block receipts # known geth keys -head_header_key = b'LastBlock' # head (latest) header hash +head_header_key = b"LastBlock" # head (latest) header hash # custom prefixes -address_prefix = b'AM' # address_prefix + hash -> address +address_prefix = b"AM" # address_prefix + hash -> address # custom keys -address_mapping_head_key = b'accountMapping' # head (latest) number of indexed block +address_mapping_head_key = b"accountMapping" # head (latest) number of indexed block def _format_block_number(number): @@ -36,7 +41,7 @@ def _encode_hex(v): """ encodes hash as hex """ - return '0x' + utils.encode_hex(v) + return "0x" + utils.encode_hex(v) class LevelDBReader(object): @@ -83,7 +88,10 @@ class LevelDBReader(object): num = self._get_block_number(hash) self.head_block_header = self._get_block_header(hash, num) # find header with valid state - while not self.db.get(self.head_block_header.state_root) and self.head_block_header.prevhash is not None: + while ( + not self.db.get(self.head_block_header.state_root) + and self.head_block_header.prevhash is not None + ): hash = self.head_block_header.prevhash num = self._get_block_number(hash) self.head_block_header = self._get_block_header(hash, num) @@ -201,11 +209,11 @@ class EthLevelDB(object): try: address = _encode_hex(indexer.get_contract_by_hash(address_hash)) except AddressNotFoundError: - ''' + """ The hash->address mapping does not exist in our index. If the index is up-to-date, this likely means that the contract was created by an internal transaction. Skip this contract as right now we don't have a good solution for this. - ''' + """ continue @@ -264,4 +272,6 @@ class EthLevelDB(object): gets account storage data at position """ account = self.reader._get_account(address) - return _encode_hex(utils.zpad(utils.encode_int(account.get_storage_data(position)), 32)) + return _encode_hex( + utils.zpad(utils.encode_int(account.get_storage_data(position)), 32) + ) diff --git a/mythril/ethereum/interface/leveldb/state.py b/mythril/ethereum/interface/leveldb/state.py index e8f86331..678c36af 100644 --- a/mythril/ethereum/interface/leveldb/state.py +++ b/mythril/ethereum/interface/leveldb/state.py @@ -1,24 +1,39 @@ import rlp import binascii -from ethereum.utils import normalize_address, hash32, trie_root, \ - big_endian_int, address, int256, encode_hex, encode_int, \ - big_endian_to_int, int_to_addr, zpad, parse_as_bin, parse_as_int, \ - decode_hex, sha3, is_string, is_numeric +from ethereum.utils import ( + normalize_address, + hash32, + trie_root, + big_endian_int, + address, + int256, + encode_hex, + encode_int, + big_endian_to_int, + int_to_addr, + zpad, + parse_as_bin, + parse_as_int, + decode_hex, + sha3, + is_string, + is_numeric, +) from rlp.sedes import big_endian_int, Binary, binary, CountableList from ethereum import utils from ethereum import trie from ethereum.trie import Trie from ethereum.securetrie import SecureTrie -BLANK_HASH = utils.sha3(b'') -BLANK_ROOT = utils.sha3rlp(b'') +BLANK_HASH = utils.sha3(b"") +BLANK_ROOT = utils.sha3rlp(b"") STATE_DEFAULTS = { "txindex": 0, "gas_used": 0, "gas_limit": 3141592, "block_number": 0, - "block_coinbase": '\x00' * 20, + "block_coinbase": "\x00" * 20, "block_difficulty": 1, "timestamp": 0, "logs": [], @@ -37,10 +52,10 @@ class Account(rlp.Serializable): """ fields = [ - ('nonce', big_endian_int), - ('balance', big_endian_int), - ('storage', trie_root), - ('code_hash', hash32) + ("nonce", big_endian_int), + ("balance", big_endian_int), + ("storage", trie_root), + ("code_hash", hash32), ] def __init__(self, nonce, balance, storage, code_hash, db, address): @@ -69,7 +84,8 @@ class Account(rlp.Serializable): if key not in self.storage_cache: v = self.storage_trie.get(utils.encode_int32(key)) self.storage_cache[key] = utils.big_endian_to_int( - rlp.decode(v) if v else b'') + rlp.decode(v) if v else b"" + ) return self.storage_cache[key] @classmethod @@ -77,7 +93,7 @@ class Account(rlp.Serializable): """ creates a blank account """ - db.put(BLANK_HASH, b'') + db.put(BLANK_HASH, b"") o = cls(initial_nonce, 0, trie.BLANK_ROOT, BLANK_HASH, db, address) o.existent_at_start = False return o @@ -88,6 +104,7 @@ class Account(rlp.Serializable): """ return self.nonce == 0 and self.balance == 0 and self.code_hash == BLANK_HASH + class State: """ adjusted state from ethereum.state @@ -107,13 +124,14 @@ class State: if address in self.cache: return self.cache[address] rlpdata = self.secure_trie.get(address) - if rlpdata == trie.BLANK_NODE and len(address) == 32: # support for hashed addresses + if ( + rlpdata == trie.BLANK_NODE and len(address) == 32 + ): # support for hashed addresses rlpdata = self.trie.get(address) if rlpdata != trie.BLANK_NODE: o = rlp.decode(rlpdata, Account, db=self.db, address=address) else: - o = Account.blank_account( - self.db, address, 0) + o = Account.blank_account(self.db, address, 0) self.cache[address] = o o._mutable = True o._cached_rlp = None diff --git a/mythril/ethereum/interface/rpc/base_client.py b/mythril/ethereum/interface/rpc/base_client.py index 9234ecf9..e0c6e387 100644 --- a/mythril/ethereum/interface/rpc/base_client.py +++ b/mythril/ethereum/interface/rpc/base_client.py @@ -1,4 +1,4 @@ -from abc import (abstractmethod) +from abc import abstractmethod from .constants import BLOCK_TAGS, BLOCK_TAG_LATEST from .utils import hex_to_dec, validate_block @@ -8,13 +8,14 @@ ETH_DEFAULT_RPC_PORT = 8545 PARITY_DEFAULT_RPC_PORT = 8545 PYETHAPP_DEFAULT_RPC_PORT = 4000 MAX_RETRIES = 3 -JSON_MEDIA_TYPE = 'application/json' +JSON_MEDIA_TYPE = "application/json" -''' +""" This code is adapted from: https://github.com/ConsenSys/ethjsonrpc -''' -class BaseClient(object): +""" + +class BaseClient(object): @abstractmethod def _call(self, method, params=None, _id=1): pass @@ -25,7 +26,7 @@ class BaseClient(object): TESTED """ - return self._call('eth_coinbase') + return self._call("eth_coinbase") def eth_blockNumber(self): """ @@ -33,7 +34,7 @@ class BaseClient(object): TESTED """ - return hex_to_dec(self._call('eth_blockNumber')) + return hex_to_dec(self._call("eth_blockNumber")) def eth_getBalance(self, address=None, block=BLOCK_TAG_LATEST): """ @@ -43,7 +44,7 @@ class BaseClient(object): """ address = address or self.eth_coinbase() block = validate_block(block) - return hex_to_dec(self._call('eth_getBalance', [address, block])) + return hex_to_dec(self._call("eth_getBalance", [address, block])) def eth_getStorageAt(self, address=None, position=0, block=BLOCK_TAG_LATEST): """ @@ -52,7 +53,7 @@ class BaseClient(object): TESTED """ block = validate_block(block) - return self._call('eth_getStorageAt', [address, hex(position), block]) + return self._call("eth_getStorageAt", [address, hex(position), block]) def eth_getCode(self, address, default_block=BLOCK_TAG_LATEST): """ @@ -63,7 +64,7 @@ class BaseClient(object): if isinstance(default_block, str): if default_block not in BLOCK_TAGS: raise ValueError - return self._call('eth_getCode', [address, default_block]) + return self._call("eth_getCode", [address, default_block]) def eth_getBlockByNumber(self, block=BLOCK_TAG_LATEST, tx_objects=True): """ @@ -72,7 +73,7 @@ class BaseClient(object): TESTED """ block = validate_block(block) - return self._call('eth_getBlockByNumber', [block, tx_objects]) + return self._call("eth_getBlockByNumber", [block, tx_objects]) def eth_getTransactionReceipt(self, tx_hash): """ @@ -80,4 +81,4 @@ class BaseClient(object): TESTED """ - return self._call('eth_getTransactionReceipt', [tx_hash]) + return self._call("eth_getTransactionReceipt", [tx_hash]) diff --git a/mythril/ethereum/interface/rpc/client.py b/mythril/ethereum/interface/rpc/client.py index 1545092f..97b5af46 100644 --- a/mythril/ethereum/interface/rpc/client.py +++ b/mythril/ethereum/interface/rpc/client.py @@ -3,7 +3,12 @@ import logging import requests from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError as RequestsConnectionError -from .exceptions import (ConnectionError, BadStatusCodeError, BadJsonError, BadResponseError) +from .exceptions import ( + ConnectionError, + BadStatusCodeError, + BadJsonError, + BadResponseError, +) from .base_client import BaseClient GETH_DEFAULT_RPC_PORT = 8545 @@ -11,17 +16,19 @@ ETH_DEFAULT_RPC_PORT = 8545 PARITY_DEFAULT_RPC_PORT = 8545 PYETHAPP_DEFAULT_RPC_PORT = 4000 MAX_RETRIES = 3 -JSON_MEDIA_TYPE = 'application/json' +JSON_MEDIA_TYPE = "application/json" -''' +""" This code is adapted from: https://github.com/ConsenSys/ethjsonrpc -''' +""" + + class EthJsonRpc(BaseClient): """ Ethereum JSON-RPC client class """ - def __init__(self, host='localhost', port=GETH_DEFAULT_RPC_PORT, tls=False): + def __init__(self, host="localhost", port=GETH_DEFAULT_RPC_PORT, tls=False): self.host = host self.port = port self.tls = tls @@ -31,17 +38,12 @@ class EthJsonRpc(BaseClient): def _call(self, method, params=None, _id=1): params = params or [] - data = { - 'jsonrpc': '2.0', - 'method': method, - 'params': params, - 'id': _id, - } - scheme = 'http' + data = {"jsonrpc": "2.0", "method": method, "params": params, "id": _id} + scheme = "http" if self.tls: - scheme += 's' - url = '{}://{}:{}'.format(scheme, self.host, self.port) - headers = {'Content-Type': JSON_MEDIA_TYPE} + scheme += "s" + url = "{}://{}:{}".format(scheme, self.host, self.port) + headers = {"Content-Type": JSON_MEDIA_TYPE} logging.debug("rpc send: %s" % json.dumps(data)) try: r = self.session.post(url, headers=headers, data=json.dumps(data)) @@ -55,7 +57,7 @@ class EthJsonRpc(BaseClient): except ValueError: raise BadJsonError(r.text) try: - return response['result'] + return response["result"] except KeyError: raise BadResponseError(response) diff --git a/mythril/ethereum/interface/rpc/constants.py b/mythril/ethereum/interface/rpc/constants.py index 414c6d1c..a0dcd254 100644 --- a/mythril/ethereum/interface/rpc/constants.py +++ b/mythril/ethereum/interface/rpc/constants.py @@ -1,8 +1,4 @@ -BLOCK_TAG_EARLIEST = 'earliest' -BLOCK_TAG_LATEST = 'latest' -BLOCK_TAG_PENDING = 'pending' -BLOCK_TAGS = ( - BLOCK_TAG_EARLIEST, - BLOCK_TAG_LATEST, - BLOCK_TAG_PENDING, -) +BLOCK_TAG_EARLIEST = "earliest" +BLOCK_TAG_LATEST = "latest" +BLOCK_TAG_PENDING = "pending" +BLOCK_TAGS = (BLOCK_TAG_EARLIEST, BLOCK_TAG_LATEST, BLOCK_TAG_PENDING) diff --git a/mythril/ethereum/interface/rpc/utils.py b/mythril/ethereum/interface/rpc/utils.py index 5f98fcea..c0f701a2 100644 --- a/mythril/ethereum/interface/rpc/utils.py +++ b/mythril/ethereum/interface/rpc/utils.py @@ -13,12 +13,13 @@ def clean_hex(d): Convert decimal to hex and remove the "L" suffix that is appended to large numbers """ - return hex(d).rstrip('L') + return hex(d).rstrip("L") + def validate_block(block): if isinstance(block, str): if block not in BLOCK_TAGS: - raise ValueError('invalid block tag') + raise ValueError("invalid block tag") if isinstance(block, int): block = hex(block) return block @@ -28,11 +29,11 @@ def wei_to_ether(wei): """ Convert wei to ether """ - return 1.0 * wei / 10**18 + return 1.0 * wei / 10 ** 18 def ether_to_wei(ether): """ Convert ether to wei """ - return ether * 10**18 + return ether * 10 ** 18 diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index 1784add4..08a8d4e7 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -18,98 +18,218 @@ from mythril.version import VERSION def exit_with_error(format, message): - if format == 'text' or format == 'markdown': + if format == "text" or format == "markdown": print(message) else: - result = {'success': False, 'error': str(message), 'issues': []} + result = {"success": False, "error": str(message), "issues": []} print(json.dumps(result)) sys.exit() def main(): - parser = argparse.ArgumentParser(description='Security analysis of Ethereum smart contracts') - parser.add_argument("solidity_file", nargs='*') - - commands = parser.add_argument_group('commands') - commands.add_argument('-g', '--graph', help='generate a control flow graph') - commands.add_argument('-V', '--version', action='store_true', - help='print the Mythril version number and exit') - commands.add_argument('-x', '--fire-lasers', action='store_true', - help='detect vulnerabilities, use with -c, -a or solidity file(s)') - commands.add_argument('-t', '--truffle', action='store_true', - help='analyze a truffle project (run from project dir)') - commands.add_argument('-d', '--disassemble', action='store_true', help='print disassembly') - commands.add_argument('-j', '--statespace-json', help='dumps the statespace json', metavar='OUTPUT_FILE') - - inputs = parser.add_argument_group('input arguments') - inputs.add_argument('-c', '--code', help='hex-encoded bytecode string ("6060604052...")', metavar='BYTECODE') - inputs.add_argument('-f', '--codefile', help='file containing hex-encoded bytecode string', - metavar='BYTECODEFILE', type=argparse.FileType('r')) - inputs.add_argument('-a', '--address', help='pull contract from the blockchain', metavar='CONTRACT_ADDRESS') - inputs.add_argument('-l', '--dynld', action='store_true', help='auto-load dependencies from the blockchain') - - outputs = parser.add_argument_group('output formats') - outputs.add_argument('-o', '--outform', choices=['text', 'markdown', 'json'], default='text', - help='report output format', metavar='') - outputs.add_argument('--verbose-report', action='store_true', help='Include debugging information in report') - - database = parser.add_argument_group('local contracts database') - database.add_argument('-s', '--search', help='search the contract database', metavar='EXPRESSION') - database.add_argument('--leveldb-dir', help='specify leveldb directory for search or direct access operations', metavar='LEVELDB_PATH') - - utilities = parser.add_argument_group('utilities') - utilities.add_argument('--hash', help='calculate function signature hash', metavar='SIGNATURE') - utilities.add_argument('--storage', help='read state variables from storage index, use with -a', - metavar='INDEX,NUM_SLOTS,[array] / mapping,INDEX,[KEY1, KEY2...]') - utilities.add_argument('--solv', - help='specify solidity compiler version. If not present, will try to install it (Experimental)', - metavar='SOLV') - utilities.add_argument('--contract-hash-to-address', help='returns corresponding address for a contract address hash', metavar='SHA3_TO_LOOK_FOR') - - options = parser.add_argument_group('options') - options.add_argument('-m', '--modules', help='Comma-separated list of security analysis modules', metavar='MODULES') - options.add_argument('--max-depth', type=int, default=22, help='Maximum recursion depth for symbolic execution') - options.add_argument('--max-transaction-count', type=int, default=3, help='Maximum number of transactions issued by laser') - options.add_argument('--strategy', choices=['dfs', 'bfs'], default='dfs', help='Symbolic execution strategy') - options.add_argument('--execution-timeout', type=int, default=600, help="The amount of seconds to spend on symbolic execution") - options.add_argument('--create-timeout', type=int, default=10, help="The amount of seconds to spend on " - "the initial contract creation") - options.add_argument('--solc-args', help='Extra arguments for solc') - options.add_argument('--phrack', action='store_true', help='Phrack-style call graph') - options.add_argument('--enable-physics', action='store_true', help='enable graph physics simulation') - options.add_argument('-v', type=int, help='log level (0-2)', metavar='LOG_LEVEL') - - rpc = parser.add_argument_group('RPC options') - rpc.add_argument('-i', action='store_true', help='Preset: Infura Node service (Mainnet)') - rpc.add_argument('--rpc', help='custom RPC settings', metavar='HOST:PORT / ganache / infura-[network_name]') - rpc.add_argument('--rpctls', type=bool, default=False, help='RPC connection over TLS') + parser = argparse.ArgumentParser( + description="Security analysis of Ethereum smart contracts" + ) + parser.add_argument("solidity_file", nargs="*") + + commands = parser.add_argument_group("commands") + commands.add_argument("-g", "--graph", help="generate a control flow graph") + commands.add_argument( + "-V", + "--version", + action="store_true", + help="print the Mythril version number and exit", + ) + commands.add_argument( + "-x", + "--fire-lasers", + action="store_true", + help="detect vulnerabilities, use with -c, -a or solidity file(s)", + ) + commands.add_argument( + "-t", + "--truffle", + action="store_true", + help="analyze a truffle project (run from project dir)", + ) + commands.add_argument( + "-d", "--disassemble", action="store_true", help="print disassembly" + ) + commands.add_argument( + "-j", + "--statespace-json", + help="dumps the statespace json", + metavar="OUTPUT_FILE", + ) + + inputs = parser.add_argument_group("input arguments") + inputs.add_argument( + "-c", + "--code", + help='hex-encoded bytecode string ("6060604052...")', + metavar="BYTECODE", + ) + inputs.add_argument( + "-f", + "--codefile", + help="file containing hex-encoded bytecode string", + metavar="BYTECODEFILE", + type=argparse.FileType("r"), + ) + inputs.add_argument( + "-a", + "--address", + help="pull contract from the blockchain", + metavar="CONTRACT_ADDRESS", + ) + inputs.add_argument( + "-l", + "--dynld", + action="store_true", + help="auto-load dependencies from the blockchain", + ) + + outputs = parser.add_argument_group("output formats") + outputs.add_argument( + "-o", + "--outform", + choices=["text", "markdown", "json"], + default="text", + help="report output format", + metavar="", + ) + outputs.add_argument( + "--verbose-report", + action="store_true", + help="Include debugging information in report", + ) + + database = parser.add_argument_group("local contracts database") + database.add_argument( + "-s", "--search", help="search the contract database", metavar="EXPRESSION" + ) + database.add_argument( + "--leveldb-dir", + help="specify leveldb directory for search or direct access operations", + metavar="LEVELDB_PATH", + ) + + utilities = parser.add_argument_group("utilities") + utilities.add_argument( + "--hash", help="calculate function signature hash", metavar="SIGNATURE" + ) + utilities.add_argument( + "--storage", + help="read state variables from storage index, use with -a", + metavar="INDEX,NUM_SLOTS,[array] / mapping,INDEX,[KEY1, KEY2...]", + ) + utilities.add_argument( + "--solv", + help="specify solidity compiler version. If not present, will try to install it (Experimental)", + metavar="SOLV", + ) + utilities.add_argument( + "--contract-hash-to-address", + help="returns corresponding address for a contract address hash", + metavar="SHA3_TO_LOOK_FOR", + ) + + options = parser.add_argument_group("options") + options.add_argument( + "-m", + "--modules", + help="Comma-separated list of security analysis modules", + metavar="MODULES", + ) + options.add_argument( + "--max-depth", + type=int, + default=22, + help="Maximum recursion depth for symbolic execution", + ) + options.add_argument( + "--max-transaction-count", + type=int, + default=3, + help="Maximum number of transactions issued by laser", + ) + options.add_argument( + "--strategy", + choices=["dfs", "bfs"], + default="dfs", + help="Symbolic execution strategy", + ) + options.add_argument( + "--execution-timeout", + type=int, + default=600, + help="The amount of seconds to spend on symbolic execution", + ) + options.add_argument( + "--create-timeout", + type=int, + default=10, + help="The amount of seconds to spend on " "the initial contract creation", + ) + options.add_argument("--solc-args", help="Extra arguments for solc") + options.add_argument( + "--phrack", action="store_true", help="Phrack-style call graph" + ) + options.add_argument( + "--enable-physics", action="store_true", help="enable graph physics simulation" + ) + options.add_argument("-v", type=int, help="log level (0-2)", metavar="LOG_LEVEL") + + rpc = parser.add_argument_group("RPC options") + rpc.add_argument( + "-i", action="store_true", help="Preset: Infura Node service (Mainnet)" + ) + rpc.add_argument( + "--rpc", + help="custom RPC settings", + metavar="HOST:PORT / ganache / infura-[network_name]", + ) + rpc.add_argument( + "--rpctls", type=bool, default=False, help="RPC connection over TLS" + ) # Get config values args = parser.parse_args() if args.version: - if args.outform == 'json': - print(json.dumps({'version_str': VERSION})) + if args.outform == "json": + print(json.dumps({"version_str": VERSION})) else: print("Mythril version {}".format(VERSION)) sys.exit() # Parse cmdline args - if not (args.search or args.hash or args.disassemble or args.graph or args.fire_lasers - or args.storage or args.truffle or args.statespace_json or args.contract_hash_to_address): + if not ( + args.search + or args.hash + or args.disassemble + or args.graph + or args.fire_lasers + or args.storage + or args.truffle + or args.statespace_json + or args.contract_hash_to_address + ): parser.print_help() sys.exit() if args.v: if 0 <= args.v < 3: coloredlogs.install( - fmt='%(name)s[%(process)d] %(levelname)s %(message)s', - level=[logging.NOTSET, logging.INFO, logging.DEBUG][args.v] + fmt="%(name)s[%(process)d] %(levelname)s %(message)s", + level=[logging.NOTSET, logging.INFO, logging.DEBUG][args.v], ) else: - exit_with_error(args.outform, "Invalid -v value, you can find valid values in usage") + exit_with_error( + args.outform, "Invalid -v value, you can find valid values in usage" + ) # -- commands -- if args.hash: @@ -121,8 +241,7 @@ def main(): # infura = None, rpc = None, rpctls = None # solc_args = None, dynld = None, max_recursion_depth = 12): - mythril = Mythril(solv=args.solv, dynld=args.dynld, - solc_args=args.solc_args) + mythril = Mythril(solv=args.solv, dynld=args.dynld, solc_args=args.solc_args) if args.dynld and not (args.rpc or args.i): mythril.set_api_from_config_path() @@ -136,7 +255,9 @@ def main(): mythril.set_api_rpc_localhost() elif args.search or args.contract_hash_to_address: # Open LevelDB if necessary - mythril.set_api_leveldb(mythril.leveldb_dir if not args.leveldb_dir else args.leveldb_dir) + mythril.set_api_leveldb( + mythril.leveldb_dir if not args.leveldb_dir else args.leveldb_dir + ) if args.search: # Database search ops @@ -158,7 +279,8 @@ def main(): mythril.analyze_truffle_project(args) except FileNotFoundError: print( - "Build directory not found. Make sure that you start the analysis from the project root, and that 'truffle compile' has executed successfully.") + "Build directory not found. Make sure that you start the analysis from the project root, and that 'truffle compile' has executed successfully." + ) sys.exit() # Load / compile input contracts @@ -168,7 +290,7 @@ def main(): # Load from bytecode address, _ = mythril.load_from_bytecode(args.code) elif args.codefile: - bytecode = ''.join([l.strip() for l in args.codefile if len(l.strip()) > 0]) + bytecode = "".join([l.strip() for l in args.codefile if len(l.strip()) > 0]) address, _ = mythril.load_from_bytecode(bytecode) elif args.address: # Get bytecode from a contract address @@ -176,37 +298,55 @@ def main(): elif args.solidity_file: # Compile Solidity source file(s) if args.graph and len(args.solidity_file) > 1: - exit_with_error(args.outform, - "Cannot generate call graphs from multiple input files. Please do it one at a time.") + exit_with_error( + args.outform, + "Cannot generate call graphs from multiple input files. Please do it one at a time.", + ) address, _ = mythril.load_from_solidity(args.solidity_file) # list of files else: - exit_with_error(args.outform, - "No input bytecode. Please provide EVM code via -c BYTECODE, -a ADDRESS, or -i SOLIDITY_FILES") + exit_with_error( + args.outform, + "No input bytecode. Please provide EVM code via -c BYTECODE, -a ADDRESS, or -i SOLIDITY_FILES", + ) # Commands if args.storage: if not args.address: - exit_with_error(args.outform, - "To read storage, provide the address of a deployed contract with the -a option.") - - storage = mythril.get_state_variable_from_storage(address=address, - params=[a.strip() for a in args.storage.strip().split(",")]) + exit_with_error( + args.outform, + "To read storage, provide the address of a deployed contract with the -a option.", + ) + + storage = mythril.get_state_variable_from_storage( + address=address, + params=[a.strip() for a in args.storage.strip().split(",")], + ) print(storage) elif args.disassemble: - easm_text = mythril.contracts[0].get_easm() # or mythril.disassemble(mythril.contracts[0]) + easm_text = mythril.contracts[ + 0 + ].get_easm() # or mythril.disassemble(mythril.contracts[0]) sys.stdout.write(easm_text) elif args.graph or args.fire_lasers: if not mythril.contracts: - exit_with_error(args.outform, "input files do not contain any valid contracts") + exit_with_error( + args.outform, "input files do not contain any valid contracts" + ) if args.graph: - html = mythril.graph_html(strategy=args.strategy, contract=mythril.contracts[0], address=address, - enable_physics=args.enable_physics, phrackify=args.phrack, - max_depth=args.max_depth, execution_timeout=args.execution_timeout, - create_timeout=args.create_timeout) + html = mythril.graph_html( + strategy=args.strategy, + contract=mythril.contracts[0], + address=address, + enable_physics=args.enable_physics, + phrackify=args.phrack, + max_depth=args.max_depth, + execution_timeout=args.execution_timeout, + create_timeout=args.create_timeout, + ) try: with open(args.graph, "w") as f: @@ -215,27 +355,40 @@ def main(): exit_with_error(args.outform, "Error saving graph: " + str(e)) else: - report = mythril.fire_lasers(strategy=args.strategy, address=address, - modules=[m.strip() for m in args.modules.strip().split(",")] if args.modules else [], - verbose_report=args.verbose_report, - max_depth=args.max_depth, execution_timeout=args.execution_timeout, - create_timeout=args.create_timeout, - max_transaction_count=args.max_transaction_count) + report = mythril.fire_lasers( + strategy=args.strategy, + address=address, + modules=[m.strip() for m in args.modules.strip().split(",")] + if args.modules + else [], + verbose_report=args.verbose_report, + max_depth=args.max_depth, + execution_timeout=args.execution_timeout, + create_timeout=args.create_timeout, + max_transaction_count=args.max_transaction_count, + ) outputs = { - 'json': report.as_json(), - 'text': report.as_text(), - 'markdown': report.as_markdown() + "json": report.as_json(), + "text": report.as_text(), + "markdown": report.as_markdown(), } print(outputs[args.outform]) elif args.statespace_json: if not mythril.contracts: - exit_with_error(args.outform, "input files do not contain any valid contracts") - - statespace = mythril.dump_statespace(strategy=args.strategy, contract=mythril.contracts[0], address=address, - max_depth=args.max_depth, execution_timeout=args.execution_timeout, - create_timeout=args.create_timeout) + exit_with_error( + args.outform, "input files do not contain any valid contracts" + ) + + statespace = mythril.dump_statespace( + strategy=args.strategy, + contract=mythril.contracts[0], + address=address, + max_depth=args.max_depth, + execution_timeout=args.execution_timeout, + create_timeout=args.create_timeout, + ) try: with open(args.statespace_json, "w") as f: diff --git a/mythril/laser/ethereum/call.py b/mythril/laser/ethereum/call.py index c5200145..589d44a9 100644 --- a/mythril/laser/ethereum/call.py +++ b/mythril/laser/ethereum/call.py @@ -12,7 +12,9 @@ to get the necessary elements from the stack and determine the parameters for th """ -def get_call_parameters(global_state: GlobalState, dynamic_loader: DynLoader, with_value=False): +def get_call_parameters( + global_state: GlobalState, dynamic_loader: DynLoader, with_value=False +): """ Gets call parameters from global state Pops the values from the stack and determines output parameters @@ -23,21 +25,40 @@ def get_call_parameters(global_state: GlobalState, dynamic_loader: DynLoader, wi """ gas, to = global_state.mstate.pop(2) value = global_state.mstate.pop() if with_value else 0 - memory_input_offset, memory_input_size, memory_out_offset, memory_out_size = global_state.mstate.pop(4) + memory_input_offset, memory_input_size, memory_out_offset, memory_out_size = global_state.mstate.pop( + 4 + ) callee_address = get_callee_address(global_state, dynamic_loader, to) callee_account = None - call_data, call_data_type = get_call_data(global_state, memory_input_offset, memory_input_size, False) + call_data, call_data_type = get_call_data( + global_state, memory_input_offset, memory_input_size, False + ) if int(callee_address, 16) >= 5 or int(callee_address, 16) == 0: - call_data, call_data_type = get_call_data(global_state, memory_input_offset, memory_input_size) - callee_account = get_callee_account(global_state, callee_address, dynamic_loader) - - return callee_address, callee_account, call_data, value, call_data_type, gas, memory_out_offset, memory_out_size - - -def get_callee_address(global_state: GlobalState, dynamic_loader: DynLoader, symbolic_to_address: BitVecRef): + call_data, call_data_type = get_call_data( + global_state, memory_input_offset, memory_input_size + ) + callee_account = get_callee_account( + global_state, callee_address, dynamic_loader + ) + + return ( + callee_address, + callee_account, + call_data, + value, + call_data_type, + gas, + memory_out_offset, + memory_out_size, + ) + + +def get_callee_address( + global_state: GlobalState, dynamic_loader: DynLoader, symbolic_to_address: BitVecRef +): """ Gets the address of the callee :param global_state: state to look in @@ -52,7 +73,7 @@ def get_callee_address(global_state: GlobalState, dynamic_loader: DynLoader, sym except TypeError: logging.debug("Symbolic call encountered") - match = re.search(r'storage_(\d+)', str(simplify(symbolic_to_address))) + match = re.search(r"storage_(\d+)", str(simplify(symbolic_to_address))) logging.debug("CALL to: " + str(simplify(symbolic_to_address))) if match is None or dynamic_loader is None: @@ -63,7 +84,9 @@ def get_callee_address(global_state: GlobalState, dynamic_loader: DynLoader, sym # attempt to read the contract address from instance storage try: - callee_address = dynamic_loader.read_storage(environment.active_account.address, index) + callee_address = dynamic_loader.read_storage( + environment.active_account.address, index + ) # TODO: verify whether this happens or not except: logging.debug("Error accessing contract storage.") @@ -76,7 +99,9 @@ def get_callee_address(global_state: GlobalState, dynamic_loader: DynLoader, sym return callee_address -def get_callee_account(global_state: GlobalState, callee_address: str, dynamic_loader: DynLoader): +def get_callee_account( + global_state: GlobalState, callee_address: str, dynamic_loader: DynLoader +): """ Gets the callees account from the global_state :param global_state: state to look in @@ -108,7 +133,9 @@ def get_callee_account(global_state: GlobalState, callee_address: str, dynamic_l raise ValueError() logging.debug("Dependency loaded: " + callee_address) - callee_account = Account(callee_address, code, callee_address, dynamic_loader=dynamic_loader) + callee_account = Account( + callee_address, code, callee_address, dynamic_loader=dynamic_loader + ) accounts[callee_address] = callee_account return callee_account @@ -118,7 +145,7 @@ def get_call_data( global_state: GlobalState, memory_start: Union[int, BitVecNumRef, BoolRef], memory_size: Union[int, BitVecNumRef, BoolRef], - pad=True + pad=True, ): """ Gets call_data from the global_state @@ -132,7 +159,11 @@ def get_call_data( try: # TODO: This only allows for either fully concrete or fully symbolic calldata. # Improve management of memory and callata to support a mix between both types. - call_data = state.memory[util.get_concrete_int(memory_start):util.get_concrete_int(memory_start + memory_size)] + call_data = state.memory[ + util.get_concrete_int(memory_start) : util.get_concrete_int( + memory_start + memory_size + ) + ] if len(call_data) < 32 and pad: call_data += [0] * (32 - len(call_data)) call_data_type = CalldataType.CONCRETE diff --git a/mythril/laser/ethereum/cfg.py b/mythril/laser/ethereum/cfg.py index 73e5c868..5ccf5287 100644 --- a/mythril/laser/ethereum/cfg.py +++ b/mythril/laser/ethereum/cfg.py @@ -38,21 +38,27 @@ class Node: code = "" for state in self.states: instruction = state.get_current_instruction() - code += str(instruction['address']) + " " + instruction['opcode'] - if instruction['opcode'].startswith("PUSH"): - code += " " + instruction['argument'] + code += str(instruction["address"]) + " " + instruction["opcode"] + if instruction["opcode"].startswith("PUSH"): + code += " " + instruction["argument"] code += "\\n" return dict( contract_name=self.contract_name, start_addr=self.start_addr, function_name=self.function_name, - code=code + code=code, ) class Edge: - def __init__(self, node_from: int, node_to: int, edge_type=JumpType.UNCONDITIONAL, condition=None): + def __init__( + self, + node_from: int, + node_to: int, + edge_type=JumpType.UNCONDITIONAL, + condition=None, + ): self.node_from = node_from self.node_to = node_to self.type = edge_type @@ -63,4 +69,4 @@ class Edge: @property def as_dict(self) -> Dict[str, int]: - return {"from": self.node_from, 'to': self.node_to} + return {"from": self.node_from, "to": self.node_to} diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index aa3ec758..f0494ec1 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -4,19 +4,46 @@ from copy import copy, deepcopy from typing import Callable, List from ethereum import utils -from z3 import Extract, UDiv, simplify, Concat, ULT, UGT, BitVecNumRef, Not, \ - is_false, is_expr, ExprRef, URem, SRem, BitVec, Solver, is_true, BitVecVal, If, BoolRef, Or +from z3 import ( + Extract, + UDiv, + simplify, + Concat, + ULT, + UGT, + BitVecNumRef, + Not, + is_false, + is_expr, + ExprRef, + URem, + SRem, + BitVec, + Solver, + is_true, + BitVecVal, + If, + BoolRef, + Or, +) import mythril.laser.ethereum.natives as natives import mythril.laser.ethereum.util as helper from mythril.laser.ethereum import util from mythril.laser.ethereum.call import get_call_parameters -from mythril.laser.ethereum.evm_exceptions import VmException, StackUnderflowException, InvalidJumpDestination, \ - InvalidInstruction +from mythril.laser.ethereum.evm_exceptions import ( + VmException, + StackUnderflowException, + InvalidJumpDestination, + InvalidInstruction, +) from mythril.laser.ethereum.keccak import KeccakFunctionManager from mythril.laser.ethereum.state import GlobalState, CalldataType -from mythril.laser.ethereum.transaction import MessageCallTransaction, TransactionStartSignal, \ - ContractCreationTransaction +from mythril.laser.ethereum.transaction import ( + MessageCallTransaction, + TransactionStartSignal, + ContractCreationTransaction, +) from mythril.support.loader import DynLoader TT256 = 2 ** 256 @@ -48,13 +75,12 @@ class StateTransition(object): return states def __call__(self, func: Callable) -> Callable: - def wrapper(func_obj: "Instruction", global_state: GlobalState) -> List[GlobalState]: - new_global_states = self.call_on_state_copy( - func, - func_obj, - global_state - ) + def wrapper( + func_obj: "Instruction", global_state: GlobalState + ) -> List[GlobalState]: + new_global_states = self.call_on_state_copy(func, func_obj, global_state) return self.increment_states_pc(new_global_states) + return wrapper @@ -81,7 +107,11 @@ class Instruction: elif self.op_code.startswith("LOG"): op = "log" - instruction_mutator = getattr(self, op + '_', None) if not post else getattr(self, op + '_' + 'post', None) + instruction_mutator = ( + getattr(self, op + "_", None) + if not post + else getattr(self, op + "_" + "post", None) + ) if instruction_mutator is None: raise NotImplementedError @@ -95,12 +125,12 @@ class Instruction: @StateTransition() def push_(self, global_state: GlobalState) -> List[GlobalState]: push_instruction = global_state.get_current_instruction() - push_value = push_instruction['argument'][2:] + push_value = push_instruction["argument"][2:] try: - length_of_value = 2*int(push_instruction['opcode'][4:]) + length_of_value = 2 * int(push_instruction["opcode"][4:]) except ValueError: - raise VmException('Invalid Push instruction') + raise VmException("Invalid Push instruction") push_value += "0" * max(length_of_value - len(push_value), 0) global_state.mstate.stack.append(BitVecVal(int(push_value, 16), 256)) @@ -108,7 +138,7 @@ class Instruction: @StateTransition() def dup_(self, global_state: GlobalState) -> List[GlobalState]: - value = int(global_state.get_current_instruction()['opcode'][3:], 10) + value = int(global_state.get_current_instruction()["opcode"][3:], 10) global_state.mstate.stack.append(global_state.mstate.stack[-value]) return [global_state] @@ -174,12 +204,16 @@ class Instruction: index = util.get_concrete_int(op0) offset = (31 - index) * 8 if offset >= 0: - result = simplify(Concat(BitVecVal(0, 248), Extract(offset + 7, offset, op1))) + result = simplify( + Concat(BitVecVal(0, 248), Extract(offset + 7, offset, op1)) + ) else: result = 0 except TypeError: logging.debug("BYTE: Unsupported symbolic byte offset") - result = global_state.new_bitvec(str(simplify(op1)) + "[" + str(simplify(op0)) + "]", 256) + result = global_state.new_bitvec( + str(simplify(op1)) + "[" + str(simplify(op0)) + "]", 256 + ) mstate.stack.append(result) return [global_state] @@ -188,24 +222,39 @@ class Instruction: @StateTransition() def add_(self, global_state: GlobalState) -> List[GlobalState]: global_state.mstate.stack.append( - (helper.pop_bitvec(global_state.mstate) + helper.pop_bitvec(global_state.mstate))) + ( + helper.pop_bitvec(global_state.mstate) + + helper.pop_bitvec(global_state.mstate) + ) + ) return [global_state] @StateTransition() def sub_(self, global_state: GlobalState) -> List[GlobalState]: global_state.mstate.stack.append( - (helper.pop_bitvec(global_state.mstate) - helper.pop_bitvec(global_state.mstate))) + ( + helper.pop_bitvec(global_state.mstate) + - helper.pop_bitvec(global_state.mstate) + ) + ) return [global_state] @StateTransition() def mul_(self, global_state: GlobalState) -> List[GlobalState]: global_state.mstate.stack.append( - (helper.pop_bitvec(global_state.mstate) * helper.pop_bitvec(global_state.mstate))) + ( + helper.pop_bitvec(global_state.mstate) + * helper.pop_bitvec(global_state.mstate) + ) + ) return [global_state] @StateTransition() def div_(self, global_state: GlobalState) -> List[GlobalState]: - op0, op1 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate) + op0, op1 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) if op1 == 0: global_state.mstate.stack.append(BitVecVal(0, 256)) else: @@ -214,7 +263,10 @@ class Instruction: @StateTransition() def sdiv_(self, global_state: GlobalState) -> List[GlobalState]: - s0, s1 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate) + s0, s1 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) if s1 == 0: global_state.mstate.stack.append(BitVecVal(0, 256)) else: @@ -223,27 +275,39 @@ class Instruction: @StateTransition() def mod_(self, global_state: GlobalState) -> List[GlobalState]: - s0, s1 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate) + s0, s1 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) global_state.mstate.stack.append(0 if s1 == 0 else URem(s0, s1)) return [global_state] @StateTransition() def smod_(self, global_state: GlobalState) -> List[GlobalState]: - s0, s1 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate) + s0, s1 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) global_state.mstate.stack.append(0 if s1 == 0 else SRem(s0, s1)) return [global_state] @StateTransition() def addmod_(self, global_state: GlobalState) -> List[GlobalState]: - s0, s1, s2 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate), util.pop_bitvec( - global_state.mstate) + s0, s1, s2 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) global_state.mstate.stack.append(URem(URem(s0, s2) + URem(s1, s2), s2)) return [global_state] @StateTransition() def mulmod_(self, global_state: GlobalState) -> List[GlobalState]: - s0, s1, s2 = util.pop_bitvec(global_state.mstate), util.pop_bitvec(global_state.mstate), util.pop_bitvec( - global_state.mstate) + s0, s1, s2 = ( + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + util.pop_bitvec(global_state.mstate), + ) global_state.mstate.stack.append(URem(URem(s0, s2) * URem(s1, s2), s2)) return [global_state] @@ -253,9 +317,14 @@ class Instruction: base, exponent = util.pop_bitvec(state), util.pop_bitvec(state) if (type(base) != BitVecNumRef) or (type(exponent) != BitVecNumRef): - state.stack.append(global_state.new_bitvec("(" + str(simplify(base)) + ")**(" + str(simplify(exponent)) + ")", 256)) + state.stack.append( + global_state.new_bitvec( + "(" + str(simplify(base)) + ")**(" + str(simplify(exponent)) + ")", + 256, + ) + ) else: - state.stack.append(pow(base.as_long(), exponent.as_long(), 2**256)) + state.stack.append(pow(base.as_long(), exponent.as_long(), 2 ** 256)) return [global_state] @@ -359,31 +428,69 @@ class Instruction: b = environment.calldata[offset] except TypeError: logging.debug("CALLDATALOAD: Unsupported symbolic index") - state.stack.append(global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(simplify(op0)) + "]", 256)) + state.stack.append( + global_state.new_bitvec( + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(simplify(op0)) + + "]", + 256, + ) + ) return [global_state] except IndexError: logging.debug("Calldata not set, using symbolic variable instead") - state.stack.append(global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(simplify(op0)) + "]", 256)) + state.stack.append( + global_state.new_bitvec( + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(simplify(op0)) + + "]", + 256, + ) + ) return [global_state] if type(b) == int: try: - val = b''.join([calldata.to_bytes(1, byteorder='big') for calldata in - environment.calldata[offset:offset+32]]) + val = b"".join( + [ + calldata.to_bytes(1, byteorder="big") + for calldata in environment.calldata[offset : offset + 32] + ] + ) - logging.debug("Final value: " + str(int.from_bytes(val, byteorder='big'))) - state.stack.append(BitVecVal(int.from_bytes(val, byteorder='big'), 256)) + logging.debug( + "Final value: " + str(int.from_bytes(val, byteorder="big")) + ) + state.stack.append(BitVecVal(int.from_bytes(val, byteorder="big"), 256)) except (TypeError, AttributeError): - state.stack.append(global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(simplify(op0)) + "]", 256)) + state.stack.append( + global_state.new_bitvec( + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(simplify(op0)) + + "]", + 256, + ) + ) else: # symbolic variable - state.stack.append(global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(simplify(op0)) + "]", 256)) + state.stack.append( + global_state.new_bitvec( + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(simplify(op0)) + + "]", + 256, + ) + ) return [global_state] @@ -392,7 +499,11 @@ class Instruction: state = global_state.mstate environment = global_state.environment if environment.calldata_type == CalldataType.SYMBOLIC: - state.stack.append(global_state.new_bitvec("calldatasize_" + environment.active_account.contract_name, 256)) + state.stack.append( + global_state.new_bitvec( + "calldatasize_" + environment.active_account.contract_name, 256 + ) + ) else: state.stack.append(BitVecVal(len(environment.calldata), 256)) return [global_state] @@ -428,19 +539,38 @@ class Instruction: if dstart_sym or size_sym: state.mem_extend(mstart, 1) state.memory[mstart] = global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(dstart) + ": + " + str( - size) + "]", 256) + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(dstart) + + ": + " + + str(size) + + "]", + 256, + ) return [global_state] if size > 0: try: state.mem_extend(mstart, size) except TypeError: - logging.debug("Memory allocation error: mstart = " + str(mstart) + ", size = " + str(size)) + logging.debug( + "Memory allocation error: mstart = " + + str(mstart) + + ", size = " + + str(size) + ) state.mem_extend(mstart, 1) state.memory[mstart] = global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(dstart) + ": + " + str( - size) + "]", 256) + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(dstart) + + ": + " + + str(size) + + "]", + 256, + ) return [global_state] try: @@ -453,8 +583,15 @@ class Instruction: logging.debug("Exception copying calldata to memory") state.memory[mstart] = global_state.new_bitvec( - "calldata_" + str(environment.active_account.contract_name) + "[" + str(dstart) + ": + " + str( - size) + "]", 256) + "calldata_" + + str(environment.active_account.contract_name) + + "[" + + str(dstart) + + ": + " + + str(size) + + "]", + 256, + ) return [global_state] # Environment @@ -513,8 +650,12 @@ class Instruction: try: state.mem_extend(index, length) - data = b''.join([util.get_concrete_int(i).to_bytes(1, byteorder='big') - for i in state.memory[index: index + length]]) + data = b"".join( + [ + util.get_concrete_int(i).to_bytes(1, byteorder="big") + for i in state.memory[index : index + length] + ] + ) except TypeError: argument = str(state.memory[index]).replace(" ", "_") @@ -537,7 +678,11 @@ class Instruction: @StateTransition() def codecopy_(self, global_state: GlobalState) -> List[GlobalState]: - memory_offset, code_offset, size = global_state.mstate.stack.pop(), global_state.mstate.stack.pop(), global_state.mstate.stack.pop() + memory_offset, code_offset, size = ( + global_state.mstate.stack.pop(), + global_state.mstate.stack.pop(), + global_state.mstate.stack.pop(), + ) try: concrete_memory_offset = helper.get_concrete_int(memory_offset) @@ -551,8 +696,14 @@ class Instruction: except TypeError: # except both attribute error and Exception global_state.mstate.mem_extend(concrete_memory_offset, 1) - global_state.mstate.memory[concrete_memory_offset] = \ - global_state.new_bitvec("code({})".format(global_state.environment.active_account.contract_name), 256) + global_state.mstate.memory[ + concrete_memory_offset + ] = global_state.new_bitvec( + "code({})".format( + global_state.environment.active_account.contract_name + ), + 256, + ) return [global_state] try: @@ -561,26 +712,52 @@ class Instruction: logging.debug("Unsupported symbolic code offset in CODECOPY") global_state.mstate.mem_extend(concrete_memory_offset, concrete_size) for i in range(concrete_size): - global_state.mstate.memory[concrete_memory_offset + i] = \ - global_state.new_bitvec("code({})".format(global_state.environment.active_account.contract_name), 256) + global_state.mstate.memory[ + concrete_memory_offset + i + ] = global_state.new_bitvec( + "code({})".format( + global_state.environment.active_account.contract_name + ), + 256, + ) return [global_state] bytecode = global_state.environment.code.bytecode - if concrete_size == 0 and isinstance(global_state.current_transaction, ContractCreationTransaction): + if concrete_size == 0 and isinstance( + global_state.current_transaction, ContractCreationTransaction + ): if concrete_code_offset >= len(global_state.environment.code.bytecode) // 2: global_state.mstate.mem_extend(concrete_memory_offset, 1) - global_state.mstate.memory[concrete_memory_offset] = \ - global_state.new_bitvec("code({})".format(global_state.environment.active_account.contract_name), 256) + global_state.mstate.memory[ + concrete_memory_offset + ] = global_state.new_bitvec( + "code({})".format( + global_state.environment.active_account.contract_name + ), + 256, + ) return [global_state] for i in range(concrete_size): if 2 * (concrete_code_offset + i + 1) <= len(bytecode): - global_state.mstate.memory[concrete_memory_offset + i] =\ - int(bytecode[2*(concrete_code_offset + i): 2*(concrete_code_offset + i + 1)], 16) + global_state.mstate.memory[concrete_memory_offset + i] = int( + bytecode[ + 2 + * (concrete_code_offset + i) : 2 + * (concrete_code_offset + i + 1) + ], + 16, + ) else: - global_state.mstate.memory[concrete_memory_offset + i] = \ - global_state.new_bitvec("code({})".format(global_state.environment.active_account.contract_name), 256) + global_state.mstate.memory[ + concrete_memory_offset + i + ] = global_state.new_bitvec( + "code({})".format( + global_state.environment.active_account.contract_name + ), + 256, + ) return [global_state] @@ -627,7 +804,9 @@ class Instruction: def blockhash_(self, global_state: GlobalState) -> List[GlobalState]: state = global_state.mstate blocknumber = state.stack.pop() - state.stack.append(global_state.new_bitvec("blockhash_block_" + str(blocknumber), 256)) + state.stack.append( + global_state.new_bitvec("blockhash_block_" + str(blocknumber), 256) + ) return [global_state] @StateTransition() @@ -647,7 +826,9 @@ class Instruction: @StateTransition() def difficulty_(self, global_state: GlobalState) -> List[GlobalState]: - global_state.mstate.stack.append(global_state.new_bitvec("block_difficulty", 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("block_difficulty", 256) + ) return [global_state] @StateTransition() @@ -697,7 +878,9 @@ class Instruction: try: state.mem_extend(mstart, 32) except Exception: - logging.debug("Error extending memory, mstart = " + str(mstart) + ", size = 32") + logging.debug( + "Error extending memory, mstart = " + str(mstart) + ", size = 32" + ) logging.debug("MSTORE to mem[" + str(mstart) + "]: " + str(value)) @@ -706,7 +889,7 @@ class Instruction: _bytes = util.concrete_int_to_bytes(value) - state.memory[mstart:mstart+len(_bytes)] = _bytes + state.memory[mstart : mstart + len(_bytes)] = _bytes except (AttributeError, TypeError): try: @@ -735,7 +918,7 @@ class Instruction: @StateTransition() def sload_(self, global_state: GlobalState) -> List[GlobalState]: global keccak_function_manager - + state = global_state.mstate index = state.stack.pop() logging.debug("Storage access at index " + str(index)) @@ -761,15 +944,19 @@ class Instruction: for (keccak_key, constraint) in constraints: if constraint in state.constraints: - results += self._sload_helper(global_state, keccak_key, [constraint]) + results += self._sload_helper( + global_state, keccak_key, [constraint] + ) if len(results) > 0: return results for (keccak_key, constraint) in constraints: - results += self._sload_helper(copy(global_state), keccak_key, [constraint]) + results += self._sload_helper( + copy(global_state), keccak_key, [constraint] + ) if len(results) > 0: return results - + return self._sload_helper(global_state, str(index)) @staticmethod @@ -825,14 +1012,26 @@ class Instruction: index_argument = keccak_function_manager.get_argument(index) if is_true(key_argument == index_argument): - return self._sstore_helper(copy(global_state), keccak_key, value, key_argument == index_argument) - - results += self._sstore_helper(copy(global_state), keccak_key, value, key_argument == index_argument) + return self._sstore_helper( + copy(global_state), + keccak_key, + value, + key_argument == index_argument, + ) + + results += self._sstore_helper( + copy(global_state), + keccak_key, + value, + key_argument == index_argument, + ) new = Or(new, key_argument != index_argument) if len(results) > 0: - results += self._sstore_helper(copy(global_state), str(index), value, new) + results += self._sstore_helper( + copy(global_state), str(index), value, new + ) return results return self._sstore_helper(global_state, str(index), value) @@ -840,12 +1039,16 @@ class Instruction: @staticmethod def _sstore_helper(global_state, index, value, constraint=None): try: - global_state.environment.active_account = deepcopy(global_state.environment.active_account) + global_state.environment.active_account = deepcopy( + global_state.environment.active_account + ) global_state.accounts[ - global_state.environment.active_account.address] = global_state.environment.active_account + global_state.environment.active_account.address + ] = global_state.environment.active_account - global_state.environment.active_account.storage[index] =\ + global_state.environment.active_account.storage[index] = ( value if not isinstance(value, ExprRef) else simplify(value) + ) except KeyError: logging.debug("Error writing to storage: Invalid index") @@ -869,10 +1072,12 @@ class Instruction: if index is None: raise InvalidJumpDestination("JUMP to invalid address") - op_code = disassembly.instruction_list[index]['opcode'] + op_code = disassembly.instruction_list[index]["opcode"] if op_code != "JUMPDEST": - raise InvalidJumpDestination("Skipping JUMP to invalid destination (not JUMPDEST): " + str(jump_addr)) + raise InvalidJumpDestination( + "Skipping JUMP to invalid destination (not JUMPDEST): " + str(jump_addr) + ) new_state = copy(global_state) new_state.mstate.pc = index @@ -896,9 +1101,13 @@ class Instruction: return [global_state] # False case - negated = simplify(Not(condition)) if type(condition) == BoolRef else condition == 0 + negated = ( + simplify(Not(condition)) if type(condition) == BoolRef else condition == 0 + ) - if (type(negated) == bool and negated) or (type(negated) == BoolRef and not is_false(negated)): + if (type(negated) == bool and negated) or ( + type(negated) == BoolRef and not is_false(negated) + ): new_state = copy(global_state) new_state.mstate.depth += 1 new_state.mstate.pc += 1 @@ -918,8 +1127,10 @@ class Instruction: instr = disassembly.instruction_list[index] condi = simplify(condition) if type(condition) == BoolRef else condition != 0 - if instr['opcode'] == "JUMPDEST": - if (type(condi) == bool and condi) or (type(condi) == BoolRef and not is_false(condi)): + if instr["opcode"] == "JUMPDEST": + if (type(condi) == bool and condi) or ( + type(condi) == BoolRef and not is_false(condi) + ): new_state = copy(global_state) new_state.mstate.pc = index new_state.mstate.depth += 1 @@ -969,7 +1180,9 @@ class Instruction: offset, length = state.stack.pop(), state.stack.pop() return_data = [global_state.new_bitvec("return_data", 256)] try: - return_data = state.memory[util.get_concrete_int(offset):util.get_concrete_int(offset + length)] + return_data = state.memory[ + util.get_concrete_int(offset) : util.get_concrete_int(offset + length) + ] except TypeError: logging.debug("Return with symbolic length or offset. Not supported") global_state.current_transaction.end(global_state, return_data) @@ -981,14 +1194,16 @@ class Instruction: # Often the target of the suicide instruction will be symbolic # If it isn't then well transfer the balance to the indicated contract if isinstance(target, BitVecNumRef): - target = '0x' + hex(target.as_long())[-40:] + target = "0x" + hex(target.as_long())[-40:] if isinstance(target, str): try: - global_state.world_state[target].balance += global_state.environment.active_account.balance + global_state.world_state[ + target + ].balance += global_state.environment.active_account.balance except KeyError: global_state.world_state.create_account( address=target, - balance=global_state.environment.active_account.balance + balance=global_state.environment.active_account.balance, ) global_state.environment.active_account.balance = 0 @@ -1002,10 +1217,14 @@ class Instruction: offset, length = state.stack.pop(), state.stack.pop() return_data = [global_state.new_bitvec("return_data", 256)] try: - return_data = state.memory[util.get_concrete_int(offset):util.get_concrete_int(offset + length)] + return_data = state.memory[ + util.get_concrete_int(offset) : util.get_concrete_int(offset + length) + ] except TypeError: logging.debug("Return with symbolic length or offset. Not supported") - global_state.current_transaction.end(global_state, return_data=return_data, revert=True) + global_state.current_transaction.end( + global_state, return_data=return_data, revert=True + ) @StateTransition() def assert_fail_(self, global_state: GlobalState): @@ -1028,15 +1247,22 @@ class Instruction: try: callee_address, callee_account, call_data, value, call_data_type, gas, memory_out_offset, memory_out_size = get_call_parameters( - global_state, self.dynamic_loader, True) + global_state, self.dynamic_loader, True + ) except ValueError as e: logging.debug( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) ) # TODO: decide what to do in this case - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) return [global_state] - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) if 0 < int(callee_address, 16) < 5: logging.info("Native contract called: " + callee_address) @@ -1056,27 +1282,38 @@ class Instruction: try: data = natives.native_contracts(call_address_int, call_data) except natives.NativeContractException: - contract_list = ['ecerecover', 'sha256', 'ripemd160', 'identity'] + contract_list = ["ecerecover", "sha256", "ripemd160", "identity"] for i in range(mem_out_sz): - global_state.mstate.memory[mem_out_start + i] = global_state.new_bitvec(contract_list[call_address_int - 1] + - "(" + str(call_data) + ")", 256) + global_state.mstate.memory[ + mem_out_start + i + ] = global_state.new_bitvec( + contract_list[call_address_int - 1] + + "(" + + str(call_data) + + ")", + 256, + ) return [global_state] - for i in range(min(len(data), mem_out_sz)): # If more data is used then it's chopped off + for i in range( + min(len(data), mem_out_sz) + ): # If more data is used then it's chopped off global_state.mstate.memory[mem_out_start + i] = data[i] # TODO: maybe use BitVec here constrained to 1 return [global_state] - transaction = MessageCallTransaction(global_state.world_state, - callee_account, - BitVecVal(int(environment.active_account.address, 16), 256), - call_data=call_data, - gas_price=environment.gasprice, - call_value=value, - origin=environment.origin, - call_data_type=call_data_type) + transaction = MessageCallTransaction( + global_state.world_state, + callee_account, + BitVecVal(int(environment.active_account.address, 16), 256), + call_data=call_data, + gas_price=environment.gasprice, + call_value=value, + origin=environment.origin, + call_data_type=call_data_type, + ) raise TransactionStartSignal(transaction, self.op_code) @StateTransition() @@ -1085,36 +1322,57 @@ class Instruction: try: _, _, _, _, _, _, memory_out_offset, memory_out_size = get_call_parameters( - global_state, self.dynamic_loader, True) + global_state, self.dynamic_loader, True + ) except ValueError as e: logging.info( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) ) - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) return [global_state] if global_state.last_return_data is None: # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec( + "retval_" + str(instr["address"]), 256 + ) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 0) return [global_state] try: - memory_out_offset = util.get_concrete_int(memory_out_offset) if isinstance(memory_out_offset, ExprRef) else memory_out_offset - memory_out_size = util.get_concrete_int(memory_out_size) if isinstance(memory_out_size, ExprRef) else memory_out_size + memory_out_offset = ( + util.get_concrete_int(memory_out_offset) + if isinstance(memory_out_offset, ExprRef) + else memory_out_offset + ) + memory_out_size = ( + util.get_concrete_int(memory_out_size) + if isinstance(memory_out_size, ExprRef) + else memory_out_size + ) except TypeError: - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) return [global_state] # Copy memory - global_state.mstate.mem_extend(memory_out_offset, min(memory_out_size, len(global_state.last_return_data))) + global_state.mstate.mem_extend( + memory_out_offset, min(memory_out_size, len(global_state.last_return_data)) + ) for i in range(min(memory_out_size, len(global_state.last_return_data))): - global_state.mstate.memory[i + memory_out_offset] = global_state.last_return_data[i] + global_state.mstate.memory[ + i + memory_out_offset + ] = global_state.last_return_data[i] # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 1) @@ -1127,24 +1385,30 @@ class Instruction: try: callee_address, callee_account, call_data, value, call_data_type, gas, _, _ = get_call_parameters( - global_state, self.dynamic_loader, True) + global_state, self.dynamic_loader, True + ) except ValueError as e: logging.info( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) ) - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) return [global_state] - transaction = MessageCallTransaction(global_state.world_state, - environment.active_account, - environment.address, - call_data=call_data, - gas_price=environment.gasprice, - call_value=value, - origin=environment.origin, - call_data_type=call_data_type, - code=callee_account.code - ) + transaction = MessageCallTransaction( + global_state.world_state, + environment.active_account, + environment.address, + call_data=call_data, + gas_price=environment.gasprice, + call_value=value, + origin=environment.origin, + call_data_type=call_data_type, + code=callee_account.code, + ) raise TransactionStartSignal(transaction, self.op_code) @StateTransition() @@ -1153,109 +1417,152 @@ class Instruction: try: _, _, _, _, _, _, memory_out_offset, memory_out_size = get_call_parameters( - global_state, self.dynamic_loader, True) + global_state, self.dynamic_loader, True + ) except ValueError as e: logging.info( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) ) - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) return [global_state] if global_state.last_return_data is None: # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec( + "retval_" + str(instr["address"]), 256 + ) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 0) return [global_state] try: - memory_out_offset = util.get_concrete_int(memory_out_offset) if isinstance(memory_out_offset, ExprRef) else memory_out_offset - memory_out_size = util.get_concrete_int(memory_out_size) if isinstance(memory_out_size, ExprRef) else memory_out_size + memory_out_offset = ( + util.get_concrete_int(memory_out_offset) + if isinstance(memory_out_offset, ExprRef) + else memory_out_offset + ) + memory_out_size = ( + util.get_concrete_int(memory_out_size) + if isinstance(memory_out_size, ExprRef) + else memory_out_size + ) except TypeError: - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) return [global_state] # Copy memory - global_state.mstate.mem_extend(memory_out_offset, min(memory_out_size, len(global_state.last_return_data))) + global_state.mstate.mem_extend( + memory_out_offset, min(memory_out_size, len(global_state.last_return_data)) + ) for i in range(min(memory_out_size, len(global_state.last_return_data))): - global_state.mstate.memory[i + memory_out_offset] = global_state.last_return_data[i] + global_state.mstate.memory[ + i + memory_out_offset + ] = global_state.last_return_data[i] # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 1) return [global_state] - @StateTransition() def delegatecall_(self, global_state: GlobalState) -> List[GlobalState]: instr = global_state.get_current_instruction() environment = global_state.environment try: - callee_address, callee_account, call_data, _, call_data_type, gas, _, _ = get_call_parameters(global_state, - self.dynamic_loader) + callee_address, callee_account, call_data, _, call_data_type, gas, _, _ = get_call_parameters( + global_state, self.dynamic_loader + ) except ValueError as e: logging.info( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) ) - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) return [global_state] - transaction = MessageCallTransaction(global_state.world_state, - environment.active_account, - environment.sender, - call_data, - gas_price=environment.gasprice, - call_value=environment.callvalue, - origin=environment.origin, - call_data_type=call_data_type, - code=callee_account.code - ) + transaction = MessageCallTransaction( + global_state.world_state, + environment.active_account, + environment.sender, + call_data, + gas_price=environment.gasprice, + call_value=environment.callvalue, + origin=environment.origin, + call_data_type=call_data_type, + code=callee_account.code, + ) raise TransactionStartSignal(transaction, self.op_code) - @StateTransition() def delegatecall_post(self, global_state: GlobalState) -> List[GlobalState]: instr = global_state.get_current_instruction() try: - _, _, _, _, _, _, memory_out_offset, memory_out_size =\ - get_call_parameters(global_state, self.dynamic_loader) + _, _, _, _, _, _, memory_out_offset, memory_out_size = get_call_parameters( + global_state, self.dynamic_loader + ) except ValueError as e: logging.info( - "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format(e) + "Could not determine required parameters for call, putting fresh symbol on the stack. \n{}".format( + e + ) + ) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) ) - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) return [global_state] if global_state.last_return_data is None: # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec( + "retval_" + str(instr["address"]), 256 + ) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 0) return [global_state] try: - memory_out_offset = util.get_concrete_int(memory_out_offset) if isinstance(memory_out_offset, - ExprRef) else memory_out_offset - memory_out_size = util.get_concrete_int(memory_out_size) if isinstance(memory_out_size, - ExprRef) else memory_out_size + memory_out_offset = ( + util.get_concrete_int(memory_out_offset) + if isinstance(memory_out_offset, ExprRef) + else memory_out_offset + ) + memory_out_size = ( + util.get_concrete_int(memory_out_size) + if isinstance(memory_out_size, ExprRef) + else memory_out_size + ) except TypeError: - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) return [global_state] # Copy memory - global_state.mstate.mem_extend(memory_out_offset, - min(memory_out_size, len(global_state.last_return_data))) + global_state.mstate.mem_extend( + memory_out_offset, min(memory_out_size, len(global_state.last_return_data)) + ) for i in range(min(memory_out_size, len(global_state.last_return_data))): - global_state.mstate.memory[i + memory_out_offset] = global_state.last_return_data[i] + global_state.mstate.memory[ + i + memory_out_offset + ] = global_state.last_return_data[i] # Put return value on stack - return_value = global_state.new_bitvec("retval_" + str(instr['address']), 256) + return_value = global_state.new_bitvec("retval_" + str(instr["address"]), 256) global_state.mstate.stack.append(return_value) global_state.mstate.constraints.append(return_value == 1) @@ -1265,6 +1572,7 @@ class Instruction: def staticcall_(self, global_state: GlobalState) -> List[GlobalState]: # TODO: implement me instr = global_state.get_current_instruction() - global_state.mstate.stack.append(global_state.new_bitvec("retval_" + str(instr['address']), 256)) + global_state.mstate.stack.append( + global_state.new_bitvec("retval_" + str(instr["address"]), 256) + ) return [global_state] - diff --git a/mythril/laser/ethereum/natives.py b/mythril/laser/ethereum/natives.py index ae697085..1e379fbe 100644 --- a/mythril/laser/ethereum/natives.py +++ b/mythril/laser/ethereum/natives.py @@ -16,10 +16,12 @@ class NativeContractException(Exception): pass -def int_to_32bytes(i: int) -> bytes: # used because int can't fit as bytes function's input +def int_to_32bytes( + i: int +) -> bytes: # used because int can't fit as bytes function's input o = [0] * 32 for x in range(32): - o[31 - x] = i & 0xff + o[31 - x] = i & 0xFF i >>= 8 return bytes(o) @@ -27,7 +29,7 @@ def int_to_32bytes(i: int) -> bytes: # used because int can't fit as bytes fun def extract32(data: bytearray, i: int) -> int: if i >= len(data): return 0 - o = data[i: min(i + 32, len(data))] + o = data[i : min(i + 32, len(data))] o.extend(bytearray(32 - len(o))) return bytearray_to_int(o) @@ -42,13 +44,13 @@ def ecrecover(data: str) -> List: except TypeError: raise NativeContractException - message = b''.join([ALL_BYTES[x] for x in data[0:32]]) + message = b"".join([ALL_BYTES[x] for x in data[0:32]]) if r >= secp256k1n or s >= secp256k1n or v < 27 or v > 28: return [] try: pub = ecrecover_to_pub(message, v, r, s) except Exception as e: - logging.info("An error has occured while extracting public key: "+e) + logging.info("An error has occured while extracting public key: " + e) return [] o = [0] * 12 + [x for x in sha3(pub)[-20:]] return o @@ -67,7 +69,7 @@ def ripemd160(data: Union[bytes, str]) -> bytes: data = bytes(data) except TypeError: raise NativeContractException - digest = hashlib.new('ripemd160', data).digest() + digest = hashlib.new("ripemd160", data).digest() padded = 12 * [0] + list(digest) return bytes(padded) @@ -85,4 +87,4 @@ def native_contracts(address: int, data: List): takes integer address 1, 2, 3, 4 """ functions = (ecrecover, sha256, ripemd160, identity) - return functions[address-1](data) + return functions[address - 1](data) diff --git a/mythril/laser/ethereum/state.py b/mythril/laser/ethereum/state.py index 6344218f..6046b82d 100644 --- a/mythril/laser/ethereum/state.py +++ b/mythril/laser/ethereum/state.py @@ -5,7 +5,10 @@ from copy import copy, deepcopy from enum import Enum from random import randint from typing import KeysView, Dict, List, Union, Any -from mythril.laser.ethereum.evm_exceptions import StackOverflowException, StackUnderflowException +from mythril.laser.ethereum.evm_exceptions import ( + StackOverflowException, + StackUnderflowException, +) class CalldataType(Enum): @@ -17,6 +20,7 @@ class Storage: """ Storage class represents the storage of an Account """ + def __init__(self, concrete=False, address=None, dynamic_loader=None): """ Constructor for Storage @@ -33,7 +37,12 @@ class Storage: except KeyError: if self.address and int(self.address[2:], 16) != 0 and self.dynld: try: - self._storage[item] = int(self.dynld.read_storage(contract_address=self.address, index=int(item)), 16) + self._storage[item] = int( + self.dynld.read_storage( + contract_address=self.address, index=int(item) + ), + 16, + ) return self._storage[item] except ValueError: pass @@ -53,8 +62,16 @@ class Account: """ Account class representing ethereum accounts """ - def __init__(self, address: str, code=None, contract_name="unknown", balance=None, concrete_storage=False, - dynamic_loader=None): + + def __init__( + self, + address: str, + code=None, + contract_name="unknown", + balance=None, + concrete_storage=False, + dynamic_loader=None, + ): """ Constructor for account :param address: Address of the account @@ -66,7 +83,9 @@ class Account: self.nonce = 0 self.code = code or Disassembly("") self.balance = balance if balance else BitVec("balance", 256) - self.storage = Storage(concrete_storage, address=address, dynamic_loader=dynamic_loader) + self.storage = Storage( + concrete_storage, address=address, dynamic_loader=dynamic_loader + ) # Metadata self.address = address @@ -85,13 +104,19 @@ class Account: @property def as_dict(self) -> Dict: - return {'nonce': self.nonce, 'code': self.code, 'balance': self.balance, 'storage': self.storage} + return { + "nonce": self.nonce, + "code": self.code, + "balance": self.balance, + "storage": self.storage, + } class Environment: """ The environment class represents the current execution environment for the symbolic executor """ + def __init__( self, active_account: Account, @@ -125,15 +150,22 @@ class Environment: @property def as_dict(self) -> Dict: - return dict(active_account=self.active_account, sender=self.sender, calldata=self.calldata, - gasprice=self.gasprice, callvalue=self.callvalue, origin=self.origin, - calldata_type=self.calldata_type) + return dict( + active_account=self.active_account, + sender=self.sender, + calldata=self.calldata, + gasprice=self.gasprice, + callvalue=self.callvalue, + origin=self.origin, + calldata_type=self.calldata_type, + ) class MachineStack(list): """ Defines EVM stack, overrides the default list to handle overflows """ + STACK_LIMIT = 1024 def __init__(self, default_list=None): @@ -147,8 +179,10 @@ class MachineStack(list): :function: appends the element to list if the size is less than STACK_LIMIT, else throws an error """ if super(MachineStack, self).__len__() >= self.STACK_LIMIT: - raise StackOverflowException("Reached the EVM stack limit of {}, you can't append more " - "elements".format(self.STACK_LIMIT)) + raise StackOverflowException( + "Reached the EVM stack limit of {}, you can't append more " + "elements".format(self.STACK_LIMIT) + ) super(MachineStack, self).append(element) def pop(self, index=-1) -> BitVecNumRef: @@ -167,25 +201,28 @@ class MachineStack(list): try: return super(MachineStack, self).__getitem__(item) except IndexError: - raise StackUnderflowException("Trying to access a stack element which doesn't exist") + raise StackUnderflowException( + "Trying to access a stack element which doesn't exist" + ) def __add__(self, other): """ Implement list concatenation if needed """ - raise NotImplementedError('Implement this if needed') + raise NotImplementedError("Implement this if needed") def __iadd__(self, other): """ Implement list concatenation if needed """ - raise NotImplementedError('Implement this if needed') + raise NotImplementedError("Implement this if needed") class MachineState: """ MachineState represents current machine state also referenced to as \mu """ + def __init__(self, gas: int): """ Constructor for machineState """ self.pc = 0 @@ -203,13 +240,13 @@ class MachineState: """ if self.memory_size > start + size: return - m_extend = (start + size - self.memory_size) + m_extend = start + size - self.memory_size self.memory.extend(bytearray(m_extend)) def memory_write(self, offset: int, data: List[int]) -> None: """ Writes data to memory starting at offset """ self.mem_extend(offset, len(data)) - self.memory[offset:offset+len(data)] = data + self.memory[offset : offset + len(data)] = data def pop(self, amount=1) -> Union[BitVecRef, List[BitVecRef]]: """ Pops amount elements from the stack""" @@ -229,13 +266,20 @@ class MachineState: @property def as_dict(self) -> Dict: - return dict(pc=self.pc, stack=self.stack, memory=self.memory, memsize=self.memory_size, gas=self.gas) + return dict( + pc=self.pc, + stack=self.stack, + memory=self.memory, + memsize=self.memory_size, + gas=self.gas, + ) class GlobalState: """ GlobalState represents the current globalstate """ + def __init__( self, world_state: "WorldState", @@ -243,7 +287,7 @@ class GlobalState: node: Node, machine_state=None, transaction_stack=None, - last_return_data=None + last_return_data=None, ): """ Constructor for GlobalState""" self.node = node @@ -259,8 +303,14 @@ class GlobalState: environment = copy(self.environment) mstate = deepcopy(self.mstate) transaction_stack = copy(self.transaction_stack) - return GlobalState(world_state, environment, self.node, mstate, transaction_stack=transaction_stack, - last_return_data=self.last_return_data) + return GlobalState( + world_state, + environment, + self.node, + mstate, + transaction_stack=transaction_stack, + last_return_data=self.last_return_data, + ) @property def accounts(self) -> Dict: @@ -274,7 +324,9 @@ class GlobalState: return instructions[self.mstate.pc] @property - def current_transaction(self) -> Union["MessageCallTransaction", "ContractCreationTransaction", None]: + def current_transaction( + self + ) -> Union["MessageCallTransaction", "ContractCreationTransaction", None]: # TODO: Remove circular to transaction package to import Transaction classes try: return self.transaction_stack[-1][0] @@ -295,6 +347,7 @@ class WorldState: """ The WorldState class represents the world state as described in the yellow paper """ + def __init__(self, transaction_sequence=None): """ Constructor for the world state. Initializes the accounts record @@ -317,7 +370,9 @@ class WorldState: new_world_state.node = self.node return new_world_state - def create_account(self, balance=0, address=None, concrete_storage=False, dynamic_loader=None) -> Account: + def create_account( + self, balance=0, address=None, concrete_storage=False, dynamic_loader=None + ) -> Account: """ Create non-contract account :param address: The account's address @@ -327,7 +382,12 @@ class WorldState: :return: The new account """ address = address if address else self._generate_new_address() - new_account = Account(address, balance=balance, dynamic_loader=dynamic_loader, concrete_storage=concrete_storage) + new_account = Account( + address, + balance=balance, + dynamic_loader=dynamic_loader, + concrete_storage=concrete_storage, + ) self._put_account(new_account) return new_account @@ -340,14 +400,16 @@ class WorldState: :return: The new account """ # TODO: Add type hints - new_account = Account(self._generate_new_address(), code=contract_code, balance=0) + new_account = Account( + self._generate_new_address(), code=contract_code, balance=0 + ) new_account.storage = storage self._put_account(new_account) def _generate_new_address(self) -> str: """ Generates a new address for the global state""" while True: - address = '0x' + ''.join([str(hex(randint(0, 16)))[-1] for _ in range(20)]) + address = "0x" + "".join([str(hex(randint(0, 16)))[-1] for _ in range(20)]) if address not in self.accounts.keys(): return address diff --git a/mythril/laser/ethereum/strategy/basic.py b/mythril/laser/ethereum/strategy/basic.py index 18d16374..b67f602f 100644 --- a/mythril/laser/ethereum/strategy/basic.py +++ b/mythril/laser/ethereum/strategy/basic.py @@ -10,6 +10,7 @@ class DepthFirstSearchStrategy: Implements a depth first search strategy I.E. Follow one path to a leaf, and then continue to the next one """ + def __init__(self, work_list: List[GlobalState], max_depth: float): self.work_list = work_list self.max_depth = max_depth @@ -35,6 +36,7 @@ class BreadthFirstSearchStrategy: Implements a breadth first search strategy I.E. Execute all states of a "level" before continuing """ + def __init__(self, work_list: List[GlobalState], max_depth: float): self.work_list = work_list self.max_depth = max_depth @@ -53,4 +55,3 @@ class BreadthFirstSearchStrategy: return global_state except IndexError: raise StopIteration() - diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 5739484d..818b8d3f 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -1,8 +1,11 @@ import logging from typing import List, Tuple, Union, Callable from mythril.laser.ethereum.state import WorldState, GlobalState -from mythril.laser.ethereum.transaction import TransactionStartSignal, TransactionEndSignal, \ - ContractCreationTransaction +from mythril.laser.ethereum.transaction import ( + TransactionStartSignal, + TransactionEndSignal, + ContractCreationTransaction, +) from mythril.laser.ethereum.evm_exceptions import StackUnderflowException from mythril.laser.ethereum.instructions import Instruction from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType @@ -10,7 +13,10 @@ from mythril.laser.ethereum.state import Account from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy from datetime import datetime, timedelta from copy import copy -from mythril.laser.ethereum.transaction import execute_contract_creation, execute_message_call +from mythril.laser.ethereum.transaction import ( + execute_contract_creation, + execute_message_call, +) from functools import reduce from mythril.laser.ethereum.evm_exceptions import VmException @@ -19,9 +25,9 @@ class SVMError(Exception): pass -''' +""" Main symbolic execution engine. -''' +""" class LaserEVM: @@ -33,11 +39,11 @@ class LaserEVM: self, accounts: List[Account], dynamic_loader=None, - max_depth=float('inf'), + max_depth=float("inf"), execution_timeout=60, create_timeout=10, strategy=DepthFirstSearchStrategy, - max_transaction_count=3 + max_transaction_count=3, ): world_state = WorldState() world_state.accounts = accounts @@ -65,13 +71,21 @@ class LaserEVM: self.pre_hooks = {} self.post_hooks = {} - logging.info("LASER EVM initialized with dynamic loader: " + str(dynamic_loader)) + logging.info( + "LASER EVM initialized with dynamic loader: " + str(dynamic_loader) + ) @property def accounts(self) -> List[Account]: return self.world_state.accounts - def sym_exec(self, main_address=None, creation_code=None, contract_name=None, max_transactions=3) -> None: + def sym_exec( + self, + main_address=None, + creation_code=None, + contract_name=None, + max_transactions=3, + ) -> None: logging.debug("Starting LASER execution") self.time = datetime.now() @@ -80,11 +94,19 @@ class LaserEVM: execute_message_call(self, main_address) elif creation_code: logging.info("Starting contract creation transaction") - created_account = execute_contract_creation(self, creation_code, contract_name) - logging.info("Finished contract creation, found {} open states".format(len(self.open_states))) + created_account = execute_contract_creation( + self, creation_code, contract_name + ) + logging.info( + "Finished contract creation, found {} open states".format( + len(self.open_states) + ) + ) if len(self.open_states) == 0: - logging.warning("No contract was created during the execution of contract creation " - "Increase the resources for creation execution (--max-depth or --create_timeout)") + logging.warning( + "No contract was created during the execution of contract creation " + "Increase the resources for creation execution (--max-depth or --create_timeout)" + ) # Reset code coverage self.coverage = {} @@ -92,7 +114,9 @@ class LaserEVM: initial_coverage = self._get_covered_instructions() self.time = datetime.now() - logging.info("Starting message call transaction, iteration: {}".format(i)) + logging.info( + "Starting message call transaction, iteration: {}".format(i) + ) execute_message_call(self, created_account.address) end_coverage = self._get_covered_instructions() @@ -100,22 +124,36 @@ class LaserEVM: break logging.info("Finished symbolic execution") - logging.info("%d nodes, %d edges, %d total states", len(self.nodes), len(self.edges), self.total_states) + logging.info( + "%d nodes, %d edges, %d total states", + len(self.nodes), + len(self.edges), + self.total_states, + ) for code, coverage in self.coverage.items(): - cov = reduce(lambda sum_, val: sum_ + 1 if val else sum_, coverage[1]) / float(coverage[0]) * 100 + cov = ( + reduce(lambda sum_, val: sum_ + 1 if val else sum_, coverage[1]) + / float(coverage[0]) + * 100 + ) logging.info("Achieved {} coverage for code: {}".format(cov, code)) def _get_covered_instructions(self) -> int: """ Gets the total number of covered instructions for all accounts in the svm""" total_covered_instructions = 0 for _, cv in self.coverage.items(): - total_covered_instructions += reduce(lambda sum_, val: sum_ + 1 if val else sum_, cv[1]) + total_covered_instructions += reduce( + lambda sum_, val: sum_ + 1 if val else sum_, cv[1] + ) return total_covered_instructions - def exec(self, create=False)-> None: + def exec(self, create=False) -> None: for global_state in self.strategy: if self.execution_timeout and not create: - if self.time + timedelta(seconds=self.execution_timeout) <= datetime.now(): + if ( + self.time + timedelta(seconds=self.execution_timeout) + <= datetime.now() + ): return elif self.create_timeout and create: if self.time + timedelta(seconds=self.create_timeout) <= datetime.now(): @@ -132,10 +170,12 @@ class LaserEVM: self.work_list += new_states self.total_states += len(new_states) - def execute_state(self, global_state: GlobalState) -> Tuple[List[GlobalState], Union[str, None]]: + def execute_state( + self, global_state: GlobalState + ) -> Tuple[List[GlobalState], Union[str, None]]: instructions = global_state.environment.code.instruction_list try: - op_code = instructions[global_state.mstate.pc]['opcode'] + op_code = instructions[global_state.mstate.pc]["opcode"] except IndexError: self.open_states.append(global_state.world_state) return [], None @@ -143,7 +183,9 @@ class LaserEVM: self._execute_pre_hook(op_code, global_state) try: self._measure_coverage(global_state) - new_global_states = Instruction(op_code, self.dynamic_loader).evaluate(global_state) + new_global_states = Instruction(op_code, self.dynamic_loader).evaluate( + global_state + ) except VmException as e: transaction, return_global_state = global_state.transaction_stack.pop() @@ -152,29 +194,42 @@ class LaserEVM: # In this case we don't put an unmodified world state in the open_states list Since in the case of an # exceptional halt all changes should be discarded, and this world state would not provide us with a # previously unseen world state - logging.debug("Encountered a VmException, ending path: `{}`".format(str(e))) + logging.debug( + "Encountered a VmException, ending path: `{}`".format(str(e)) + ) new_global_states = [] else: # First execute the post hook for the transaction ending instruction self._execute_post_hook(op_code, [global_state]) - new_global_states = self._end_message_call(return_global_state, global_state, - revert_changes=True, return_data=None) + new_global_states = self._end_message_call( + return_global_state, + global_state, + revert_changes=True, + return_data=None, + ) except TransactionStartSignal as start_signal: # Setup new global state new_global_state = start_signal.transaction.initial_global_state() - new_global_state.transaction_stack = copy(global_state.transaction_stack) + [(start_signal.transaction, global_state)] + new_global_state.transaction_stack = copy( + global_state.transaction_stack + ) + [(start_signal.transaction, global_state)] new_global_state.node = global_state.node new_global_state.mstate.constraints = global_state.mstate.constraints return [new_global_state], op_code except TransactionEndSignal as end_signal: - transaction, return_global_state = end_signal.global_state.transaction_stack.pop() + transaction, return_global_state = ( + end_signal.global_state.transaction_stack.pop() + ) if return_global_state is None: - if (not isinstance(transaction, ContractCreationTransaction) or transaction.return_data) and not end_signal.revert: + if ( + not isinstance(transaction, ContractCreationTransaction) + or transaction.return_data + ) and not end_signal.revert: end_signal.global_state.world_state.node = global_state.node self.open_states.append(end_signal.global_state.world_state) new_global_states = [] @@ -182,9 +237,12 @@ class LaserEVM: # First execute the post hook for the transaction ending instruction self._execute_post_hook(op_code, [end_signal.global_state]) - new_global_states = self._end_message_call(return_global_state, global_state, - revert_changes=False or end_signal.revert, - return_data=transaction.return_data) + new_global_states = self._end_message_call( + return_global_state, + global_state, + revert_changes=False or end_signal.revert, + return_data=transaction.return_data, + ) self._execute_post_hook(op_code, new_global_states) @@ -195,20 +253,25 @@ class LaserEVM: return_global_state: GlobalState, global_state: GlobalState, revert_changes=False, - return_data=None + return_data=None, ) -> List[GlobalState]: # Resume execution of the transaction initializing instruction - op_code = return_global_state.environment.code.instruction_list[return_global_state.mstate.pc]['opcode'] + op_code = return_global_state.environment.code.instruction_list[ + return_global_state.mstate.pc + ]["opcode"] # Set execution result in the return_state return_global_state.last_return_data = return_data if not revert_changes: return_global_state.world_state = copy(global_state.world_state) - return_global_state.environment.active_account = \ - global_state.accounts[return_global_state.environment.active_account.address] + return_global_state.environment.active_account = global_state.accounts[ + return_global_state.environment.active_account.address + ] # Execute the post instruction handler - new_global_states = Instruction(op_code, self.dynamic_loader).evaluate(return_global_state, True) + new_global_states = Instruction(op_code, self.dynamic_loader).evaluate( + return_global_state, True + ) # In order to get a nice call graph we need to set the nodes here for state in new_global_states: @@ -222,7 +285,10 @@ class LaserEVM: instruction_index = global_state.mstate.pc if code not in self.coverage.keys(): - self.coverage[code] = [number_of_instructions, [False]*number_of_instructions] + self.coverage[code] = [ + number_of_instructions, + [False] * number_of_instructions, + ] self.coverage[code][1][instruction_index] = True @@ -233,19 +299,27 @@ class LaserEVM: self._new_node_state(state) elif opcode == "JUMPI": for state in new_states: - self._new_node_state(state, JumpType.CONDITIONAL, state.mstate.constraints[-1]) + self._new_node_state( + state, JumpType.CONDITIONAL, state.mstate.constraints[-1] + ) elif opcode in ("SLOAD", "SSTORE") and len(new_states) > 1: for state in new_states: - self._new_node_state(state, JumpType.CONDITIONAL, state.mstate.constraints[-1]) + self._new_node_state( + state, JumpType.CONDITIONAL, state.mstate.constraints[-1] + ) - elif opcode in ("CALL", 'CALLCODE', 'DELEGATECALL', 'STATICCALL'): + elif opcode in ("CALL", "CALLCODE", "DELEGATECALL", "STATICCALL"): assert len(new_states) <= 1 for state in new_states: self._new_node_state(state, JumpType.CALL) # Keep track of added contracts so the graph can be generated properly - if state.environment.active_account.contract_name not in self.world_state.accounts.keys(): + if ( + state.environment.active_account.contract_name + not in self.world_state.accounts.keys() + ): self.world_state.accounts[ - state.environment.active_account.address] = state.environment.active_account + state.environment.active_account.address + ] = state.environment.active_account elif opcode == "RETURN": for state in new_states: self._new_node_state(state, JumpType.RETURN) @@ -253,25 +327,29 @@ class LaserEVM: for state in new_states: state.node.states.append(state) - def _new_node_state(self, state: GlobalState, edge_type=JumpType.UNCONDITIONAL, condition=None) -> None: + def _new_node_state( + self, state: GlobalState, edge_type=JumpType.UNCONDITIONAL, condition=None + ) -> None: new_node = Node(state.environment.active_account.contract_name) old_node = state.node state.node = new_node new_node.constraints = state.mstate.constraints self.nodes[new_node.uid] = new_node - self.edges.append(Edge(old_node.uid, new_node.uid, edge_type=edge_type, condition=condition)) + self.edges.append( + Edge(old_node.uid, new_node.uid, edge_type=edge_type, condition=condition) + ) if edge_type == JumpType.RETURN: new_node.flags |= NodeFlags.CALL_RETURN elif edge_type == JumpType.CALL: try: - if 'retval' in str(state.mstate.stack[-1]): + if "retval" in str(state.mstate.stack[-1]): new_node.flags |= NodeFlags.CALL_RETURN else: new_node.flags |= NodeFlags.FUNC_ENTRY except StackUnderflowException: new_node.flags |= NodeFlags.FUNC_ENTRY - address = state.environment.code.instruction_list[state.mstate.pc]['address'] + address = state.environment.code.instruction_list[state.mstate.pc]["address"] environment = state.environment disassembly = environment.code @@ -282,7 +360,11 @@ class LaserEVM: new_node.flags |= NodeFlags.FUNC_ENTRY logging.debug( - "- Entering function " + environment.active_account.contract_name + ":" + new_node.function_name) + "- Entering function " + + environment.active_account.contract_name + + ":" + + new_node.function_name + ) elif address == 0: environment.active_function_name = "fallback" @@ -294,7 +376,9 @@ class LaserEVM: for hook in self.pre_hooks[op_code]: hook(global_state) - def _execute_post_hook(self, op_code: str, global_states: List[GlobalState]) -> None: + def _execute_post_hook( + self, op_code: str, global_states: List[GlobalState] + ) -> None: if op_code not in self.post_hooks.keys(): return diff --git a/mythril/laser/ethereum/taint_analysis.py b/mythril/laser/ethereum/taint_analysis.py index 502d9f6a..d014e78f 100644 --- a/mythril/laser/ethereum/taint_analysis.py +++ b/mythril/laser/ethereum/taint_analysis.py @@ -88,7 +88,9 @@ class TaintRunner: """ @staticmethod - def execute(statespace: SymExecWrapper, node: Node, state: GlobalState, initial_stack=None) -> TaintResult: + def execute( + statespace: SymExecWrapper, node: Node, state: GlobalState, initial_stack=None + ) -> TaintResult: """ Runs taint analysis on the statespace :param statespace: symbolic statespace to run taint analysis on @@ -115,9 +117,11 @@ class TaintRunner: records = TaintRunner.execute_node(node, record, index) result.add_records(records) - if len(records) == 0: # continue if there is no record to work on + if len(records) == 0: # continue if there is no record to work on continue - children = TaintRunner.children(node, statespace, environment, transaction_stack_length) + children = TaintRunner.children( + node, statespace, environment, transaction_stack_length + ) for child in children: current_nodes.append((child, records[-1], 0)) return result @@ -127,19 +131,33 @@ class TaintRunner: node: Node, statespace: SymExecWrapper, environment: Environment, - transaction_stack_length: int + transaction_stack_length: int, ) -> List[Node]: - direct_children = [statespace.nodes[edge.node_to] for edge in statespace.edges if edge.node_from == node.uid and edge.type != JumpType.Transaction] + direct_children = [ + statespace.nodes[edge.node_to] + for edge in statespace.edges + if edge.node_from == node.uid and edge.type != JumpType.Transaction + ] children = [] for child in direct_children: - if all(len(state.transaction_stack) == transaction_stack_length for state in child.states): + if all( + len(state.transaction_stack) == transaction_stack_length + for state in child.states + ): children.append(child) - elif all(len(state.transaction_stack) > transaction_stack_length for state in child.states): - children += TaintRunner.children(child, statespace, environment, transaction_stack_length) + elif all( + len(state.transaction_stack) > transaction_stack_length + for state in child.states + ): + children += TaintRunner.children( + child, statespace, environment, transaction_stack_length + ) return children @staticmethod - def execute_node(node: Node, last_record: TaintRecord, state_index=0) -> List[TaintRecord]: + def execute_node( + node: Node, last_record: TaintRecord, state_index=0 + ) -> List[TaintRecord]: """ Runs taint analysis on a given node :param node: node to analyse @@ -161,7 +179,7 @@ class TaintRunner: new_record = record.clone() # Apply Change - op = state.get_current_instruction()['opcode'] + op = state.get_current_instruction()["opcode"] if op in TaintRunner.stack_taint_table.keys(): mutator = TaintRunner.stack_taint_table[op] @@ -182,7 +200,7 @@ class TaintRunner: TaintRunner.mutate_sstore(new_record, state.mstate.stack[-1]) elif op.startswith("LOG"): TaintRunner.mutate_log(new_record, op) - elif op in ('CALL', 'CALLCODE', 'DELEGATECALL', 'STATICCALL'): + elif op in ("CALL", "CALLCODE", "DELEGATECALL", "STATICCALL"): TaintRunner.mutate_call(new_record, op) else: logging.debug("Unknown operation encountered: {}".format(op)) @@ -274,7 +292,7 @@ class TaintRunner: @staticmethod def mutate_call(record: TaintRecord, op: str) -> None: pops = 6 - if op in ('CALL', 'CALLCODE'): + if op in ("CALL", "CALLCODE"): pops += 1 for _ in range(pops): record.stack.pop() @@ -283,55 +301,55 @@ class TaintRunner: stack_taint_table = { # instruction: (taint source, taint target) - 'POP': (1, 0), - 'ADD': (2, 1), - 'MUL': (2, 1), - 'SUB': (2, 1), - 'AND': (2, 1), - 'OR': (2, 1), - 'XOR': (2, 1), - 'NOT': (1, 1), - 'BYTE': (2, 1), - 'DIV': (2, 1), - 'MOD': (2, 1), - 'SDIV': (2, 1), - 'SMOD': (2, 1), - 'ADDMOD': (3, 1), - 'MULMOD': (3, 1), - 'EXP': (2, 1), - 'SIGNEXTEND': (2, 1), - 'LT': (2, 1), - 'GT': (2, 1), - 'SLT': (2, 1), - 'SGT': (2, 1), - 'EQ': (2, 1), - 'ISZERO': (1, 1), - 'CALLVALUE': (0, 1), - 'CALLDATALOAD': (1, 1), - 'CALLDATACOPY': (3, 0), # todo - 'CALLDATASIZE': (0, 1), - 'ADDRESS': (0, 1), - 'BALANCE': (1, 1), - 'ORIGIN': (0, 1), - 'CALLER': (0, 1), - 'CODESIZE': (0, 1), - 'SHA3': (2, 1), - 'GASPRICE': (0, 1), - 'CODECOPY': (3, 0), - 'EXTCODESIZE': (1, 1), - 'EXTCODECOPY': (4, 0), - 'RETURNDATASIZE': (0, 1), - 'BLOCKHASH': (1, 1), - 'COINBASE': (0, 1), - 'TIMESTAMP': (0, 1), - 'NUMBER': (0, 1), - 'DIFFICULTY': (0, 1), - 'GASLIMIT': (0, 1), - 'JUMP': (1, 0), - 'JUMPI': (2, 0), - 'PC': (0, 1), - 'MSIZE': (0, 1), - 'GAS': (0, 1), - 'CREATE': (3, 1), - 'RETURN': (2, 0) + "POP": (1, 0), + "ADD": (2, 1), + "MUL": (2, 1), + "SUB": (2, 1), + "AND": (2, 1), + "OR": (2, 1), + "XOR": (2, 1), + "NOT": (1, 1), + "BYTE": (2, 1), + "DIV": (2, 1), + "MOD": (2, 1), + "SDIV": (2, 1), + "SMOD": (2, 1), + "ADDMOD": (3, 1), + "MULMOD": (3, 1), + "EXP": (2, 1), + "SIGNEXTEND": (2, 1), + "LT": (2, 1), + "GT": (2, 1), + "SLT": (2, 1), + "SGT": (2, 1), + "EQ": (2, 1), + "ISZERO": (1, 1), + "CALLVALUE": (0, 1), + "CALLDATALOAD": (1, 1), + "CALLDATACOPY": (3, 0), # todo + "CALLDATASIZE": (0, 1), + "ADDRESS": (0, 1), + "BALANCE": (1, 1), + "ORIGIN": (0, 1), + "CALLER": (0, 1), + "CODESIZE": (0, 1), + "SHA3": (2, 1), + "GASPRICE": (0, 1), + "CODECOPY": (3, 0), + "EXTCODESIZE": (1, 1), + "EXTCODECOPY": (4, 0), + "RETURNDATASIZE": (0, 1), + "BLOCKHASH": (1, 1), + "COINBASE": (0, 1), + "TIMESTAMP": (0, 1), + "NUMBER": (0, 1), + "DIFFICULTY": (0, 1), + "GASLIMIT": (0, 1), + "JUMP": (1, 0), + "JUMPI": (2, 0), + "PC": (0, 1), + "MSIZE": (0, 1), + "GAS": (0, 1), + "CREATE": (3, 1), + "RETURN": (2, 0), } diff --git a/mythril/laser/ethereum/transaction/__init__.py b/mythril/laser/ethereum/transaction/__init__.py index ef953047..8ab310fc 100644 --- a/mythril/laser/ethereum/transaction/__init__.py +++ b/mythril/laser/ethereum/transaction/__init__.py @@ -1,2 +1,5 @@ from mythril.laser.ethereum.transaction.transaction_models import * -from mythril.laser.ethereum.transaction.symbolic import execute_message_call, execute_contract_creation +from mythril.laser.ethereum.transaction.symbolic import ( + execute_message_call, + execute_contract_creation, +) diff --git a/mythril/laser/ethereum/transaction/concolic.py b/mythril/laser/ethereum/transaction/concolic.py index cac7efb0..597de21d 100644 --- a/mythril/laser/ethereum/transaction/concolic.py +++ b/mythril/laser/ethereum/transaction/concolic.py @@ -1,11 +1,31 @@ -from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction, ContractCreationTransaction, get_next_transaction_id +from mythril.laser.ethereum.transaction.transaction_models import ( + MessageCallTransaction, + ContractCreationTransaction, + get_next_transaction_id, +) from z3 import BitVec -from mythril.laser.ethereum.state import GlobalState, Environment, CalldataType, Account, WorldState +from mythril.laser.ethereum.state import ( + GlobalState, + Environment, + CalldataType, + Account, + WorldState, +) from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.cfg import Node, Edge, JumpType -def execute_message_call(laser_evm, callee_address, caller_address, origin_address, code, data, gas, gas_price, value) -> None: +def execute_message_call( + laser_evm, + callee_address, + caller_address, + origin_address, + code, + data, + gas, + gas_price, + value, +) -> None: """ Executes a message call transaction from all open states """ # TODO: Resolve circular import between .transaction and ..svm to import LaserEVM here open_states = laser_evm.open_states[:] @@ -22,7 +42,7 @@ def execute_message_call(laser_evm, callee_address, caller_address, origin_addre call_value=value, origin=origin_address, call_data_type=CalldataType.SYMBOLIC, - code=Disassembly(code) + code=Disassembly(code), ) _setup_global_state_for_execution(laser_evm, transaction) @@ -40,8 +60,14 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None: laser_evm.nodes[new_node.uid] = new_node if transaction.world_state.node: - laser_evm.edges.append(Edge(transaction.world_state.node.uid, new_node.uid, edge_type=JumpType.Transaction, - condition=None)) + laser_evm.edges.append( + Edge( + transaction.world_state.node.uid, + new_node.uid, + edge_type=JumpType.Transaction, + condition=None, + ) + ) global_state.node = new_node new_node.states.append(global_state) laser_evm.work_list.append(global_state) diff --git a/mythril/laser/ethereum/transaction/symbolic.py b/mythril/laser/ethereum/transaction/symbolic.py index fc2f0671..540b2b6c 100644 --- a/mythril/laser/ethereum/transaction/symbolic.py +++ b/mythril/laser/ethereum/transaction/symbolic.py @@ -4,8 +4,11 @@ from logging import debug from mythril.disassembler.disassembly import Disassembly from mythril.laser.ethereum.cfg import Node, Edge, JumpType from mythril.laser.ethereum.state import CalldataType, Account -from mythril.laser.ethereum.transaction.transaction_models import MessageCallTransaction, ContractCreationTransaction,\ - get_next_transaction_id +from mythril.laser.ethereum.transaction.transaction_models import ( + MessageCallTransaction, + ContractCreationTransaction, + get_next_transaction_id, +) def execute_message_call(laser_evm, callee_address: str) -> None: @@ -36,13 +39,17 @@ def execute_message_call(laser_evm, callee_address: str) -> None: laser_evm.exec() -def execute_contract_creation(laser_evm, contract_initialization_code, contract_name=None) -> Account: +def execute_contract_creation( + laser_evm, contract_initialization_code, contract_name=None +) -> Account: """ Executes a contract creation transaction from all open states""" # TODO: Resolve circular import between .transaction and ..svm to import LaserEVM here open_states = laser_evm.open_states[:] del laser_evm.open_states[:] - new_account = laser_evm.world_state.create_account(0, concrete_storage=True, dynamic_loader=None) + new_account = laser_evm.world_state.create_account( + 0, concrete_storage=True, dynamic_loader=None + ) if contract_name: new_account.contract_name = contract_name @@ -58,7 +65,7 @@ def execute_contract_creation(laser_evm, contract_initialization_code, contract_ BitVec("gas_price{}".format(next_transaction_id), 256), BitVec("call_value{}".format(next_transaction_id), 256), BitVec("origin{}".format(next_transaction_id), 256), - CalldataType.SYMBOLIC + CalldataType.SYMBOLIC, ) _setup_global_state_for_execution(laser_evm, transaction) laser_evm.exec(True) @@ -76,8 +83,14 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None: laser_evm.nodes[new_node.uid] = new_node if transaction.world_state.node: - laser_evm.edges.append(Edge(transaction.world_state.node.uid, new_node.uid, edge_type=JumpType.Transaction, - condition=None)) + laser_evm.edges.append( + Edge( + transaction.world_state.node.uid, + new_node.uid, + edge_type=JumpType.Transaction, + condition=None, + ) + ) global_state.mstate.constraints = transaction.world_state.node.constraints new_node.constraints = global_state.mstate.constraints diff --git a/mythril/laser/ethereum/transaction/transaction_models.py b/mythril/laser/ethereum/transaction/transaction_models.py index 39a98649..50a2e540 100644 --- a/mythril/laser/ethereum/transaction/transaction_models.py +++ b/mythril/laser/ethereum/transaction/transaction_models.py @@ -16,6 +16,7 @@ def get_next_transaction_id() -> int: class TransactionEndSignal(Exception): """ Exception raised when a transaction is finalized""" + def __init__(self, global_state: GlobalState, revert=False): self.global_state = global_state self.revert = revert @@ -23,10 +24,11 @@ class TransactionEndSignal(Exception): class TransactionStartSignal(Exception): """ Exception raised when a new transaction is started""" + def __init__( self, transaction: Union["MessageCallTransaction", "ContractCreationTransaction"], - op_code: str + op_code: str, ): self.transaction = transaction self.op_code = op_code @@ -34,28 +36,44 @@ class TransactionStartSignal(Exception): class MessageCallTransaction: """ Transaction object models an transaction""" - def __init__(self, - world_state: WorldState, - callee_account: Account, - caller: BitVecNumRef, - call_data=(), - identifier=None, - gas_price=None, - call_value=None, - origin=None, - call_data_type=None, - code=None - ): + + def __init__( + self, + world_state: WorldState, + callee_account: Account, + caller: BitVecNumRef, + call_data=(), + identifier=None, + gas_price=None, + call_value=None, + origin=None, + call_data_type=None, + code=None, + ): assert isinstance(world_state, WorldState) self.id = identifier or get_next_transaction_id() self.world_state = world_state self.callee_account = callee_account self.caller = caller self.call_data = call_data - self.gas_price = BitVec("gasprice{}".format(identifier), 256) if gas_price is None else gas_price - self.call_value = BitVec("callvalue{}".format(identifier), 256) if call_value is None else call_value - self.origin = BitVec("origin{}".format(identifier), 256) if origin is None else origin - self.call_data_type = BitVec("call_data_type{}".format(identifier), 256) if call_data_type is None else call_data_type + self.gas_price = ( + BitVec("gasprice{}".format(identifier), 256) + if gas_price is None + else gas_price + ) + self.call_value = ( + BitVec("callvalue{}".format(identifier), 256) + if call_value is None + else call_value + ) + self.origin = ( + BitVec("origin{}".format(identifier), 256) if origin is None else origin + ) + self.call_data_type = ( + BitVec("call_data_type{}".format(identifier), 256) + if call_data_type is None + else call_data_type + ) self.code = code self.return_data = None @@ -73,7 +91,7 @@ class MessageCallTransaction: ) global_state = GlobalState(self.world_state, environment, None) - global_state.environment.active_function_name = 'fallback' + global_state.environment.active_function_name = "fallback" return global_state @@ -84,30 +102,50 @@ class MessageCallTransaction: class ContractCreationTransaction: """ Transaction object models an transaction""" - def __init__(self, - world_state: WorldState, - caller: BitVecNumRef, - identifier=None, - callee_account=None, - code=None, - call_data=(), - gas_price=None, - call_value=None, - origin=None, - call_data_type=None, - ): + + def __init__( + self, + world_state: WorldState, + caller: BitVecNumRef, + identifier=None, + callee_account=None, + code=None, + call_data=(), + gas_price=None, + call_value=None, + origin=None, + call_data_type=None, + ): assert isinstance(world_state, WorldState) self.id = identifier or get_next_transaction_id() self.world_state = world_state # TODO: set correct balance for new account - self.callee_account = callee_account if callee_account else world_state.create_account(0, concrete_storage=True) + self.callee_account = ( + callee_account + if callee_account + else world_state.create_account(0, concrete_storage=True) + ) self.caller = caller - self.gas_price = BitVec("gasprice{}".format(identifier), 256) if gas_price is None else gas_price - self.call_value = BitVec("callvalue{}".format(identifier), 256) if call_value is None else call_value - self.origin = BitVec("origin{}".format(identifier), 256) if origin is None else origin - self.call_data_type = BitVec("call_data_type{}".format(identifier), 256) if call_data_type is None else call_data_type + self.gas_price = ( + BitVec("gasprice{}".format(identifier), 256) + if gas_price is None + else gas_price + ) + self.call_value = ( + BitVec("callvalue{}".format(identifier), 256) + if call_value is None + else call_value + ) + self.origin = ( + BitVec("origin{}".format(identifier), 256) if origin is None else origin + ) + self.call_data_type = ( + BitVec("call_data_type{}".format(identifier), 256) + if call_data_type is None + else call_data_type + ) self.call_data = call_data self.origin = origin @@ -128,16 +166,19 @@ class ContractCreationTransaction: ) global_state = GlobalState(self.world_state, environment, None) - global_state.environment.active_function_name = 'constructor' + global_state.environment.active_function_name = "constructor" return global_state def end(self, global_state: GlobalState, return_data=None, revert=False) -> None: - if not all([isinstance(element, int) for element in return_data]) or len(return_data) == 0: + if ( + not all([isinstance(element, int) for element in return_data]) + or len(return_data) == 0 + ): self.return_data = None raise TransactionEndSignal(global_state) - contract_code = bytes.hex(array.array('B', return_data).tostring()) + contract_code = bytes.hex(array.array("B", return_data).tostring()) global_state.environment.active_account.code = Disassembly(contract_code) self.return_data = global_state.environment.active_account.address diff --git a/mythril/laser/ethereum/util.py b/mythril/laser/ethereum/util.py index a7adc718..d27fa89f 100644 --- a/mythril/laser/ethereum/util.py +++ b/mythril/laser/ethereum/util.py @@ -27,10 +27,12 @@ def to_signed(i: int) -> int: return i if i < TT255 else i - TT256 -def get_instruction_index(instruction_list: List[Dict], address: int) -> Union[int, None]: +def get_instruction_index( + instruction_list: List[Dict], address: int +) -> Union[int, None]: index = 0 for instr in instruction_list: - if instr['address'] == address: + if instr["address"] == address: return index index += 1 return None @@ -40,7 +42,7 @@ def get_trace_line(instr: Dict, state: MachineState) -> str: stack = str(state.stack[::-1]) # stack = re.sub("(\d+)", lambda m: hex(int(m.group(1))), stack) stack = re.sub("\n", "", stack) - return str(instr['address']) + " " + instr['opcode'] + "\tSTACK: " + stack + return str(instr["address"]) + " " + instr["opcode"] + "\tSTACK: " + stack def pop_bitvec(state: MachineState) -> BitVecVal: @@ -83,15 +85,15 @@ def get_concrete_int(item: Union[int, BitVecNumRef, BoolRef]) -> int: def concrete_int_from_bytes(_bytes: bytes, start_index: int) -> int: - b = _bytes[start_index:start_index+32] - val = int.from_bytes(b, byteorder='big') + b = _bytes[start_index : start_index + 32] + val = int.from_bytes(b, byteorder="big") return val def concrete_int_to_bytes(val: int) -> bytes: if isinstance(val, int): - return val.to_bytes(32, byteorder='big') - return (simplify(val).as_long()).to_bytes(32, byteorder='big') + return val.to_bytes(32, byteorder="big") + return (simplify(val).as_long()).to_bytes(32, byteorder="big") def bytearray_to_int(arr: bytearray) -> int: @@ -99,4 +101,3 @@ def bytearray_to_int(arr: bytearray) -> int: for a in arr: o = (o << 8) + a return o - diff --git a/mythril/mythril.py b/mythril/mythril.py index 123ab904..d917d7f8 100644 --- a/mythril/mythril.py +++ b/mythril/mythril.py @@ -35,6 +35,7 @@ from mythril.ethereum.interface.leveldb.client import EthLevelDB # logging.basicConfig(level=logging.DEBUG) + class Mythril(object): """ Mythril main interface class. @@ -75,8 +76,8 @@ class Mythril(object): mythril.get_state_variable_from_storage(args) """ - def __init__(self, solv=None, - solc_args=None, dynld=False): + + def __init__(self, solv=None, solc_args=None, dynld=False): self.solv = solv self.solc_args = solc_args @@ -89,26 +90,34 @@ class Mythril(object): self.sigs.open() # tries mythril_dir/signatures.json by default (provide path= arg to make this configurable) except FileNotFoundError as fnfe: logging.info( - "No signature database found. Creating database if sigs are loaded in: " + self.sigs.signatures_file + "\n" + - "Consider replacing it with the pre-initialized database at https://raw.githubusercontent.com/ConsenSys/mythril/master/signatures.json") + "No signature database found. Creating database if sigs are loaded in: " + + self.sigs.signatures_file + + "\n" + + "Consider replacing it with the pre-initialized database at https://raw.githubusercontent.com/ConsenSys/mythril/master/signatures.json" + ) except json.JSONDecodeError as jde: - raise CriticalError("Invalid JSON in signatures file " + self.sigs.signatures_file + "\n" + str(jde)) + raise CriticalError( + "Invalid JSON in signatures file " + + self.sigs.signatures_file + + "\n" + + str(jde) + ) self.solc_binary = self._init_solc_binary(solv) - self.config_path = os.path.join(self.mythril_dir, 'config.ini') + self.config_path = os.path.join(self.mythril_dir, "config.ini") self.leveldb_dir = self._init_config() - self.eth = None # ethereum API client - self.eth_db = None # ethereum LevelDB client + self.eth = None # ethereum API client + self.eth_db = None # ethereum LevelDB client self.contracts = [] # loaded contracts @staticmethod def _init_mythril_dir(): try: - mythril_dir = os.environ['MYTHRIL_DIR'] + mythril_dir = os.environ["MYTHRIL_DIR"] except KeyError: - mythril_dir = os.path.join(os.path.expanduser('~'), ".mythril") + mythril_dir = os.path.join(os.path.expanduser("~"), ".mythril") # Initialize data directory and signature database @@ -126,59 +135,75 @@ class Mythril(object): """ system = platform.system().lower() - leveldb_fallback_dir = os.path.expanduser('~') + leveldb_fallback_dir = os.path.expanduser("~") if system.startswith("darwin"): - leveldb_fallback_dir = os.path.join(leveldb_fallback_dir, "Library", "Ethereum") + leveldb_fallback_dir = os.path.join( + leveldb_fallback_dir, "Library", "Ethereum" + ) elif system.startswith("windows"): - leveldb_fallback_dir = os.path.join(leveldb_fallback_dir, "AppData", "Roaming", "Ethereum") + leveldb_fallback_dir = os.path.join( + leveldb_fallback_dir, "AppData", "Roaming", "Ethereum" + ) else: leveldb_fallback_dir = os.path.join(leveldb_fallback_dir, ".ethereum") leveldb_fallback_dir = os.path.join(leveldb_fallback_dir, "geth", "chaindata") if not os.path.exists(self.config_path): logging.info("No config file found. Creating default: " + self.config_path) - open(self.config_path, 'a').close() + open(self.config_path, "a").close() config = ConfigParser(allow_no_value=True) config.optionxform = str - config.read(self.config_path, 'utf-8') - if 'defaults' not in config.sections(): + config.read(self.config_path, "utf-8") + if "defaults" not in config.sections(): self._add_default_options(config) - if not config.has_option('defaults', 'leveldb_dir'): + if not config.has_option("defaults", "leveldb_dir"): self._add_leveldb_option(config, leveldb_fallback_dir) - if not config.has_option('defaults', 'dynamic_loading'): + if not config.has_option("defaults", "dynamic_loading"): self._add_dynamic_loading_option(config) - with codecs.open(self.config_path, 'w', 'utf-8') as fp: + with codecs.open(self.config_path, "w", "utf-8") as fp: config.write(fp) - leveldb_dir = config.get('defaults', 'leveldb_dir', fallback=leveldb_fallback_dir) + leveldb_dir = config.get( + "defaults", "leveldb_dir", fallback=leveldb_fallback_dir + ) return os.path.expanduser(leveldb_dir) @staticmethod def _add_default_options(config): - config.add_section('defaults') + config.add_section("defaults") @staticmethod def _add_leveldb_option(config, leveldb_fallback_dir): - config.set('defaults', "#Default chaindata locations:") - config.set('defaults', "#– Mac: ~/Library/Ethereum/geth/chaindata") - config.set('defaults', "#– Linux: ~/.ethereum/geth/chaindata") - config.set('defaults', "#– Windows: %USERPROFILE%\\AppData\\Roaming\\Ethereum\\geth\\chaindata") - config.set('defaults', 'leveldb_dir', leveldb_fallback_dir) + config.set("defaults", "#Default chaindata locations:") + config.set("defaults", "#– Mac: ~/Library/Ethereum/geth/chaindata") + config.set("defaults", "#– Linux: ~/.ethereum/geth/chaindata") + config.set( + "defaults", + "#– Windows: %USERPROFILE%\\AppData\\Roaming\\Ethereum\\geth\\chaindata", + ) + config.set("defaults", "leveldb_dir", leveldb_fallback_dir) @staticmethod def _add_dynamic_loading_option(config): - config.set('defaults', '#– To connect to Infura use dynamic_loading: infura') - config.set('defaults', '#– To connect to Rpc use ' - 'dynamic_loading: HOST:PORT / ganache / infura-[network_name]') - config.set('defaults', '#– To connect to local host use dynamic_loading: localhost') - config.set('defaults', 'dynamic_loading', 'infura') + config.set("defaults", "#– To connect to Infura use dynamic_loading: infura") + config.set( + "defaults", + "#– To connect to Rpc use " + "dynamic_loading: HOST:PORT / ganache / infura-[network_name]", + ) + config.set( + "defaults", "#– To connect to local host use dynamic_loading: localhost" + ) + config.set("defaults", "dynamic_loading", "infura") def analyze_truffle_project(self, *args, **kwargs): - return analyze_truffle_project(self.sigs, *args, **kwargs) # just passthru by passing signatures for now + return analyze_truffle_project( + self.sigs, *args, **kwargs + ) # just passthru by passing signatures for now @staticmethod def _init_solc_binary(version): @@ -188,27 +213,31 @@ class Mythril(object): if version: # tried converting input to semver, seemed not necessary so just slicing for now if version == str(solc.main.get_solc_version())[:6]: - logging.info('Given version matches installed version') + logging.info("Given version matches installed version") try: - solc_binary = os.environ['SOLC'] + solc_binary = os.environ["SOLC"] except KeyError: - solc_binary = 'solc' + solc_binary = "solc" else: if util.solc_exists(version): - logging.info('Given version is already installed') + logging.info("Given version is already installed") else: try: - solc.install_solc('v' + version) + solc.install_solc("v" + version) except SolcError: - raise CriticalError("There was an error when trying to install the specified solc version") + raise CriticalError( + "There was an error when trying to install the specified solc version" + ) - solc_binary = os.path.join(os.environ['HOME'], ".py-solc/solc-v" + version, "bin/solc") + solc_binary = os.path.join( + os.environ["HOME"], ".py-solc/solc-v" + version, "bin/solc" + ) logging.info("Setting the compiler to " + str(solc_binary)) else: try: - solc_binary = os.environ['SOLC'] + solc_binary = os.environ["SOLC"] except KeyError: - solc_binary = 'solc' + solc_binary = "solc" return solc_binary def set_api_leveldb(self, leveldb): @@ -217,22 +246,24 @@ class Mythril(object): return self.eth def set_api_rpc_infura(self): - self.eth = EthJsonRpc('mainnet.infura.io', 443, True) + self.eth = EthJsonRpc("mainnet.infura.io", 443, True) logging.info("Using INFURA for RPC queries") def set_api_rpc(self, rpc=None, rpctls=False): - if rpc == 'ganache': - rpcconfig = ('localhost', 8545, False) + if rpc == "ganache": + rpcconfig = ("localhost", 8545, False) else: - m = re.match(r'infura-(.*)', rpc) - if m and m.group(1) in ['mainnet', 'rinkeby', 'kovan', 'ropsten']: - rpcconfig = (m.group(1) + '.infura.io', 443, True) + m = re.match(r"infura-(.*)", rpc) + if m and m.group(1) in ["mainnet", "rinkeby", "kovan", "ropsten"]: + rpcconfig = (m.group(1) + ".infura.io", 443, True) else: try: host, port = rpc.split(":") rpcconfig = (host, int(port), rpctls) except ValueError: - raise CriticalError("Invalid RPC argument, use 'ganache', 'infura-[network]' or 'HOST:PORT'") + raise CriticalError( + "Invalid RPC argument, use 'ganache', 'infura-[network]' or 'HOST:PORT'" + ) if rpcconfig: self.eth = EthJsonRpc(rpcconfig[0], int(rpcconfig[1]), rpcconfig[2]) @@ -241,26 +272,25 @@ class Mythril(object): raise CriticalError("Invalid RPC settings, check help for details.") def set_api_rpc_localhost(self): - self.eth = EthJsonRpc('localhost', 8545) + self.eth = EthJsonRpc("localhost", 8545) logging.info("Using default RPC settings: http://localhost:8545") def set_api_from_config_path(self): config = ConfigParser(allow_no_value=False) config.optionxform = str - config.read(self.config_path, 'utf-8') - if config.has_option('defaults', 'dynamic_loading'): - dynamic_loading = config.get('defaults', 'dynamic_loading') + config.read(self.config_path, "utf-8") + if config.has_option("defaults", "dynamic_loading"): + dynamic_loading = config.get("defaults", "dynamic_loading") else: - dynamic_loading = 'infura' - if dynamic_loading == 'infura': + dynamic_loading = "infura" + if dynamic_loading == "infura": self.set_api_rpc_infura() - elif dynamic_loading == 'localhost': + elif dynamic_loading == "localhost": self.set_api_rpc_localhost() else: self.set_api_rpc(dynamic_loading) def search_db(self, search): - def search_callback(contract, address, balance): print("Address: " + address + ", balance: " + str(balance)) @@ -272,8 +302,8 @@ class Mythril(object): raise CriticalError("Syntax error in search expression.") def contract_hash_to_address(self, hash): - if not re.match(r'0x[a-fA-F0-9]{64}', hash): - raise CriticalError("Invalid address hash. Expected format is '0x...'.") + if not re.match(r"0x[a-fA-F0-9]{64}", hash): + raise CriticalError("Invalid address hash. Expected format is '0x...'.") print(self.eth_db.contract_hash_to_address(hash)) @@ -283,20 +313,24 @@ class Mythril(object): return address, self.contracts[-1] # return address and contract object def load_from_address(self, address): - if not re.match(r'0x[a-fA-F0-9]{40}', address): - raise CriticalError("Invalid contract address. Expected format is '0x...'.") + if not re.match(r"0x[a-fA-F0-9]{40}", address): + raise CriticalError("Invalid contract address. Expected format is '0x...'.") try: code = self.eth.eth_getCode(address) except FileNotFoundError as e: - raise CriticalError("IPC error: " + str(e)) + raise CriticalError("IPC error: " + str(e)) except ConnectionError as e: - raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.") + raise CriticalError( + "Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly." + ) except Exception as e: - raise CriticalError("IPC / RPC error: " + str(e)) + raise CriticalError("IPC / RPC error: " + str(e)) else: if code == "0x" or code == "0x0": - raise CriticalError("Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain.") + raise CriticalError( + "Received an empty response from eth_getCode. Check the contract address and verify that you are on the correct chain." + ) else: self.contracts.append(ETHContract(code, name=address)) return address, self.contracts[-1] # return address and contract object @@ -319,57 +353,105 @@ class Mythril(object): try: # import signatures from solidity source - self.sigs.import_from_solidity_source(file, solc_binary=self.solc_binary, solc_args=self.solc_args) + self.sigs.import_from_solidity_source( + file, solc_binary=self.solc_binary, solc_args=self.solc_args + ) # Save updated function signatures self.sigs.write() # dump signatures to disk (previously opened file or default location) if contract_name is not None: - contract = SolidityContract(file, contract_name, solc_args=self.solc_args) + contract = SolidityContract( + file, contract_name, solc_args=self.solc_args + ) self.contracts.append(contract) contracts.append(contract) else: - for contract in get_contracts_from_file(file, solc_args=self.solc_args): + for contract in get_contracts_from_file( + file, solc_args=self.solc_args + ): self.contracts.append(contract) contracts.append(contract) - except FileNotFoundError: raise CriticalError("Input file not found: " + file) except CompilerError as e: raise CriticalError(e) except NoContractFoundError: - logging.info("The file " + file + " does not contain a compilable contract.") - + logging.info( + "The file " + file + " does not contain a compilable contract." + ) return address, contracts - def dump_statespace(self, strategy, contract, address=None, max_depth=None, - execution_timeout=None, create_timeout=None): - - sym = SymExecWrapper(contract, address, strategy, - dynloader=DynLoader(self.eth) if self.dynld else None, - max_depth=max_depth, execution_timeout=execution_timeout, create_timeout=create_timeout) + def dump_statespace( + self, + strategy, + contract, + address=None, + max_depth=None, + execution_timeout=None, + create_timeout=None, + ): + + sym = SymExecWrapper( + contract, + address, + strategy, + dynloader=DynLoader(self.eth) if self.dynld else None, + max_depth=max_depth, + execution_timeout=execution_timeout, + create_timeout=create_timeout, + ) return get_serializable_statespace(sym) - def graph_html(self, strategy, contract, address, max_depth=None, enable_physics=False, - phrackify=False, execution_timeout=None, create_timeout=None): - sym = SymExecWrapper(contract, address, strategy, - dynloader=DynLoader(self.eth) if self.dynld else None, - max_depth=max_depth, execution_timeout=execution_timeout, create_timeout=create_timeout) + def graph_html( + self, + strategy, + contract, + address, + max_depth=None, + enable_physics=False, + phrackify=False, + execution_timeout=None, + create_timeout=None, + ): + sym = SymExecWrapper( + contract, + address, + strategy, + dynloader=DynLoader(self.eth) if self.dynld else None, + max_depth=max_depth, + execution_timeout=execution_timeout, + create_timeout=create_timeout, + ) return generate_graph(sym, physics=enable_physics, phrackify=phrackify) - def fire_lasers(self, strategy, contracts=None, address=None, - modules=None, verbose_report=False, max_depth=None, execution_timeout=None, create_timeout=None, - max_transaction_count=None): + def fire_lasers( + self, + strategy, + contracts=None, + address=None, + modules=None, + verbose_report=False, + max_depth=None, + execution_timeout=None, + create_timeout=None, + max_transaction_count=None, + ): all_issues = [] - for contract in (contracts or self.contracts): - sym = SymExecWrapper(contract, address, strategy, - dynloader=DynLoader(self.eth) if self.dynld else None, - max_depth=max_depth, execution_timeout=execution_timeout, - create_timeout=create_timeout, - max_transaction_count=max_transaction_count) + for contract in contracts or self.contracts: + sym = SymExecWrapper( + contract, + address, + strategy, + dynloader=DynLoader(self.eth) if self.dynld else None, + max_depth=max_depth, + execution_timeout=execution_timeout, + create_timeout=create_timeout, + max_transaction_count=max_transaction_count, + ) issues = fire_lasers(sym, modules) @@ -393,13 +475,18 @@ class Mythril(object): try: if params[0] == "mapping": if len(params) < 3: - raise CriticalError("Invalid number of parameters.") + raise CriticalError("Invalid number of parameters.") position = int(params[1]) position_formatted = utils.zpad(utils.int_to_big_endian(position), 32) for i in range(2, len(params)): - key = bytes(params[i], 'utf8') + key = bytes(params[i], "utf8") key_formatted = utils.rzpad(key, 32) - mappings.append(int.from_bytes(utils.sha3(key_formatted + position_formatted), byteorder='big')) + mappings.append( + int.from_bytes( + utils.sha3(key_formatted + position_formatted), + byteorder="big", + ) + ) length = len(mappings) if length == 1: @@ -407,37 +494,58 @@ class Mythril(object): else: if len(params) >= 4: - raise CriticalError("Invalid number of parameters.") + raise CriticalError("Invalid number of parameters.") if len(params) >= 1: position = int(params[0]) if len(params) >= 2: length = int(params[1]) if len(params) == 3 and params[2] == "array": - position_formatted = utils.zpad(utils.int_to_big_endian(position), 32) - position = int.from_bytes(utils.sha3(position_formatted), byteorder='big') + position_formatted = utils.zpad( + utils.int_to_big_endian(position), 32 + ) + position = int.from_bytes( + utils.sha3(position_formatted), byteorder="big" + ) except ValueError: - raise CriticalError("Invalid storage index. Please provide a numeric value.") + raise CriticalError( + "Invalid storage index. Please provide a numeric value." + ) outtxt = [] try: if length == 1: - outtxt.append("{}: {}".format(position, self.eth.eth_getStorageAt(address, position))) + outtxt.append( + "{}: {}".format( + position, self.eth.eth_getStorageAt(address, position) + ) + ) else: if len(mappings) > 0: for i in range(0, len(mappings)): position = mappings[i] - outtxt.append("{}: {}".format(hex(position), self.eth.eth_getStorageAt(address, position))) + outtxt.append( + "{}: {}".format( + hex(position), + self.eth.eth_getStorageAt(address, position), + ) + ) else: for i in range(position, position + length): - outtxt.append("{}: {}".format(hex(i), self.eth.eth_getStorageAt(address, i))) + outtxt.append( + "{}: {}".format( + hex(i), self.eth.eth_getStorageAt(address, i) + ) + ) except FileNotFoundError as e: - raise CriticalError("IPC error: " + str(e)) + raise CriticalError("IPC error: " + str(e)) except ConnectionError as e: - raise CriticalError("Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly.") - return '\n'.join(outtxt) + raise CriticalError( + "Could not connect to RPC server. Make sure that your node is running and that RPC parameters are set correctly." + ) + return "\n".join(outtxt) @staticmethod def disassemble(contract): diff --git a/mythril/support/loader.py b/mythril/support/loader.py index d219d17c..ac5a5d2e 100644 --- a/mythril/support/loader.py +++ b/mythril/support/loader.py @@ -4,7 +4,6 @@ import re class DynLoader: - def __init__(self, eth): self.eth = eth self.storage_cache = {} @@ -19,13 +18,17 @@ class DynLoader: self.storage_cache[contract_address] = {} - data = self.eth.eth_getStorageAt(contract_address, position=index, block='latest') + data = self.eth.eth_getStorageAt( + contract_address, position=index, block="latest" + ) self.storage_cache[contract_address][index] = data except IndexError: - data = self.eth.eth_getStorageAt(contract_address, position=index, block='latest') + data = self.eth.eth_getStorageAt( + contract_address, position=index, block="latest" + ) self.storage_cache[contract_address][index] = data @@ -33,9 +36,11 @@ class DynLoader: def dynld(self, contract_address, dependency_address): - logging.info("Dynld at contract " + contract_address + ": " + dependency_address) + logging.info( + "Dynld at contract " + contract_address + ": " + dependency_address + ) - m = re.match(r'^(0x[0-9a-fA-F]{40})$', dependency_address) + m = re.match(r"^(0x[0-9a-fA-F]{40})$", dependency_address) if m: dependency_address = m.group(1) diff --git a/mythril/support/signatures.py b/mythril/support/signatures.py index 7c454089..4bc1880c 100644 --- a/mythril/support/signatures.py +++ b/mythril/support/signatures.py @@ -27,14 +27,15 @@ try: import fcntl def lock_file(f, exclusive=False): - if f.mode == 'r' and exclusive: - raise Exception('Please use non exclusive mode for reading') + if f.mode == "r" and exclusive: + raise Exception("Please use non exclusive mode for reading") flag = fcntl.LOCK_EX if exclusive else fcntl.LOCK_SH fcntl.lockf(f, flag) def unlock_file(f): return + except ImportError: # Windows file locking # TODO: confirm the existence or non existence of shared locks in windows msvcrt and make changes based on that @@ -44,8 +45,8 @@ except ImportError: return os.path.getsize(os.path.realpath(f.name)) def lock_file(f, exclusive=False): - if f.mode == 'r' and exclusive: - raise Exception('Please use non exclusive mode for reading') + if f.mode == "r" and exclusive: + raise Exception("Please use non exclusive mode for reading") msvcrt.locking(f.fileno(), msvcrt.LK_RLCK, file_size(f)) def unlock_file(f): @@ -53,7 +54,6 @@ except ImportError: class SignatureDb(object): - def __init__(self, enable_online_lookup=True): """ Constr @@ -61,9 +61,15 @@ class SignatureDb(object): """ self.signatures = {} # signatures in-mem cache self.signatures_file = None - self.enable_online_lookup = enable_online_lookup # enable online funcsig resolving - self.online_lookup_miss = set() # temporarily track misses from onlinedb to avoid requesting the same non-existent sighash multiple times - self.online_directory_unavailable_until = 0 # flag the online directory as unavailable for some time + self.enable_online_lookup = ( + enable_online_lookup + ) # enable online funcsig resolving + self.online_lookup_miss = ( + set() + ) # temporarily track misses from onlinedb to avoid requesting the same non-existent sighash multiple times + self.online_directory_unavailable_until = ( + 0 + ) # flag the online directory as unavailable for some time def open(self, path=None): """ @@ -75,15 +81,19 @@ class SignatureDb(object): if not path: # try default locations try: - mythril_dir = os.environ['MYTHRIL_DIR'] + mythril_dir = os.environ["MYTHRIL_DIR"] except KeyError: - mythril_dir = os.path.join(os.path.expanduser('~'), ".mythril") - path = os.path.join(mythril_dir, 'signatures.json') + mythril_dir = os.path.join(os.path.expanduser("~"), ".mythril") + path = os.path.join(mythril_dir, "signatures.json") - self.signatures_file = path # store early to allow error handling to access the place we tried to load the file + self.signatures_file = ( + path + ) # store early to allow error handling to access the place we tried to load the file if not os.path.exists(path): logging.debug("Signatures: file not found: %s" % path) - raise FileNotFoundError("Missing function signature file. Resolving of function names disabled.") + raise FileNotFoundError( + "Missing function signature file. Resolving of function names disabled." + ) with open(path, "r") as f: lock_file(f) @@ -122,16 +132,18 @@ class SignatureDb(object): finally: unlock_file(f) - sigs.update(self.signatures) # reload file and merge cached sigs into what we load from file + sigs.update( + self.signatures + ) # reload file and merge cached sigs into what we load from file self.signatures = sigs if directory and not os.path.exists(directory): - os.makedirs(directory) # create folder structure if not existS + os.makedirs(directory) # create folder structure if not existS - if not os.path.exists(path): # creates signatures.json file if it doesn't exist + if not os.path.exists(path): # creates signatures.json file if it doesn't exist open(path, "w").close() - with open(path, "r+") as f: # placing 'w+' here will result in race conditions + with open(path, "r+") as f: # placing 'w+' here will result in race conditions lock_file(f, exclusive=True) try: json.dump(self.signatures, f) @@ -151,11 +163,20 @@ class SignatureDb(object): """ if not sighash.startswith("0x"): sighash = "0x%s" % sighash # normalize sighash format - if self.enable_online_lookup and not self.signatures.get(sighash) and sighash not in self.online_lookup_miss and time.time() > self.online_directory_unavailable_until: + if ( + self.enable_online_lookup + and not self.signatures.get(sighash) + and sighash not in self.online_lookup_miss + and time.time() > self.online_directory_unavailable_until + ): # online lookup enabled, and signature not in cache, sighash was not a miss earlier, and online directory not down - logging.debug("Signatures: performing online lookup for sighash %r" % sighash) + logging.debug( + "Signatures: performing online lookup for sighash %r" % sighash + ) try: - funcsigs = SignatureDb.lookup_online(sighash, timeout=timeout) # might return multiple sigs + funcsigs = SignatureDb.lookup_online( + sighash, timeout=timeout + ) # might return multiple sigs if funcsigs: # only store if we get at least one result self.signatures[sighash] = funcsigs @@ -163,8 +184,13 @@ class SignatureDb(object): # miss self.online_lookup_miss.add(sighash) except FourByteDirectoryOnlineLookupError as fbdole: - self.online_directory_unavailable_until = time.time() + 2 * 60 # wait at least 2 mins to try again - logging.warning("online function signature lookup not available. will not try to lookup hash for the next 2 minutes. exception: %r" % fbdole) + self.online_directory_unavailable_until = ( + time.time() + 2 * 60 + ) # wait at least 2 mins to try again + logging.warning( + "online function signature lookup not available. will not try to lookup hash for the next 2 minutes. exception: %r" + % fbdole + ) if type(self.signatures[sighash]) != list: return [self.signatures[sighash]] return self.signatures[sighash] # raise keyerror @@ -177,13 +203,19 @@ class SignatureDb(object): """ return self.get(sighash=item) - def import_from_solidity_source(self, file_path, solc_binary="solc", solc_args=None): + def import_from_solidity_source( + self, file_path, solc_binary="solc", solc_args=None + ): """ Import Function Signatures from solidity source files :param file_path: solidity source code file path :return: self """ - self.signatures.update(SignatureDb.get_sigs_from_file(file_path, solc_binary=solc_binary, solc_args=solc_args)) + self.signatures.update( + SignatureDb.get_sigs_from_file( + file_path, solc_binary=solc_binary, solc_args=solc_args + ) + ) return self @staticmethod @@ -201,9 +233,11 @@ class SignatureDb(object): """ if not ethereum_input_decoder: return None - return list(ethereum_input_decoder.decoder.FourByteDirectory.lookup_signatures(sighash, - timeout=timeout, - proxies=proxies)) + return list( + ethereum_input_decoder.decoder.FourByteDirectory.lookup_signatures( + sighash, timeout=timeout, proxies=proxies + ) + ) @staticmethod def get_sigs_from_file(file_name, solc_binary="solc", solc_args=None): @@ -221,13 +255,19 @@ class SignatureDb(object): ret = p.returncode if ret != 0: - raise CompilerError("Solc experienced a fatal error (code %d).\n\n%s" % (ret, stderr.decode('UTF-8'))) + raise CompilerError( + "Solc experienced a fatal error (code %d).\n\n%s" + % (ret, stderr.decode("UTF-8")) + ) except FileNotFoundError: raise CompilerError( - "Compiler not found. Make sure that solc is installed and in PATH, or set the SOLC environment variable.") - stdout = stdout.decode('unicode_escape').split('\n') + "Compiler not found. Make sure that solc is installed and in PATH, or set the SOLC environment variable." + ) + stdout = stdout.decode("unicode_escape").split("\n") for line in stdout: - if '(' in line and ')' in line and ":" in line: # the ':' need not be checked but just to be sure - sigs["0x"+line.split(':')[0]] = [line.split(":")[1].strip()] + if ( + "(" in line and ")" in line and ":" in line + ): # the ':' need not be checked but just to be sure + sigs["0x" + line.split(":")[0]] = [line.split(":")[1].strip()] logging.debug("Signatures: found %d signatures after parsing" % len(sigs)) return sigs diff --git a/mythril/support/truffle.py b/mythril/support/truffle.py index 6dd13b2e..a6d0a537 100644 --- a/mythril/support/truffle.py +++ b/mythril/support/truffle.py @@ -25,36 +25,49 @@ def analyze_truffle_project(sigs, args): for filename in files: - if re.match(r'.*\.json$', filename) and filename != "Migrations.json": + if re.match(r".*\.json$", filename) and filename != "Migrations.json": with open(os.path.join(build_dir, filename)) as cf: contractdata = json.load(cf) try: - name = contractdata['contractName'] - bytecode = contractdata['deployedBytecode'] - filename = PurePath(contractdata['sourcePath']).name + name = contractdata["contractName"] + bytecode = contractdata["deployedBytecode"] + filename = PurePath(contractdata["sourcePath"]).name except KeyError: - print("Unable to parse contract data. Please use Truffle 4 to compile your project.") + print( + "Unable to parse contract data. Please use Truffle 4 to compile your project." + ) sys.exit() if len(bytecode) < 4: continue - sigs.import_from_solidity_source(contractdata['sourcePath'], solc_args=args.solc_args) + sigs.import_from_solidity_source( + contractdata["sourcePath"], solc_args=args.solc_args + ) sigs.write() ethcontract = ETHContract(bytecode, name=name) address = util.get_indexed_address(0) - sym = SymExecWrapper(ethcontract, address, args.strategy, max_depth=args.max_depth, - create_timeout=args.create_timeout, execution_timeout=args.execution_timeout) + sym = SymExecWrapper( + ethcontract, + address, + args.strategy, + max_depth=args.max_depth, + create_timeout=args.create_timeout, + execution_timeout=args.execution_timeout, + ) issues = fire_lasers(sym) if not len(issues): - if args.outform == 'text' or args.outform == 'markdown': + if args.outform == "text" or args.outform == "markdown": print("# Analysis result for " + name + "\n\nNo issues found.") else: - result = {'contract': name, 'result': {'success': True, 'error': None, 'issues': []}} + result = { + "contract": name, + "result": {"success": True, "error": None, "issues": []}, + } print(json.dumps(result)) else: @@ -62,9 +75,9 @@ def analyze_truffle_project(sigs, args): # augment with source code disassembly = ethcontract.disassembly - source = contractdata['source'] + source = contractdata["source"] - deployed_source_map = contractdata['deployedSourceMap'].split(";") + deployed_source_map = contractdata["deployedSourceMap"].split(";") mappings = [] @@ -80,34 +93,49 @@ def analyze_truffle_project(sigs, args): if len(mapping) > 2 and len(mapping[2]) > 0: idx = int(mapping[2]) - lineno = source.encode('utf-8')[0:offset].count('\n'.encode('utf-8')) + 1 + lineno = ( + source.encode("utf-8")[0:offset].count("\n".encode("utf-8")) + 1 + ) mappings.append(SourceMapping(idx, offset, length, lineno)) for issue in issues: - index = get_instruction_index(disassembly.instruction_list, issue.address) + index = get_instruction_index( + disassembly.instruction_list, issue.address + ) if index: - try: - offset = mappings[index].offset - length = mappings[index].length - - issue.filename = filename - issue.code = source.encode('utf-8')[offset:offset + length].decode('utf-8') - issue.lineno = mappings[index].lineno - except IndexError: - logging.debug("No code mapping at index %d", index) + try: + offset = mappings[index].offset + length = mappings[index].length + + issue.filename = filename + issue.code = source.encode("utf-8")[ + offset : offset + length + ].decode("utf-8") + issue.lineno = mappings[index].lineno + except IndexError: + logging.debug("No code mapping at index %d", index) report.append_issue(issue) - if args.outform == 'json': + if args.outform == "json": - result = {'contract': name, 'result': {'success': True, 'error': None, 'issues': list(map(lambda x: x.as_dict, issues))}} + result = { + "contract": name, + "result": { + "success": True, + "error": None, + "issues": list(map(lambda x: x.as_dict, issues)), + }, + } print(json.dumps(result)) else: - if args.outform == 'text': - print("# Analysis result for " + name + ":\n\n" + report.as_text()) - elif args.outform == 'markdown': + if args.outform == "text": + print( + "# Analysis result for " + name + ":\n\n" + report.as_text() + ) + elif args.outform == "markdown": print(report.as_markdown()) diff --git a/mythril/version.py b/mythril/version.py index 01e7ce46..72d9c8d0 100644 --- a/mythril/version.py +++ b/mythril/version.py @@ -1,3 +1,3 @@ # This file is suitable for sourcing inside POSIX shell, e.g. bash as # well as for importing into Python -VERSION="v0.18.12" # NOQA +VERSION = "v0.18.12" # NOQA diff --git a/setup.py b/setup.py index 2c92ce4a..fe9ad5f5 100755 --- a/setup.py +++ b/setup.py @@ -22,19 +22,22 @@ VERSION = None # Package version (vX.Y.Z). It must match git tag being used for CircleCI # deployment; otherwise the build will failed. -version_path = (Path(__file__).parent / 'mythril' / 'version.py').absolute() -exec(open(str(version_path), 'r').read()) +version_path = (Path(__file__).parent / "mythril" / "version.py").absolute() +exec(open(str(version_path), "r").read()) class VerifyVersionCommand(install): """Custom command to verify that the git tag matches our version""" - description = 'verify that the git tag matches our version' + + description = "verify that the git tag matches our version" def run(self): - tag = os.getenv('CIRCLE_TAG') + tag = os.getenv("CIRCLE_TAG") if tag != VERSION: - info = "Git tag: {0} does not match the version of this app: {1}".format(tag, VERSION) + info = "Git tag: {0} does not match the version of this app: {1}".format( + tag, VERSION + ) sys.exit(info) @@ -44,91 +47,61 @@ def read_file(fname): :param fname: path relative to setup.py :return: file contents """ - with open(os.path.join(os.path.dirname(__file__), fname), 'r') as fd: + with open(os.path.join(os.path.dirname(__file__), fname), "r") as fd: return fd.read() setup( - name='mythril', - + name="mythril", version=VERSION[1:], - - description='Security analysis tool for Ethereum smart contracts', + description="Security analysis tool for Ethereum smart contracts", long_description=read_file("README.md") if os.path.isfile("README.md") else "", - long_description_content_type='text/markdown', # requires twine and recent setuptools - - url='https://github.com/b-mueller/mythril', - - author='Bernhard Mueller', - author_email='bernhard.mueller11@gmail.com', - - license='MIT', - + long_description_content_type="text/markdown", # requires twine and recent setuptools + url="https://github.com/b-mueller/mythril", + author="Bernhard Mueller", + author_email="bernhard.mueller11@gmail.com", + license="MIT", classifiers=[ - 'Development Status :: 3 - Alpha', - - 'Intended Audience :: Science/Research', - 'Topic :: Software Development :: Disassemblers', - - 'License :: OSI Approved :: MIT License', - - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', + "Development Status :: 3 - Alpha", + "Intended Audience :: Science/Research", + "Topic :: Software Development :: Disassemblers", + "License :: OSI Approved :: MIT License", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", ], - - keywords='hacking disassembler security ethereum', - - packages=find_packages(exclude=['contrib', 'docs', 'tests']), - + keywords="hacking disassembler security ethereum", + packages=find_packages(exclude=["contrib", "docs", "tests"]), install_requires=[ - 'coloredlogs>=10.0', - 'ethereum>=2.3.2', - 'z3-solver>=4.5', - 'requests', - 'py-solc', - 'plyvel', - 'eth_abi>=1.0.0', - 'eth-utils>=1.0.1', - 'eth-account>=0.1.0a2', - 'eth-hash>=0.1.0', - 'eth-keyfile>=0.5.1', - 'eth-keys>=0.2.0b3', - 'eth-rlp>=0.1.0', - 'eth-tester>=0.1.0b21', - 'eth-typing>=1.3.0,<2.0.0', - 'coverage', - 'jinja2>=2.9', - 'rlp>=1.0.1', - 'transaction>=2.2.1', - 'py-flags', - 'mock', - 'configparser>=3.5.0', - 'persistent>=4.2.0' + "coloredlogs>=10.0", + "ethereum>=2.3.2", + "z3-solver>=4.5", + "requests", + "py-solc", + "plyvel", + "eth_abi>=1.0.0", + "eth-utils>=1.0.1", + "eth-account>=0.1.0a2", + "eth-hash>=0.1.0", + "eth-keyfile>=0.5.1", + "eth-keys>=0.2.0b3", + "eth-rlp>=0.1.0", + "eth-tester>=0.1.0b21", + "eth-typing>=1.3.0,<2.0.0", + "coverage", + "jinja2>=2.9", + "rlp>=1.0.1", + "transaction>=2.2.1", + "py-flags", + "mock", + "configparser>=3.5.0", + "persistent>=4.2.0", ], - - tests_require=[ - 'pytest>=3.6.0', - 'pytest_mock', - 'pytest-cov' - ], - - python_requires='>=3.5', - - extras_require={ - }, - - package_data={ - 'mythril.analysis.templates': ['*'] - }, - + tests_require=["pytest>=3.6.0", "pytest_mock", "pytest-cov"], + python_requires=">=3.5", + extras_require={}, + package_data={"mythril.analysis.templates": ["*"]}, include_package_data=True, - - entry_points={ - 'console_scripts': ["myth=mythril.interfaces.cli:main"], - }, - - cmdclass={ - 'verify': VerifyVersionCommand, - } + entry_points={"console_scripts": ["myth=mythril.interfaces.cli:main"]}, + cmdclass={"verify": VerifyVersionCommand}, ) diff --git a/tests/__init__.py b/tests/__init__.py index 8aab9cad..c8418bbe 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -16,15 +16,17 @@ MYTHRIL_DIR = TESTS_DIR / "mythril_dir" class BaseTestCase(TestCase): - def setUp(self): self.changed_files = [] - self.ori_mythril_dir = getattr(os.environ, 'MYTHRIL_DIR', '') - os.environ['MYTHRIL_DIR'] = str(MYTHRIL_DIR) - shutil.copyfile(str(MYTHRIL_DIR / "signatures.json.example"), str(MYTHRIL_DIR / "signatures.json")) + self.ori_mythril_dir = getattr(os.environ, "MYTHRIL_DIR", "") + os.environ["MYTHRIL_DIR"] = str(MYTHRIL_DIR) + shutil.copyfile( + str(MYTHRIL_DIR / "signatures.json.example"), + str(MYTHRIL_DIR / "signatures.json"), + ) def tearDown(self): - os.environ['MYTHRIL_DIR'] = self.ori_mythril_dir + os.environ["MYTHRIL_DIR"] = self.ori_mythril_dir os.remove(str(MYTHRIL_DIR / "signatures.json")) def compare_files_error_message(self): @@ -41,4 +43,6 @@ class BaseTestCase(TestCase): self.changed_files.append((input_file, output_expected, output_current)) def assert_and_show_changed_files(self): - self.assertEqual(0, len(self.changed_files), msg=self.compare_files_error_message()) + self.assertEqual( + 0, len(self.changed_files), msg=self.compare_files_error_message() + ) diff --git a/tests/analysis/test_delegatecall.py b/tests/analysis/test_delegatecall.py index 3fb02020..a363f0da 100644 --- a/tests/analysis/test_delegatecall.py +++ b/tests/analysis/test_delegatecall.py @@ -1,4 +1,8 @@ -from mythril.analysis.modules.delegatecall import execute, _concrete_call, _symbolic_call +from mythril.analysis.modules.delegatecall import ( + execute, + _concrete_call, + _symbolic_call, +) from mythril.analysis.ops import Call, Variable, VarType from mythril.analysis.symbolic import SymExecWrapper from mythril.laser.ethereum.cfg import Node @@ -32,11 +36,14 @@ def test_concrete_call(): assert issue.contract == node.contract_name assert issue.function == node.function_name assert issue.title == "Call data forwarded with delegatecall()" - assert issue.type == 'Informational' - assert issue.description == "This contract forwards its call data via DELEGATECALL in its fallback function." \ - " This means that any function in the called contract can be executed." \ - " Note that the callee contract will have access to the storage of the " \ - "calling contract.\n DELEGATECALL target: 0x1" + assert issue.type == "Informational" + assert ( + issue.description + == "This contract forwards its call data via DELEGATECALL in its fallback function." + " This means that any function in the called contract can be executed." + " Note that the callee contract will have access to the storage of the " + "calling contract.\n DELEGATECALL target: 0x1" + ) def test_concrete_call_symbolic_to(): @@ -63,11 +70,14 @@ def test_concrete_call_symbolic_to(): assert issue.contract == node.contract_name assert issue.function == node.function_name assert issue.title == "Call data forwarded with delegatecall()" - assert issue.type == 'Informational' - assert issue.description == "This contract forwards its call data via DELEGATECALL in its fallback function." \ - " This means that any function in the called contract can be executed." \ - " Note that the callee contract will have access to the storage of the " \ - "calling contract.\n DELEGATECALL target: calldata_3" + assert issue.type == "Informational" + assert ( + issue.description + == "This contract forwards its call data via DELEGATECALL in its fallback function." + " This means that any function in the called contract can be executed." + " Note that the callee contract will have access to the storage of the " + "calling contract.\n DELEGATECALL target: calldata_3" + ) def test_concrete_call_not_calldata(): @@ -92,7 +102,6 @@ def test_symbolic_call_storage_to(mocker): state = GlobalState(None, environment, None) state.mstate.memory = ["placeholder", "calldata_bling_0"] - node = Node("example") node.contract_name = "the contract name" node.function_name = "the function name" @@ -100,14 +109,12 @@ def test_symbolic_call_storage_to(mocker): to = Variable("storage_1", VarType.SYMBOLIC) call = Call(node, state, None, "Type: ", to, None) - mocker.patch.object(SymExecWrapper, "__init__", lambda x, y: None) statespace = SymExecWrapper(1) - mocker.patch.object(statespace, 'find_storage_write') + mocker.patch.object(statespace, "find_storage_write") statespace.find_storage_write.return_value = "Function name" - # act issues = _symbolic_call(call, state, address, statespace) @@ -116,11 +123,14 @@ def test_symbolic_call_storage_to(mocker): assert issue.address == address assert issue.contract == node.contract_name assert issue.function == node.function_name - assert issue.title == 'Type: to a user-supplied address' - assert issue.type == 'Informational' - assert issue.description == 'This contract delegates execution to a contract address in storage slot 1.' \ - ' This storage slot can be written to by calling the function `Function name`. ' \ - 'Be aware that the called contract gets unrestricted access to this contract\'s state.' + assert issue.title == "Type: to a user-supplied address" + assert issue.type == "Informational" + assert ( + issue.description + == "This contract delegates execution to a contract address in storage slot 1." + " This storage slot can be written to by calling the function `Function name`. " + "Be aware that the called contract gets unrestricted access to this contract's state." + ) def test_symbolic_call_calldata_to(mocker): @@ -130,7 +140,6 @@ def test_symbolic_call_calldata_to(mocker): state = GlobalState(None, None, None) state.mstate.memory = ["placeholder", "calldata_bling_0"] - node = Node("example") node.contract_name = "the contract name" node.function_name = "the function name" @@ -138,14 +147,12 @@ def test_symbolic_call_calldata_to(mocker): to = Variable("calldata", VarType.SYMBOLIC) call = Call(node, state, None, "Type: ", to, None) - mocker.patch.object(SymExecWrapper, "__init__", lambda x, y: None) statespace = SymExecWrapper(1) - mocker.patch.object(statespace, 'find_storage_write') + mocker.patch.object(statespace, "find_storage_write") statespace.find_storage_write.return_value = "Function name" - # act issues = _symbolic_call(call, state, address, statespace) @@ -154,29 +161,32 @@ def test_symbolic_call_calldata_to(mocker): assert issue.address == address assert issue.contract == node.contract_name assert issue.function == node.function_name - assert issue.title == 'Type: to a user-supplied address' - assert issue.type == 'Informational' - assert issue.description == 'This contract delegates execution to a contract address obtained from calldata. ' \ - 'Be aware that the called contract gets unrestricted access to this contract\'s state.' - - -@patch('mythril.laser.ethereum.state.GlobalState.get_current_instruction') -@patch('mythril.analysis.modules.delegatecall._concrete_call') -@patch('mythril.analysis.modules.delegatecall._symbolic_call') + assert issue.title == "Type: to a user-supplied address" + assert issue.type == "Informational" + assert ( + issue.description + == "This contract delegates execution to a contract address obtained from calldata. " + "Be aware that the called contract gets unrestricted access to this contract's state." + ) + + +@patch("mythril.laser.ethereum.state.GlobalState.get_current_instruction") +@patch("mythril.analysis.modules.delegatecall._concrete_call") +@patch("mythril.analysis.modules.delegatecall._symbolic_call") def test_delegate_call(sym_mock, concrete_mock, curr_instruction): # arrange # sym_mock = mocker.patch.object(delegatecall, "_symbolic_call") # concrete_mock = mocker.patch.object(delegatecall, "_concrete_call") sym_mock.return_value = [] concrete_mock.return_value = [] - curr_instruction.return_value = {'address': '0x10'} + curr_instruction.return_value = {"address": "0x10"} - active_account = Account('0x10') + active_account = Account("0x10") environment = Environment(active_account, None, None, None, None, None) state = GlobalState(None, environment, Node) state.mstate.memory = ["placeholder", "calldata_bling_0"] state.mstate.stack = [1, 2, 3] - assert state.get_current_instruction() == {'address': '0x10'} + assert state.get_current_instruction() == {"address": "0x10"} node = Node("example") node.contract_name = "the contract name" @@ -196,8 +206,8 @@ def test_delegate_call(sym_mock, concrete_mock, curr_instruction): assert sym_mock.call_count == 1 -@patch('mythril.analysis.modules.delegatecall._concrete_call') -@patch('mythril.analysis.modules.delegatecall._symbolic_call') +@patch("mythril.analysis.modules.delegatecall._concrete_call") +@patch("mythril.analysis.modules.delegatecall._symbolic_call") def test_delegate_call_not_delegate(sym_mock, concrete_mock): # arrange # sym_mock = mocker.patch.object(delegatecall, "_symbolic_call") @@ -223,8 +233,8 @@ def test_delegate_call_not_delegate(sym_mock, concrete_mock): assert sym_mock.call_count == 0 -@patch('mythril.analysis.modules.delegatecall._concrete_call') -@patch('mythril.analysis.modules.delegatecall._symbolic_call') +@patch("mythril.analysis.modules.delegatecall._concrete_call") +@patch("mythril.analysis.modules.delegatecall._symbolic_call") def test_delegate_call_not_fallback(sym_mock, concrete_mock): # arrange # sym_mock = mocker.patch.object(delegatecall, "_symbolic_call") diff --git a/tests/cmd_line_test.py b/tests/cmd_line_test.py index 9c5033aa..91c8714d 100644 --- a/tests/cmd_line_test.py +++ b/tests/cmd_line_test.py @@ -3,50 +3,64 @@ from tests import * MYTH = str(PROJECT_DIR / "myth") + def output_of(command): return check_output(command, shell=True).decode("UTF-8") -class CommandLineToolTestCase(BaseTestCase): +class CommandLineToolTestCase(BaseTestCase): def test_disassemble_code_correctly(self): command = "python3 {} MYTH -d -c 0x5050".format(MYTH) - self.assertEqual('0 POP\n1 POP\n', output_of(command)) + self.assertEqual("0 POP\n1 POP\n", output_of(command)) def test_disassemble_solidity_file_correctly(self): - solidity_file = str(TESTDATA / 'input_contracts'/ 'metacoin.sol') + solidity_file = str(TESTDATA / "input_contracts" / "metacoin.sol") command = "python3 {} -d {}".format(MYTH, solidity_file) - self.assertIn('2 PUSH1 0x40\n4 MSTORE', output_of(command)) + self.assertIn("2 PUSH1 0x40\n4 MSTORE", output_of(command)) def test_hash_a_function_correctly(self): command = "python3 {} --hash 'setOwner(address)'".format(MYTH) - self.assertEqual('0x13af4035\n', output_of(command)) + self.assertEqual("0x13af4035\n", output_of(command)) -class TruffleTestCase(BaseTestCase): +class TruffleTestCase(BaseTestCase): def test_analysis_truffle_project(self): truffle_project_root = str(TESTS_DIR / "truffle_project") - command = "cd {}; truffle compile; python3 {} --truffle".format(truffle_project_root, MYTH) - self.assertIn("In the function `withdrawfunds()` a non-zero amount of Ether is sent to msg.sender.", output_of(command)) + command = "cd {}; truffle compile; python3 {} --truffle".format( + truffle_project_root, MYTH + ) + self.assertIn( + "In the function `withdrawfunds()` a non-zero amount of Ether is sent to msg.sender.", + output_of(command), + ) -class InfuraTestCase(BaseTestCase): +class InfuraTestCase(BaseTestCase): def test_infura_mainnet(self): - command = "python3 {} --rpc infura-mainnet -d -a 0x2a0c0dbecc7e4d658f48e01e3fa353f44050c208".format(MYTH) + command = "python3 {} --rpc infura-mainnet -d -a 0x2a0c0dbecc7e4d658f48e01e3fa353f44050c208".format( + MYTH + ) output = output_of(command) self.assertIn("0 PUSH1 0x60\n2 PUSH1 0x40\n4 MSTORE", output) self.assertIn("7278 POP\n7279 POP\n7280 JUMP\n7281 STOP", output) def test_infura_rinkeby(self): - command = "python3 {} --rpc infura-rinkeby -d -a 0xB6f2bFED892a662bBF26258ceDD443f50Fa307F5".format(MYTH) + command = "python3 {} --rpc infura-rinkeby -d -a 0xB6f2bFED892a662bBF26258ceDD443f50Fa307F5".format( + MYTH + ) output = output_of(command) self.assertIn("34 JUMPDEST\n35 CALLVALUE", output) def test_infura_kovan(self): - command = "python3 {} --rpc infura-kovan -d -a 0xE6bBF9B5A3451242F82f8cd458675092617a1235".format(MYTH) + command = "python3 {} --rpc infura-kovan -d -a 0xE6bBF9B5A3451242F82f8cd458675092617a1235".format( + MYTH + ) output = output_of(command) self.assertIn("9999 PUSH1 0x00\n10001 NOT\n10002 AND\n10003 PUSH1 0x00", output) def test_infura_ropsten(self): - command = "python3 {} --rpc infura-ropsten -d -a 0x6e0E0e02377Bc1d90E8a7c21f12BA385C2C35f78".format(MYTH) + command = "python3 {} --rpc infura-ropsten -d -a 0x6e0E0e02377Bc1d90E8a7c21f12BA385C2C35f78".format( + MYTH + ) output = output_of(command) self.assertIn("1821 PUSH1 0x01\n1823 PUSH2 0x070c", output) diff --git a/tests/disassembler_test.py b/tests/disassembler_test.py index 30020b2e..b663b074 100644 --- a/tests/disassembler_test.py +++ b/tests/disassembler_test.py @@ -2,13 +2,14 @@ from mythril.disassembler.disassembly import Disassembly from mythril.ether import util from tests import * + def _compile_to_code(input_file): compiled = util.get_solc_json(str(input_file)) - code = list(compiled['contracts'].values())[0]['bin-runtime'] + code = list(compiled["contracts"].values())[0]["bin-runtime"] return code -class DisassemblerTestCase(BaseTestCase): +class DisassemblerTestCase(BaseTestCase): def test_instruction_list(self): code = "0x606060405236156100ca5763ffffffff60e060020a600035041663054f7d9c81146100d3578063095c21e3146100f45780630ba50baa146101165780631a3719321461012857806366529e3f14610153578063682789a81461017257806389f21efc146101915780638da5cb5b146101ac5780638f4ffcb1146101d55780639a1f2a5714610240578063b5f522f71461025b578063bd94b005146102b6578063c5ab5a13146102c8578063cc424839146102f1578063deb077b914610303578063f3fef3a314610322575b6100d15b5b565b005b34610000576100e0610340565b604080519115158252519081900360200190f35b3461000057610104600435610361565b60408051918252519081900360200190f35b34610000576100d1600435610382565b005b3461000057610104600160a060020a03600435166103b0565b60408051918252519081900360200190f35b346100005761010461041e565b60408051918252519081900360200190f35b3461000057610104610424565b60408051918252519081900360200190f35b34610000576100d1600160a060020a036004351661042b565b005b34610000576101b961046f565b60408051600160a060020a039092168252519081900360200190f35b3461000057604080516020600460643581810135601f81018490048402850184019095528484526100d1948235600160a060020a039081169560248035966044359093169594608494929391019190819084018382808284375094965061048595505050505050565b005b34610000576100d1600160a060020a03600435166106e7565b005b346100005761026b60043561072b565b60408051600160a060020a0390991689526020890197909752878701959095526060870193909352608086019190915260a085015260c084015260e083015251908190036101000190f35b34610000576100d160043561077a565b005b34610000576101b9610830565b60408051600160a060020a039092168252519081900360200190f35b34610000576100d160043561083f565b005b34610000576101046108a1565b60408051918252519081900360200190f35b34610000576100d1600160a060020a03600435166024356108a7565b005b60015474010000000000000000000000000000000000000000900460ff1681565b600681815481101561000057906000526020600020900160005b5054905081565b600054600160a060020a036301000000909104811690331681146103a557610000565b60038290555b5b5050565b6005546040805160006020918201819052825160e260020a631d010437028152600160a060020a03868116600483015293519194939093169263740410dc92602480830193919282900301818787803b156100005760325a03f115610000575050604051519150505b919050565b60035481565b6006545b90565b600054600160a060020a0363010000009091048116903316811461044e57610000565b60018054600160a060020a031916600160a060020a0384161790555b5b5050565b60005463010000009004600160a060020a031681565b6000600060006000600060006000600160149054906101000a900460ff16156104ad57610000565b87600081518110156100005760209101015160005460f860020a918290048202975002600160f860020a031990811690871614156105405760009450600196505b600587101561053057878781518110156100005790602001015160f860020a900460f860020a0260f860020a900485610100020194505b6001909601956104ee565b61053b8b868c610955565b6106d6565b600054610100900460f860020a02600160f860020a0319908116908716141561069e57506001955060009250829150819050805b60058710156105b657878781518110156100005790602001015160f860020a900460f860020a0260f860020a900481610100020190505b600190960195610574565b600596505b60098710156105fd57878781518110156100005790602001015160f860020a900460f860020a0260f860020a900484610100020193505b6001909601956105bb565b600996505b600d87101561064457878781518110156100005790602001015160f860020a900460f860020a0260f860020a900483610100020192505b600190960195610602565b600d96505b601187101561068b57878781518110156100005790602001015160f860020a900460f860020a0260f860020a900482610100020191505b600190960195610649565b61053b8b828c878787610bc4565b6106d6565b60005462010000900460f860020a02600160f860020a031990811690871614156106d15761053b8b8b610e8e565b6106d6565b610000565b5b5b5b5b5050505050505050505050565b600054600160a060020a0363010000009091048116903316811461070a57610000565b60058054600160a060020a031916600160a060020a0384161790555b5b5050565b600760208190526000918252604090912080546001820154600283015460038401546004850154600586015460068701549690970154600160a060020a03909516969395929491939092909188565b600081815260076020526040812054600160a060020a0390811690331681146107a257610000565b600083815260076020526040902080546004820154600583015460038401549395506107dc93600160a060020a039093169291029061105e565b50600060058301556107ed83611151565b6040805184815290517fb5dc9baf0cb4e7e4759fa12eadebddf9316e26147d5a9ae150c4228d5a1dd23f9181900360200190a161082933611244565b5b5b505050565b600154600160a060020a031681565b600054600160a060020a0363010000009091048116903316811461086257610000565b600080546040516301000000909104600160a060020a0316916108fc851502918591818181858888f1935050505015156103ab57610000565b5b5b5050565b60025481565b600054600160a060020a036301000000909104811690331681146108ca57610000565b6000805460408051602090810184905281517fa9059cbb0000000000000000000000000000000000000000000000000000000081526301000000909304600160a060020a0390811660048501526024840187905291519187169363a9059cbb9360448082019492918390030190829087803b156100005760325a03f115610000575050505b5b505050565b610100604051908101604052806000600160a060020a03168152602001600081526020016000815260200160008152602001600081526020016000815260200160008152602001600081525060006007600085815260200190815260200160002061010060405190810160405290816000820160009054906101000a9004600160a060020a0316600160a060020a0316600160a060020a0316815260200160018201548152602001600282015481526020016003820154815260200160048201548152602001600582015481526020016006820154815260200160078201548152505091508260001415610a4857610000565b8160400151838115610000570615610a5f57610000565b6002548410610a6d57610000565b8160400151838115610000570490508160a00151811115610a8d57610000565b610a9c8584846020015161128e565b1515610aa757610000565b60a082018051829003815260008581526007602081815260409283902086518154600160a060020a031916600160a060020a038216178255918701516001820181905593870151600282015560608701516003820155608087015160048201559351600585015560c0860151600685015560e08601519390910192909255610b319190859061105e565b1515610b3c57610000565b610b518582846080015102846060015161105e565b1515610b5c57610000565b60a0820151158015610b71575060c082015115155b15610b7f57610b7f84611151565b5b6040805185815290517fb5dc9baf0cb4e7e4759fa12eadebddf9316e26147d5a9ae150c4228d5a1dd23f9181900360200190a1610bbc85611244565b5b5050505050565b831515610bd057610000565b82851415610bdd57610000565b801580610be8575081155b15610bf257610000565b80848115610000570615610c0557610000565b6005546040805160006020918201819052825160e260020a631d010437028152600160a060020a038b8116600483015293518695949094169363740410dc9360248084019491938390030190829087803b156100005760325a03f11561000057505050604051805190501015610c7a57610000565b610c8586858761128e565b1515610c9057610000565b600554604080517fbe0140a6000000000000000000000000000000000000000000000000000000008152600160a060020a03898116600483015260006024830181905260448301869052925193169263be0140a69260648084019391929182900301818387803b156100005760325a03f115610000575050506101006040519081016040528087600160a060020a03168152602001848152602001838152602001868152602001828681156100005704815260200182815260200160068054905081526020014281525060076000600254815260200190815260200160002060008201518160000160006101000a815481600160a060020a030219169083600160a060020a031602179055506020820151816001015560408201518160020155606082015181600301556080820151816004015560a0820151816005015560c0820151816006015560e0820151816007015590505060068054806001018281815481835581811511610e2757600083815260209020610e279181019083015b80821115610e235760008155600101610e0f565b5090565b5b505050916000526020600020900160005b50600280549182905560018201905560408051918252517fb5dc9baf0cb4e7e4759fa12eadebddf9316e26147d5a9ae150c4228d5a1dd23f92509081900360200190a1610e8586611244565b5b505050505050565b600354818115610000570615610ea357610000565b600160009054906101000a9004600160a060020a0316600160a060020a031663cf35bdd060016000604051602001526040518263ffffffff1660e060020a02815260040180828152602001915050602060405180830381600087803b156100005760325a03f115610000575050604080518051600080546020938401829052845160e060020a6323b872dd028152600160a060020a038981166004830152630100000090920482166024820152604481018890529451921694506323b872dd936064808201949392918390030190829087803b156100005760325a03f1156100005750506040515115159050610f9857610000565b600554600354600160a060020a039091169063be0140a6908490600190858115610000576040805160e060020a63ffffffff8816028152600160a060020a039095166004860152921515602485015204604483015251606480830192600092919082900301818387803b156100005760325a03f1156100005750505061101d82611244565b60408051600160a060020a038416815290517f30a29a0aa75376a69254bb98dbd11db423b7e8c3473fb5bf0fcba60bcbc42c4b9181900360200190a15b5050565b600081151561106c57610000565b6001546040805160006020918201819052825160e460020a630cf35bdd028152600481018790529251600160a060020a039094169363cf35bdd09360248082019493918390030190829087803b156100005760325a03f1156100005750505060405180519050600160a060020a031663a9059cbb85856000604051602001526040518363ffffffff1660e060020a0281526004018083600160a060020a0316600160a060020a0316815260200182815260200192505050602060405180830381600087803b156100005760325a03f115610000575050604051519150505b9392505050565b6000818152600760205260409020600690810154815490919060001981019081101561000057906000526020600020900160005b5054600682815481101561000057906000526020600020900160005b50556006805460001981018083559091908280158290116111e7576000838152602090206111e79181019083015b80821115610e235760008155600101610e0f565b5090565b5b50506006548314915061122d9050578060076000600684815481101561000057906000526020600020900160005b505481526020810191909152604001600020600601555b6000828152600760205260408120600601555b5050565b60045481600160a060020a031631101561128957600454604051600160a060020a0383169180156108fc02916000818181858888f19350505050151561128957610000565b5b5b50565b600081151561129c57610000565b6001546040805160006020918201819052825160e460020a630cf35bdd028152600481018790529251600160a060020a039094169363cf35bdd09360248082019493918390030190829087803b156100005760325a03f11561000057505060408051805160006020928301819052835160e060020a6323b872dd028152600160a060020a038a811660048301523081166024830152604482018a905294519490921694506323b872dd93606480840194939192918390030190829087803b156100005760325a03f115610000575050604051519150505b93925050505600a165627a7a723058204dee0e1bf170a9d122508f3e876c4a84893b12a7345591521af4ca737bb765000029" disassembly = Disassembly(code) diff --git a/tests/ethcontract_test.py b/tests/ethcontract_test.py index e3ad31fc..a2cdc5a6 100644 --- a/tests/ethcontract_test.py +++ b/tests/ethcontract_test.py @@ -3,33 +3,45 @@ from mythril.ether.ethcontract import ETHContract class ETHContractTestCase(unittest.TestCase): - def setUp(self): self.code = "0x60606040525b603c5b60006010603e565b9050593681016040523660008237602060003683856040603f5a0204f41560545760206000f35bfe5b50565b005b73c3b2ae46792547a96b9f84405e36d0e07edcd05c5b905600a165627a7a7230582062a884f947232ada573f95940cce9c8bfb7e4e14e21df5af4e884941afb55e590029" self.creation_code = "0x60606040525b603c5b60006010603e565b9050593681016040523660008237602060003683856040603f5a0204f41560545760206000f35bfe5b50565b005b73c3b2ae46792547a96b9f84405e36d0e07edcd05c5b905600a165627a7a7230582062a884f947232ada573f95940cce9c8bfb7e4e14e21df5af4e884941afb55e590029" -class Getinstruction_listTestCase(ETHContractTestCase): +class Getinstruction_listTestCase(ETHContractTestCase): def runTest(self): contract = ETHContract(self.code, self.creation_code) disassembly = contract.disassembly - self.assertEqual(len(disassembly.instruction_list), 53, 'Error disassembling code using ETHContract.get_instruction_list()') + self.assertEqual( + len(disassembly.instruction_list), + 53, + "Error disassembling code using ETHContract.get_instruction_list()", + ) -class GetEASMTestCase(ETHContractTestCase): +class GetEASMTestCase(ETHContractTestCase): def runTest(self): contract = ETHContract(self.code) instruction_list = contract.get_easm() - self.assertTrue("PUSH1 0x60" in instruction_list, 'Error obtaining EASM code through ETHContract.get_easm()') + self.assertTrue( + "PUSH1 0x60" in instruction_list, + "Error obtaining EASM code through ETHContract.get_easm()", + ) -class MatchesExpressionTestCase(ETHContractTestCase): +class MatchesExpressionTestCase(ETHContractTestCase): def runTest(self): contract = ETHContract(self.code) - self.assertTrue(contract.matches_expression("code#PUSH1# or code#PUSH1#"), 'Unexpected result in expression matching') - self.assertFalse(contract.matches_expression("func#abcdef#"), 'Unexpected result in expression matching') + self.assertTrue( + contract.matches_expression("code#PUSH1# or code#PUSH1#"), + "Unexpected result in expression matching", + ) + self.assertFalse( + contract.matches_expression("func#abcdef#"), + "Unexpected result in expression matching", + ) diff --git a/tests/graph_test.py b/tests/graph_test.py index ba6eab53..fa280bf1 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -5,23 +5,31 @@ from mythril.ether.soliditycontract import ETHContract from tests import * import re -class GraphTest(BaseTestCase): +class GraphTest(BaseTestCase): def test_generate_graph(self): for input_file in TESTDATA_INPUTS.iterdir(): - output_expected = TESTDATA_OUTPUTS_EXPECTED / (input_file.name + ".graph.html") - output_current = TESTDATA_OUTPUTS_CURRENT / (input_file.name + ".graph.html") + output_expected = TESTDATA_OUTPUTS_EXPECTED / ( + input_file.name + ".graph.html" + ) + output_current = TESTDATA_OUTPUTS_CURRENT / ( + input_file.name + ".graph.html" + ) contract = ETHContract(input_file.read_text()) - sym = SymExecWrapper(contract, address=(util.get_indexed_address(0)), strategy="dfs") + sym = SymExecWrapper( + contract, address=(util.get_indexed_address(0)), strategy="dfs" + ) html = generate_graph(sym) output_current.write_text(html) - lines_expected = re.findall(r"'label': '.*'", str(output_current.read_text())) - lines_found = re.findall(r"'label': '.*'", str(output_current.read_text())) + lines_expected = re.findall( + r"'label': '.*'", str(output_current.read_text()) + ) + lines_found = re.findall(r"'label': '.*'", str(output_current.read_text())) if not (lines_expected == lines_found): self.found_changed_files(input_file, output_expected, output_current) diff --git a/tests/instructions/codecopy_test.py b/tests/instructions/codecopy_test.py index 53ae049d..8c43f2f5 100644 --- a/tests/instructions/codecopy_test.py +++ b/tests/instructions/codecopy_test.py @@ -5,7 +5,7 @@ from mythril.laser.ethereum.instructions import Instruction def test_codecopy_concrete(): # Arrange - active_account = Account("0x0", code= Disassembly("60606040")) + active_account = Account("0x0", code=Disassembly("60606040")) environment = Environment(active_account, None, None, None, None, None) og_state = GlobalState(None, environment, None, MachineState(gas=10000000)) diff --git a/tests/laser/evm_testsuite/evm_test.py b/tests/laser/evm_testsuite/evm_test.py index 82c2c973..7e96a8c5 100644 --- a/tests/laser/evm_testsuite/evm_test.py +++ b/tests/laser/evm_testsuite/evm_test.py @@ -10,9 +10,14 @@ import json from pathlib import Path import pytest -evm_test_dir = Path(__file__).parent / 'VMTests' +evm_test_dir = Path(__file__).parent / "VMTests" -test_types = ['vmArithmeticTest', 'vmBitwiseLogicOperation', 'vmPushDupSwapTest', 'vmTests'] +test_types = [ + "vmArithmeticTest", + "vmBitwiseLogicOperation", + "vmPushDupSwapTest", + "vmTests", +] def load_test_data(designations): @@ -24,27 +29,33 @@ def load_test_data(designations): top_level = json.load(file) for test_name, data in top_level.items(): - pre_condition = data['pre'] + pre_condition = data["pre"] - action = data['exec'] + action = data["exec"] - post_condition = data.get('post', {}) + post_condition = data.get("post", {}) - return_data.append((test_name, pre_condition, action, post_condition)) + return_data.append( + (test_name, pre_condition, action, post_condition) + ) return return_data -@pytest.mark.parametrize("test_name, pre_condition, action, post_condition", load_test_data(test_types)) -def test_vmtest(test_name: str, pre_condition: dict, action: dict, post_condition: dict) -> None: +@pytest.mark.parametrize( + "test_name, pre_condition, action, post_condition", load_test_data(test_types) +) +def test_vmtest( + test_name: str, pre_condition: dict, action: dict, post_condition: dict +) -> None: # Arrange accounts = {} for address, details in pre_condition.items(): account = Account(address) - account.code = Disassembly(details['code'][2:]) - account.balance = int(details['balance'], 16) - account.nonce = int(details['nonce'], 16) + account.code = Disassembly(details["code"][2:]) + account.balance = int(details["balance"], 16) + account.nonce = int(details["nonce"], 16) accounts[address] = account @@ -54,14 +65,14 @@ def test_vmtest(test_name: str, pre_condition: dict, action: dict, post_conditio laser_evm.time = datetime.now() execute_message_call( laser_evm, - callee_address=action['address'], - caller_address=action['caller'], - origin_address=action['origin'], - code=action['code'][2:], - gas=action['gas'], - data=binascii.a2b_hex(action['data'][2:]), - gas_price=int(action['gasPrice'], 16), - value=int(action['value'], 16) + callee_address=action["address"], + caller_address=action["caller"], + origin_address=action["origin"], + code=action["code"][2:], + gas=action["gas"], + data=binascii.a2b_hex(action["data"][2:]), + gas_price=int(action["gasPrice"], 16), + value=int(action["value"], 16), ) # Assert @@ -76,10 +87,10 @@ def test_vmtest(test_name: str, pre_condition: dict, action: dict, post_conditio for address, details in post_condition.items(): account = world_state[address] - assert account.nonce == int(details['nonce'], 16) - assert account.code.bytecode == details['code'][2:] + assert account.nonce == int(details["nonce"], 16) + assert account.code.bytecode == details["code"][2:] - for index, value in details['storage'].items(): + for index, value in details["storage"].items(): expected = int(value, 16) actual = get_concrete_int(account.storage[int(index, 16)]) assert actual == expected diff --git a/tests/laser/state/mstack_test.py b/tests/laser/state/mstack_test.py index 7ccd51c1..9724fd38 100644 --- a/tests/laser/state/mstack_test.py +++ b/tests/laser/state/mstack_test.py @@ -6,11 +6,10 @@ from tests import BaseTestCase class MachineStackTest(BaseTestCase): - @staticmethod def test_mstack_constructor(): mstack = MachineStack([1, 2]) - assert(mstack == [1, 2]) + assert mstack == [1, 2] @staticmethod def test_mstack_append_single_element(): @@ -18,7 +17,7 @@ class MachineStackTest(BaseTestCase): mstack.append(0) - assert(mstack == [0]) + assert mstack == [0] @staticmethod def test_mstack_append_multiple_elements(): @@ -53,4 +52,3 @@ class MachineStackTest(BaseTestCase): with pytest.raises(NotImplementedError): mstack += mstack - diff --git a/tests/laser/state/mstate_test.py b/tests/laser/state/mstate_test.py index 04101188..bdf0280f 100644 --- a/tests/laser/state/mstate_test.py +++ b/tests/laser/state/mstate_test.py @@ -2,14 +2,12 @@ import pytest from mythril.laser.ethereum.state import MachineState from mythril.laser.ethereum.evm_exceptions import StackUnderflowException -memory_extension_test_data = [ - (0, 0, 10), - (0, 30, 10), - (100, 22, 8) -] +memory_extension_test_data = [(0, 0, 10), (0, 30, 10), (100, 22, 8)] -@pytest.mark.parametrize("initial_size,start,extension_size", memory_extension_test_data) +@pytest.mark.parametrize( + "initial_size,start,extension_size", memory_extension_test_data +) def test_memory_extension(initial_size, start, extension_size): # Arrange machine_state = MachineState(0) @@ -23,12 +21,7 @@ def test_memory_extension(initial_size, start, extension_size): assert machine_state.memory_size == max(initial_size, start + extension_size) -stack_pop_too_many_test_data = [ - (0, 1), - (0, 2), - (5, 1), - (5, 10) -] +stack_pop_too_many_test_data = [(0, 1), (0, 2), (5, 1), (5, 10)] @pytest.mark.parametrize("initial_size,overflow", stack_pop_too_many_test_data) @@ -44,7 +37,7 @@ def test_stack_pop_too_many(initial_size, overflow): stack_pop_test_data = [ ([1, 2, 3], 2, [3, 2]), - ([1, 3, 4, 7, 7, 1, 2], 5, [2, 1, 7, 7, 4]) + ([1, 3, 4, 7, 7, 1, 2], 5, [2, 1, 7, 7, 4]), ] @@ -79,7 +72,7 @@ def test_stack_multiple_pop_(): def test_stack_single_pop(): # Arrange machine_state = MachineState(0) - machine_state.stack = [1,2,3] + machine_state.stack = [1, 2, 3] # Act result = machine_state.pop() @@ -88,22 +81,18 @@ def test_stack_single_pop(): assert isinstance(result, int) -memory_write_test_data = [ - (5, 10, [1, 2, 3]), - (0, 0, [3, 4]), - (20, 1, [2, 4, 10]) -] +memory_write_test_data = [(5, 10, [1, 2, 3]), (0, 0, [3, 4]), (20, 1, [2, 4, 10])] @pytest.mark.parametrize("initial_size, memory_offset, data", memory_write_test_data) def test_memory_write(initial_size, memory_offset, data): # Arrange machine_state = MachineState(0) - machine_state.memory = [0]*initial_size + machine_state.memory = [0] * initial_size # Act machine_state.memory_write(memory_offset, data) # Assert - assert len(machine_state.memory) == max(initial_size, memory_offset+len(data)) - assert machine_state.memory[memory_offset:memory_offset+len(data)] == data + assert len(machine_state.memory) == max(initial_size, memory_offset + len(data)) + assert machine_state.memory[memory_offset : memory_offset + len(data)] == data diff --git a/tests/laser/state/storage_test.py b/tests/laser/state/storage_test.py index f664f0a8..55e8f75b 100644 --- a/tests/laser/state/storage_test.py +++ b/tests/laser/state/storage_test.py @@ -2,11 +2,7 @@ import pytest from mythril.laser.ethereum.state import Storage from z3 import ExprRef -storage_uninitialized_test_data = [ - ({}, 1), - ({1: 5}, 2), - ({1: 5, 3: 10}, 2) -] +storage_uninitialized_test_data = [({}, 1), ({1: 5}, 2), ({1: 5, 3: 10}, 2)] @pytest.mark.parametrize("initial_storage,key", storage_uninitialized_test_data) diff --git a/tests/laser/test_transaction.py b/tests/laser/test_transaction.py index db6971ec..1aca7090 100644 --- a/tests/laser/test_transaction.py +++ b/tests/laser/test_transaction.py @@ -8,19 +8,20 @@ def test_intercontract_call(): # Arrange cfg.gbl_next_uid = 0 - caller_code = Disassembly("6080604052348015600f57600080fd5b5073deadbeefdeadbeefdeadbeefdeadbeefdeadbeef73ffffffffffffffffffffffffffffffffffffffff166389627e13336040518263ffffffff167c0100000000000000000000000000000000000000000000000000000000028152600401808273ffffffffffffffffffffffffffffffffffffffff1673ffffffffffffffffffffffffffffffffffffffff168152602001915050602060405180830381600087803b15801560be57600080fd5b505af115801560d1573d6000803e3d6000fd5b505050506040513d602081101560e657600080fd5b8101908080519060200190929190505050500000a165627a7a72305820fdb1e90f0d9775c94820e516970e0d41380a94624fa963c556145e8fb645d4c90029") + caller_code = Disassembly( + "6080604052348015600f57600080fd5b5073deadbeefdeadbeefdeadbeefdeadbeefdeadbeef73ffffffffffffffffffffffffffffffffffffffff166389627e13336040518263ffffffff167c0100000000000000000000000000000000000000000000000000000000028152600401808273ffffffffffffffffffffffffffffffffffffffff1673ffffffffffffffffffffffffffffffffffffffff168152602001915050602060405180830381600087803b15801560be57600080fd5b505af115801560d1573d6000803e3d6000fd5b505050506040513d602081101560e657600080fd5b8101908080519060200190929190505050500000a165627a7a72305820fdb1e90f0d9775c94820e516970e0d41380a94624fa963c556145e8fb645d4c90029" + ) caller_address = "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" - callee_code = Disassembly("608060405260043610603f576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806389627e13146044575b600080fd5b348015604f57600080fd5b506082600480360381019080803573ffffffffffffffffffffffffffffffffffffffff1690602001909291905050506084565b005b8073ffffffffffffffffffffffffffffffffffffffff166108fc3073ffffffffffffffffffffffffffffffffffffffff16319081150290604051600060405180830381858888f1935050505015801560e0573d6000803e3d6000fd5b50505600a165627a7a72305820a6b1335d6f994632bc9a7092d0eaa425de3dea05e015af8a94ad70b3969e117a0029") + callee_code = Disassembly( + "608060405260043610603f576000357c0100000000000000000000000000000000000000000000000000000000900463ffffffff16806389627e13146044575b600080fd5b348015604f57600080fd5b506082600480360381019080803573ffffffffffffffffffffffffffffffffffffffff1690602001909291905050506084565b005b8073ffffffffffffffffffffffffffffffffffffffff166108fc3073ffffffffffffffffffffffffffffffffffffffff16319081150290604051600060405180830381858888f1935050505015801560e0573d6000803e3d6000fd5b50505600a165627a7a72305820a6b1335d6f994632bc9a7092d0eaa425de3dea05e015af8a94ad70b3969e117a0029" + ) callee_address = "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" caller_account = Account(caller_address, caller_code, contract_name="Caller") callee_account = Account(callee_address, callee_code, contract_name="Callee") - accounts = { - caller_address: caller_account, - callee_address: callee_account - } + accounts = {caller_address: caller_account, callee_address: callee_account} laser = svm.LaserEVM(accounts) @@ -30,11 +31,11 @@ def test_intercontract_call(): # Assert # Initial node starts in contract caller assert len(laser.nodes.keys()) > 0 - assert laser.nodes[0].contract_name == 'Caller' + assert laser.nodes[0].contract_name == "Caller" # At one point we call into contract callee for node in laser.nodes.values(): - if node.contract_name == 'Callee': + if node.contract_name == "Callee": assert len(node.states[0].transaction_stack) > 1 return diff --git a/tests/laser/transaction/create_transaction_test.py b/tests/laser/transaction/create_transaction_test.py index dcecdfb8..c013043a 100644 --- a/tests/laser/transaction/create_transaction_test.py +++ b/tests/laser/transaction/create_transaction_test.py @@ -10,7 +10,7 @@ from mythril.analysis.symbolic import SymExecWrapper def test_create(): - contract = SolidityContract(str(tests.TESTDATA_INPUTS_CONTRACTS / 'calls.sol')) + contract = SolidityContract(str(tests.TESTDATA_INPUTS_CONTRACTS / "calls.sol")) laser_evm = svm.LaserEVM({}) @@ -27,12 +27,15 @@ def test_create(): found_instruction = created_account_code.instruction_list[i] actual_instruction = actual_code.instruction_list[i] - assert found_instruction['opcode'] == actual_instruction['opcode'] + assert found_instruction["opcode"] == actual_instruction["opcode"] + def test_sym_exec(): - contract = SolidityContract(str(tests.TESTDATA_INPUTS_CONTRACTS / 'calls.sol')) + contract = SolidityContract(str(tests.TESTDATA_INPUTS_CONTRACTS / "calls.sol")) - sym = SymExecWrapper(contract, address=(util.get_indexed_address(0)), strategy="dfs") + sym = SymExecWrapper( + contract, address=(util.get_indexed_address(0)), strategy="dfs" + ) issues = fire_lasers(sym) assert len(issues) != 0 diff --git a/tests/laser/transaction/symbolic_test.py b/tests/laser/transaction/symbolic_test.py index 6c27b327..e14e5618 100644 --- a/tests/laser/transaction/symbolic_test.py +++ b/tests/laser/transaction/symbolic_test.py @@ -1,5 +1,11 @@ -from mythril.laser.ethereum.transaction.symbolic import execute_message_call, execute_contract_creation -from mythril.laser.ethereum.transaction import MessageCallTransaction, ContractCreationTransaction +from mythril.laser.ethereum.transaction.symbolic import ( + execute_message_call, + execute_contract_creation, +) +from mythril.laser.ethereum.transaction import ( + MessageCallTransaction, + ContractCreationTransaction, +) from mythril.laser.ethereum.svm import LaserEVM from mythril.laser.ethereum.state import WorldState, Account import unittest.mock as mock @@ -14,7 +20,9 @@ def _is_contract_creation(_, transaction): assert isinstance(transaction, ContractCreationTransaction) -@mock.patch("mythril.laser.ethereum.transaction.symbolic._setup_global_state_for_execution") +@mock.patch( + "mythril.laser.ethereum.transaction.symbolic._setup_global_state_for_execution" +) def test_execute_message_call(mocked_setup: MagicMock): # Arrange laser_evm = LaserEVM({}) @@ -39,7 +47,9 @@ def test_execute_message_call(mocked_setup: MagicMock): assert len(laser_evm.open_states) == 0 -@mock.patch("mythril.laser.ethereum.transaction.symbolic._setup_global_state_for_execution") +@mock.patch( + "mythril.laser.ethereum.transaction.symbolic._setup_global_state_for_execution" +) def test_execute_contract_creation(mocked_setup: MagicMock): # Arrange laser_evm = LaserEVM({}) @@ -57,4 +67,3 @@ def test_execute_contract_creation(mocked_setup: MagicMock): # laser_evm.exec.assert_called_once() assert laser_evm.exec.call_count == 1 assert len(laser_evm.open_states) == 0 - diff --git a/tests/native_test.py b/tests/native_test.py index bcb0eab9..ccf8edd4 100644 --- a/tests/native_test.py +++ b/tests/native_test.py @@ -14,7 +14,10 @@ ECRECOVER_TEST = [(0, False) for _ in range(9)] IDENTITY_TEST = [(0, False) for _ in range(4)] -SHA256_TEST[0] = (5555555555555555, True) #These are Random numbers to check whether the 'if condition' is entered or not(True means entered) +SHA256_TEST[0] = ( + 5555555555555555, + True, +) # These are Random numbers to check whether the 'if condition' is entered or not(True means entered) SHA256_TEST[1] = (323232325445454546, True) SHA256_TEST[2] = (34756834765834658, False) SHA256_TEST[3] = (8756476956956795876987, True) @@ -45,12 +48,13 @@ IDENTITY_TEST[1] = (476934798798347, False) IDENTITY_TEST[2] = (7346948379483769, True) IDENTITY_TEST[3] = (83269476937987, False) + def _all_info(laser): accounts = {} for address, _account in laser.world_state.accounts.items(): account = _account.as_dict account["code"] = account["code"].instruction_list - account['balance'] = str(account['balance']) + account["balance"] = str(account["balance"]) accounts[address] = account nodes = {} @@ -62,56 +66,58 @@ def _all_info(laser): elif isinstance(state, GlobalState): environment = state.environment.as_dict environment["active_account"] = environment["active_account"].address - states.append({ - 'accounts': state.accounts.keys(), - 'environment': environment, - 'mstate': state.mstate.as_dict - }) + states.append( + { + "accounts": state.accounts.keys(), + "environment": environment, + "mstate": state.mstate.as_dict, + } + ) nodes[uid] = { - 'uid': node.uid, - 'contract_name': node.contract_name, - 'start_addr': node.start_addr, - 'states': states, - 'constraints': node.constraints, - 'function_name': node.function_name, - 'flags': str(node.flags) + "uid": node.uid, + "contract_name": node.contract_name, + "start_addr": node.start_addr, + "states": states, + "constraints": node.constraints, + "function_name": node.function_name, + "flags": str(node.flags), } edges = [edge.as_dict for edge in laser.edges] return { - 'accounts': accounts, - 'nodes': nodes, - 'edges': edges, - 'total_states': laser.total_states, - 'max_depth': laser.max_depth + "accounts": accounts, + "nodes": nodes, + "edges": edges, + "total_states": laser.total_states, + "max_depth": laser.max_depth, } + def _test_natives(laser_info, test_list, test_name): success = 0 for i, j in test_list: if (str(i) in laser_info) == j: success += 1 else: - print("Failed: "+str(i)+" "+str(j)) - assert(success == len(test_list)) + print("Failed: " + str(i) + " " + str(j)) + assert success == len(test_list) class NativeTests(BaseTestCase): @staticmethod def runTest(): - disassembly = SolidityContract('./tests/native_tests.sol').disassembly + disassembly = SolidityContract("./tests/native_tests.sol").disassembly account = Account("0x0000000000000000000000000000000000000000", disassembly) accounts = {account.address: account} - laser = svm.LaserEVM(accounts, max_depth = 100) + laser = svm.LaserEVM(accounts, max_depth=100) laser.sym_exec(account.address) laser_info = str(_all_info(laser)) - print('\n') - - _test_natives(laser_info, SHA256_TEST, 'SHA256') - _test_natives(laser_info, RIPEMD160_TEST, 'RIPEMD160') - _test_natives(laser_info, ECRECOVER_TEST, 'ECRECOVER') - _test_natives(laser_info, IDENTITY_TEST, 'IDENTITY') + print("\n") + _test_natives(laser_info, SHA256_TEST, "SHA256") + _test_natives(laser_info, RIPEMD160_TEST, "RIPEMD160") + _test_natives(laser_info, ECRECOVER_TEST, "ECRECOVER") + _test_natives(laser_info, IDENTITY_TEST, "IDENTITY") diff --git a/tests/report_test.py b/tests/report_test.py index 69edc860..26aa5be6 100644 --- a/tests/report_test.py +++ b/tests/report_test.py @@ -24,7 +24,12 @@ def _fix_debug_data(json_str): def _generate_report(input_file): contract = ETHContract(input_file.read_text()) - sym = SymExecWrapper(contract, address=(util.get_indexed_address(0)), strategy="dfs", execution_timeout=30) + sym = SymExecWrapper( + contract, + address=(util.get_indexed_address(0)), + strategy="dfs", + execution_timeout=30, + ) issues = fire_lasers(sym) report = Report() @@ -34,7 +39,7 @@ def _generate_report(input_file): return report, input_file -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def reports(): """ Fixture that analyses all reports""" pool = Pool(cpu_count()) @@ -48,14 +53,25 @@ def _assert_empty(changed_files, postfix): """ Asserts there are no changed files and otherwise builds error message""" message = "" for input_file in changed_files: - output_expected = (TESTDATA_OUTPUTS_EXPECTED / (input_file.name + postfix)).read_text().splitlines(1) - output_current = (TESTDATA_OUTPUTS_CURRENT / (input_file.name + postfix)).read_text().splitlines(1) - - difference = ''.join(difflib.unified_diff(output_expected, output_current)) - message += "Found differing file for input: {} \n Difference: \n {} \n".format(str(input_file), str(difference)) + output_expected = ( + (TESTDATA_OUTPUTS_EXPECTED / (input_file.name + postfix)) + .read_text() + .splitlines(1) + ) + output_current = ( + (TESTDATA_OUTPUTS_CURRENT / (input_file.name + postfix)) + .read_text() + .splitlines(1) + ) + + difference = "".join(difflib.unified_diff(output_expected, output_current)) + message += "Found differing file for input: {} \n Difference: \n {} \n".format( + str(input_file), str(difference) + ) assert message == "", message + def _assert_empty_json(changed_files): """ Asserts there are no changed files and otherwise builds error message""" postfix = ".json" @@ -71,15 +87,20 @@ def _assert_empty_json(changed_files): return obj for input_file in changed_files: - output_expected = json.loads((TESTDATA_OUTPUTS_EXPECTED / (input_file.name + postfix)).read_text()) - output_current = json.loads((TESTDATA_OUTPUTS_CURRENT / (input_file.name + postfix)).read_text()) - - if not ordered(output_expected.items()) == ordered(output_current.items()): + output_expected = json.loads( + (TESTDATA_OUTPUTS_EXPECTED / (input_file.name + postfix)).read_text() + ) + output_current = json.loads( + (TESTDATA_OUTPUTS_CURRENT / (input_file.name + postfix)).read_text() + ) + + if not ordered(output_expected.items()) == ordered(output_current.items()): expected.append(output_expected) actual.append(output_current) assert expected == actual + def _get_changed_files(postfix, report_builder, reports): """ Returns a generator for all unexpected changes in generated reports @@ -99,6 +120,7 @@ def _get_changed_files(postfix, report_builder, reports): def _get_changed_files_json(report_builder, reports): postfix = ".json" + def ordered(obj): if isinstance(obj, dict): return sorted((k, ordered(v)) for k, v in obj.items()) @@ -112,17 +134,33 @@ def _get_changed_files_json(report_builder, reports): output_current = TESTDATA_OUTPUTS_CURRENT / (input_file.name + postfix) output_current.write_text(report_builder(report)) - if not ordered(json.loads(output_expected.read_text())) == ordered(json.loads(output_current.read_text())): + if not ordered(json.loads(output_expected.read_text())) == ordered( + json.loads(output_current.read_text()) + ): yield input_file def test_json_report(reports): - _assert_empty_json(_get_changed_files_json(lambda report: _fix_path(_fix_debug_data(report.as_json())).strip(), reports)) + _assert_empty_json( + _get_changed_files_json( + lambda report: _fix_path(_fix_debug_data(report.as_json())).strip(), reports + ) + ) def test_markdown_report(reports): - _assert_empty(_get_changed_files('.markdown', lambda report: _fix_path(report.as_markdown()), reports), '.markdown') + _assert_empty( + _get_changed_files( + ".markdown", lambda report: _fix_path(report.as_markdown()), reports + ), + ".markdown", + ) def test_text_report(reports): - _assert_empty(_get_changed_files('.text', lambda report: _fix_path(report.as_text()), reports), '.text') + _assert_empty( + _get_changed_files( + ".text", lambda report: _fix_path(report.as_text()), reports + ), + ".text", + ) diff --git a/tests/rpc_test.py b/tests/rpc_test.py index 564c7da8..d5d908a2 100644 --- a/tests/rpc_test.py +++ b/tests/rpc_test.py @@ -2,6 +2,7 @@ from unittest import TestCase from mythril.ethereum.interface.rpc.client import EthJsonRpc + class RpcTest(TestCase): client = None @@ -18,26 +19,49 @@ class RpcTest(TestCase): def test_eth_blockNumber(self): block_number = self.client.eth_blockNumber() - self.assertGreater(block_number, 0, "we have made sure the blockNumber is > 0 for testing") + self.assertGreater( + block_number, 0, "we have made sure the blockNumber is > 0 for testing" + ) def test_eth_getBalance(self): - balance = self.client.eth_getBalance(address="0x0000000000000000000000000000000000000000") - self.assertGreater(balance, 10000000, "specified address should have a lot of balance") + balance = self.client.eth_getBalance( + address="0x0000000000000000000000000000000000000000" + ) + self.assertGreater( + balance, 10000000, "specified address should have a lot of balance" + ) def test_eth_getStorageAt(self): - storage = self.client.eth_getStorageAt(address="0x0000000000000000000000000000000000000000") - self.assertEqual(storage, '0x0000000000000000000000000000000000000000000000000000000000000000') + storage = self.client.eth_getStorageAt( + address="0x0000000000000000000000000000000000000000" + ) + self.assertEqual( + storage, + "0x0000000000000000000000000000000000000000000000000000000000000000", + ) def test_eth_getBlockByNumber(self): block = self.client.eth_getBlockByNumber(0) - self.assertEqual(block["extraData"], "0x11bbe8db4e347b4e8c937c1c8370e4b5ed33adb3db69cbdb7a38e1e50b1b82fa", "the data of the first block should be right") + self.assertEqual( + block["extraData"], + "0x11bbe8db4e347b4e8c937c1c8370e4b5ed33adb3db69cbdb7a38e1e50b1b82fa", + "the data of the first block should be right", + ) def test_eth_getCode(self): # TODO: can't find a proper address for getting code - code = self.client.eth_getCode(address="0x0000000000000000000000000000000000000001") + code = self.client.eth_getCode( + address="0x0000000000000000000000000000000000000001" + ) self.assertEqual(code, "0x") def test_eth_getTransactionReceipt(self): - transaction = self.client.eth_getTransactionReceipt(tx_hash="0xe363505adc6b2996111f8bd774f8653a61d244cc6567b5372c2e781c6c63b737") - self.assertEqual(transaction["from"], "0x22f2dcff5ad78c3eb6850b5cb951127b659522e6") - self.assertEqual(transaction["to"], "0x0000000000000000000000000000000000000000") + transaction = self.client.eth_getTransactionReceipt( + tx_hash="0xe363505adc6b2996111f8bd774f8653a61d244cc6567b5372c2e781c6c63b737" + ) + self.assertEqual( + transaction["from"], "0x22f2dcff5ad78c3eb6850b5cb951127b659522e6" + ) + self.assertEqual( + transaction["to"], "0x0000000000000000000000000000000000000000" + ) diff --git a/tests/solidity_contract_test.py b/tests/solidity_contract_test.py index fb1a596a..5513c4c1 100644 --- a/tests/solidity_contract_test.py +++ b/tests/solidity_contract_test.py @@ -5,8 +5,8 @@ from tests import BaseTestCase TEST_FILES = Path(__file__).parent / "testdata/input_contracts" -class SolidityContractTest(BaseTestCase): +class SolidityContractTest(BaseTestCase): def test_get_source_info_without_name_gets_latest_contract_info(self): input_file = TEST_FILES / "multi_contracts.sol" contract = SolidityContract(str(input_file)) @@ -36,4 +36,3 @@ class SolidityContractTest(BaseTestCase): self.assertEqual(code_info.filename, str(input_file)) self.assertEqual(code_info.lineno, 6) self.assertEqual(code_info.code, "assert(var1>0)") - diff --git a/tests/svm_test.py b/tests/svm_test.py index 79c94b85..ba580440 100644 --- a/tests/svm_test.py +++ b/tests/svm_test.py @@ -21,7 +21,7 @@ def _all_info(laser): for address, _account in laser.world_state.accounts.items(): account = _account.as_dict account["code"] = account["code"].instruction_list - account['balance'] = str(account['balance']) + account["balance"] = str(account["balance"]) accounts[address] = account nodes = {} @@ -33,35 +33,36 @@ def _all_info(laser): elif isinstance(state, GlobalState): environment = state.environment.as_dict environment["active_account"] = environment["active_account"].address - states.append({ - 'accounts': state.accounts.keys(), - 'environment': environment, - 'mstate': state.mstate.as_dict - }) + states.append( + { + "accounts": state.accounts.keys(), + "environment": environment, + "mstate": state.mstate.as_dict, + } + ) nodes[uid] = { - 'uid': node.uid, - 'contract_name': node.contract_name, - 'start_addr': node.start_addr, - 'states': states, - 'constraints': node.constraints, - 'function_name': node.function_name, - 'flags': str(node.flags) + "uid": node.uid, + "contract_name": node.contract_name, + "start_addr": node.start_addr, + "states": states, + "constraints": node.constraints, + "function_name": node.function_name, + "flags": str(node.flags), } edges = [edge.as_dict for edge in laser.edges] return { - 'accounts': accounts, - 'nodes': nodes, - 'edges': edges, - 'total_states': laser.total_states, - 'max_depth': laser.max_depth + "accounts": accounts, + "nodes": nodes, + "edges": edges, + "total_states": laser.total_states, + "max_depth": laser.max_depth, } class SVMTestCase(BaseTestCase): - def setUp(self): super(SVMTestCase, self).setUp() svm.gbl_next_uid = 0 @@ -70,8 +71,12 @@ class SVMTestCase(BaseTestCase): for input_file in TESTDATA_INPUTS_CONTRACTS.iterdir(): if input_file.name == "weak_random.sol": continue - output_expected = TESTDATA_OUTPUTS_EXPECTED_LASER_RESULT / (input_file.name + ".json") - output_current = TESTDATA_OUTPUTS_CURRENT_LASER_RESULT / (input_file.name + ".json") + output_expected = TESTDATA_OUTPUTS_EXPECTED_LASER_RESULT / ( + input_file.name + ".json" + ) + output_current = TESTDATA_OUTPUTS_CURRENT_LASER_RESULT / ( + input_file.name + ".json" + ) disassembly = SolidityContract(str(input_file)).disassembly account = Account("0x0000000000000000000000000000000000000000", disassembly) @@ -81,7 +86,9 @@ class SVMTestCase(BaseTestCase): laser.sym_exec(account.address) laser_info = _all_info(laser) - output_current.write_text(json.dumps(laser_info, cls=LaserEncoder, indent=4)) + output_current.write_text( + json.dumps(laser_info, cls=LaserEncoder, indent=4) + ) if not (output_expected.read_text() == output_expected.read_text()): self.found_changed_files(input_file, output_expected, output_current) diff --git a/tests/taint_mutate_stack_test.py b/tests/taint_mutate_stack_test.py index 94964350..6414340b 100644 --- a/tests/taint_mutate_stack_test.py +++ b/tests/taint_mutate_stack_test.py @@ -1,12 +1,13 @@ from mythril.laser.ethereum.taint_analysis import * + def test_mutate_not_tainted(): # Arrange record = TaintRecord() record.stack = [True, False, False] # Act - TaintRunner.mutate_stack(record, (2,1)) + TaintRunner.mutate_stack(record, (2, 1)) # Assert assert record.stack_tainted(0) diff --git a/tests/taint_result_test.py b/tests/taint_result_test.py index 65909cbc..c692ea80 100644 --- a/tests/taint_result_test.py +++ b/tests/taint_result_test.py @@ -7,7 +7,7 @@ def test_result_state(): taint_result = TaintResult() record = TaintRecord() state = GlobalState(2, None, None) - state.mstate.stack = [1,2,3] + state.mstate.stack = [1, 2, 3] record.add_state(state) record.stack = [False, False, False] # act @@ -24,8 +24,7 @@ def test_result_no_state(): taint_result = TaintResult() record = TaintRecord() state = GlobalState(2, None, None) - state.mstate.stack = [1,2,3] - + state.mstate.stack = [1, 2, 3] # act taint_result.add_records([record]) diff --git a/tests/taint_runner_test.py b/tests/taint_runner_test.py index 17ef00eb..51208bc1 100644 --- a/tests/taint_runner_test.py +++ b/tests/taint_runner_test.py @@ -6,13 +6,14 @@ from mythril.laser.ethereum.cfg import Node, Edge from mythril.laser.ethereum.state import MachineState, Account, Environment, GlobalState from mythril.laser.ethereum.svm import LaserEVM + def test_execute_state(mocker): record = TaintRecord() record.stack = [True, False, True] state = GlobalState(None, None, None) state.mstate.stack = [1, 2, 3] - mocker.patch.object(state, 'get_current_instruction') + mocker.patch.object(state, "get_current_instruction") state.get_current_instruction.return_value = {"opcode": "ADD"} # Act @@ -30,12 +31,12 @@ def test_execute_node(mocker): state_1 = GlobalState(None, None, None) state_1.mstate.stack = [1, 2, 3, 1] state_1.mstate.pc = 1 - mocker.patch.object(state_1, 'get_current_instruction') + mocker.patch.object(state_1, "get_current_instruction") state_1.get_current_instruction.return_value = {"opcode": "SWAP1"} state_2 = GlobalState(None, 1, None) state_2.mstate.stack = [1, 2, 4, 1] - mocker.patch.object(state_2, 'get_current_instruction') + mocker.patch.object(state_2, "get_current_instruction") state_2.get_current_instruction.return_value = {"opcode": "ADD"} node = Node("Test contract") @@ -54,19 +55,17 @@ def test_execute_node(mocker): assert state_1 in record.states - - def test_execute(mocker): - active_account = Account('0x00') + active_account = Account("0x00") environment = Environment(active_account, None, None, None, None, None) state_1 = GlobalState(None, environment, None, MachineState(gas=10000000)) state_1.mstate.stack = [1, 2] - mocker.patch.object(state_1, 'get_current_instruction') + mocker.patch.object(state_1, "get_current_instruction") state_1.get_current_instruction.return_value = {"opcode": "PUSH"} state_2 = GlobalState(None, environment, None, MachineState(gas=10000000)) state_2.mstate.stack = [1, 2, 3] - mocker.patch.object(state_2, 'get_current_instruction') + mocker.patch.object(state_2, "get_current_instruction") state_2.get_current_instruction.return_value = {"opcode": "ADD"} node_1 = Node("Test contract") @@ -74,7 +73,7 @@ def test_execute(mocker): state_3 = GlobalState(None, environment, None, MachineState(gas=10000000)) state_3.mstate.stack = [1, 2] - mocker.patch.object(state_3, 'get_current_instruction') + mocker.patch.object(state_3, "get_current_instruction") state_3.get_current_instruction.return_value = {"opcode": "ADD"} node_2 = Node("Test contract") diff --git a/tests/test_cli_opts.py b/tests/test_cli_opts.py index 8a07f3e1..5de6cdcd 100644 --- a/tests/test_cli_opts.py +++ b/tests/test_cli_opts.py @@ -4,22 +4,23 @@ import json import sys + def test_version_opt(capsys): # Check that "myth --version" returns a string with the word # "version" in it - sys.argv = ['mythril', '--version'] + sys.argv = ["mythril", "--version"] with pytest.raises(SystemExit) as pytest_wrapped_e: main() assert pytest_wrapped_e.type == SystemExit captured = capsys.readouterr() - assert captured.out.find(' version ') >= 1 + assert captured.out.find(" version ") >= 1 # Check that "myth --version -o json" returns a JSON object - sys.argv = ['mythril', '--version', '-o', 'json'] + sys.argv = ["mythril", "--version", "-o", "json"] with pytest.raises(SystemExit) as pytest_wrapped_e: main() assert pytest_wrapped_e.type == SystemExit captured = capsys.readouterr() d = json.loads(captured.out) assert isinstance(d, dict) - assert d['version_str'] + assert d["version_str"] diff --git a/tests/testdata/compile.py b/tests/testdata/compile.py index 51522e06..3811daeb 100644 --- a/tests/testdata/compile.py +++ b/tests/testdata/compile.py @@ -4,12 +4,12 @@ from mythril.ether.soliditycontract import SolidityContract # Recompiles all the to be tested contracts root = Path(__file__).parent -input = root / 'input_contracts' -output = root / 'inputs' +input = root / "input_contracts" +output = root / "inputs" for contract in input.iterdir(): sol = SolidityContract(str(contract)) code = sol.code - output_file = (output / "{}.o".format(contract.name)) + output_file = output / "{}.o".format(contract.name) output_file.write_text(code)