Add new tests for SolidityContract

pull/481/head
Nikhil Parasaram 6 years ago
parent 7387f14ba1
commit e218c20e91
  1. 2
      mythril/ether/ethcontract.py
  2. 3
      mythril/ether/soliditycontract.py
  3. 11
      tests/solidity_contract_test.py

@ -18,7 +18,7 @@ class ETHContract(persistent.Persistent):
self.code = code
self.disassembly = Disassembly(code, enable_online_lookup=enable_online_lookup)
self.creation_disassemble = Disassembly(creation_code, enable_online_lookup=enable_online_lookup)
self.creation_disassembly = Disassembly(creation_code, enable_online_lookup=enable_online_lookup)
def as_dict(self):

@ -92,10 +92,9 @@ class SolidityContract(ETHContract):
super().__init__(code, creation_code, name=name)
def get_source_info(self, address, constructor=False):
disassembly = self.creation_disassemble if constructor else self.disassembly
disassembly = self.creation_disassembly if constructor else self.disassembly
mappings = self.constructor_mappings if constructor else self.mappings
index = helper.get_instruction_index(disassembly.instruction_list, address)
solidity_file = self.solidity_files[mappings[index].solidity_file_idx]
filename = solidity_file.filename

@ -26,3 +26,14 @@ class SolidityContractTest(BaseTestCase):
self.assertEqual(code_info.filename, str(input_file))
self.assertEqual(code_info.lineno, 6)
self.assertEqual(code_info.code, "msg.sender.transfer(1 ether)")
def test_get_source_info_with_contract_name_specified_constructor(self):
input_file = TEST_FILES / "constructor_assert.sol"
contract = SolidityContract(str(input_file), name="AssertFail")
code_info = contract.get_source_info(62, constructor=True)
self.assertEqual(code_info.filename, str(input_file))
self.assertEqual(code_info.lineno, 6)
self.assertEqual(code_info.code, "assert(var1>0)")

Loading…
Cancel
Save