diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index e1759915..0227acaf 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -27,6 +27,9 @@ from mythril.version import VERSION # logging.basicConfig(level=logging.DEBUG) +ANALYZE_LIST = ("a", "analyze") +DISASSEMBLE_LIST = ("d", "disassemble") + log = logging.getLogger(__name__) @@ -161,16 +164,22 @@ def main() -> None: ) subparsers = parser.add_subparsers(dest="command", help="Commands") - analyzer_parser = subparsers.add_parser( - "analyze", - help="Triggers the analysis of the smart contract", - parents=[rpc_parser, utilities_parser, input_parser, output_parser], - ) - disassemble_parser = subparsers.add_parser( - "disassemble", - help="Disassembles the smart contract", - parents=[rpc_parser, utilities_parser, input_parser, output_parser], - ) + for analyze_string in ANALYZE_LIST: + analyzer_parser = subparsers.add_parser( + analyze_string, + help="Triggers the analysis of the smart contract", + parents=[rpc_parser, utilities_parser, input_parser, output_parser], + ) + create_analyzer_parser(analyzer_parser) + + for disassemble_string in DISASSEMBLE_LIST: + disassemble_parser = subparsers.add_parser( + disassemble_string, + help="Disassembles the smart contract", + parents=[rpc_parser, utilities_parser, input_parser, output_parser], + ) + create_disassemble_parser(disassemble_parser) + read_storage_parser = subparsers.add_parser( "read-storage", help="Retrieves storage slots from a given address through rpc", @@ -195,8 +204,6 @@ def main() -> None: "version", parents=[output_parser], help="Outputs the version" ) - create_disassemble_parser(disassemble_parser) - create_analyzer_parser(analyzer_parser) create_read_storage_parser(read_storage_parser) create_hash_to_addr_parser(contract_hash_to_addr) create_func_to_hash_parser(contract_func_to_hash) @@ -390,7 +397,7 @@ def validate_args(args: Namespace): args.outform, "Invalid -v value, you can find valid values in usage" ) - if args.command == "analyze": + if args.command in ANALYZE_LIST: if args.query_signature and sigs.ethereum_input_decoder is None: exit_with_error( args.outform, @@ -412,7 +419,8 @@ def set_config(args: Namespace): """ config = MythrilConfig() if ( - args.command == "analyze" and (args.dynld or not args.no_onchain_storage_access) + args.command in ANALYZE_LIST + and (args.dynld or not args.no_onchain_storage_access) ) and not (args.rpc or args.i): config.set_api_from_config_path() @@ -474,7 +482,7 @@ def load_code(disassembler: MythrilDisassembler, args: Namespace): address, _ = disassembler.load_from_address(args.address) elif args.__dict__.get("solidity_file", False): # Compile Solidity source file(s) - if args.command == "analyze" and args.graph and len(args.solidity_file) > 1: + if args.command in ANALYZE_LIST and args.graph and len(args.solidity_file) > 1: exit_with_error( args.outform, "Cannot generate call graphs from multiple input files. Please do it one at a time.", @@ -513,15 +521,14 @@ def execute_command( print(storage) return - if args.command == "disassemble": - # or mythril.disassemble(mythril.contracts[0]) + if args.command in DISASSEMBLE_LIST: if disassembler.contracts[0].code: print("Runtime Disassembly: \n" + disassembler.contracts[0].get_easm()) if disassembler.contracts[0].creation_code: print("Disassembly: \n" + disassembler.contracts[0].get_creation_easm()) - elif args.command == "analyze": + elif args.command in ANALYZE_LIST: analyzer = MythrilAnalyzer( strategy=args.strategy, disassembler=disassembler, diff --git a/tests/cmd_line_test.py b/tests/cmd_line_test.py index a3ed1949..5fa6e3bb 100644 --- a/tests/cmd_line_test.py +++ b/tests/cmd_line_test.py @@ -58,6 +58,13 @@ class CommandLineToolTestCase(BaseTestCase): ) self.assertIn(""""success": false""", output_of(command)) + def test_storage(self): + solidity_file = str(TESTDATA / "input_contracts" / "origin.sol") + command = """python3 {} read-storage "438767356, 3" 0x76799f77587738bfeef09452df215b63d2cfb08a """.format( + MYTH + ) + self.assertIn("0x1a270efc", output_of(command)) + class TruffleTestCase(BaseTestCase): def test_analysis_truffle_project(self):