diff --git a/mythril/ether/ethcontract.py b/mythril/ether/ethcontract.py index 24a859e8..b43b1919 100644 --- a/mythril/ether/ethcontract.py +++ b/mythril/ether/ethcontract.py @@ -18,7 +18,7 @@ class ETHContract(persistent.Persistent): self.code = code self.disassembly = Disassembly(code, enable_online_lookup=enable_online_lookup) - self.creation_disassemble = Disassembly(creation_code, enable_online_lookup=enable_online_lookup) + self.creation_disassembly = Disassembly(creation_code, enable_online_lookup=enable_online_lookup) def as_dict(self): diff --git a/mythril/ether/soliditycontract.py b/mythril/ether/soliditycontract.py index 22f742d8..0642ea21 100644 --- a/mythril/ether/soliditycontract.py +++ b/mythril/ether/soliditycontract.py @@ -92,10 +92,9 @@ class SolidityContract(ETHContract): super().__init__(code, creation_code, name=name) def get_source_info(self, address, constructor=False): - disassembly = self.creation_disassemble if constructor else self.disassembly + disassembly = self.creation_disassembly if constructor else self.disassembly mappings = self.constructor_mappings if constructor else self.mappings index = helper.get_instruction_index(disassembly.instruction_list, address) - solidity_file = self.solidity_files[mappings[index].solidity_file_idx] filename = solidity_file.filename diff --git a/tests/solidity_contract_test.py b/tests/solidity_contract_test.py index b1c02bf7..fb1a596a 100644 --- a/tests/solidity_contract_test.py +++ b/tests/solidity_contract_test.py @@ -26,3 +26,14 @@ class SolidityContractTest(BaseTestCase): self.assertEqual(code_info.filename, str(input_file)) self.assertEqual(code_info.lineno, 6) self.assertEqual(code_info.code, "msg.sender.transfer(1 ether)") + + def test_get_source_info_with_contract_name_specified_constructor(self): + input_file = TEST_FILES / "constructor_assert.sol" + contract = SolidityContract(str(input_file), name="AssertFail") + + code_info = contract.get_source_info(62, constructor=True) + + self.assertEqual(code_info.filename, str(input_file)) + self.assertEqual(code_info.lineno, 6) + self.assertEqual(code_info.code, "assert(var1>0)") +