diff --git a/mythril/solidity/soliditycontract.py b/mythril/solidity/soliditycontract.py index 675b93f9..4922c528 100644 --- a/mythril/solidity/soliditycontract.py +++ b/mythril/solidity/soliditycontract.py @@ -11,6 +11,65 @@ from mythril.exceptions import NoContractFoundError log = logging.getLogger(__name__) +class SolcAST: + def __init__(self, ast): + self.ast = ast + + @property + def node_type(self): + if "nodeType" in self.ast: + return self.ast["nodeType"] + if "name" in self.ast: + return self.ast["name"] + assert False, "Unknown AST type has been fed to SolcAST" + + @property + def abs_path(self): + if "absolutePath" in self.ast: + return self.ast["absolutePath"] + else: + return None + + @property + def nodes(self): + if "nodes" in self.ast: + return self.ast["nodes"] + if "children" in self.ast: + return self.ast["children"] + assert False, "Unknown AST type has been fed to SolcAST" + + def __next__(self): + yield self.ast.__next__() + + def __getitem__(self, item): + return self.ast[item] + + +class SolcSource: + def __init__(self, source): + self.source = source + + @property + def ast(self): + if "ast" in self.source: + return SolcAST(self.source["ast"]) + if "legacyAST" in self.source: + return SolcAST(self.source["legacyAST"]) + assert False, "Unknown source type has been fed to SolcSource" + + @property + def id(self): + return self.source["id"] + + @property + def name(self): + return self.source["name"] + + @property + def contents(self): + return self.source["contents"] + + class SourceMapping: def __init__(self, solidity_file_idx, offset, length, lineno, mapping): """Representation of a source mapping for a Solidity file.""" @@ -87,7 +146,7 @@ class SolidityContract(EVMContract): input_file, solc_settings_json=solc_settings_json, solc_binary=solc_binary ) - self.solc_indices = self.get_solc_indices(data) + self.solc_indices = self.get_solc_indices(input_file, data) self.solc_json = data self.input_file = input_file has_contract = False @@ -135,21 +194,22 @@ class SolidityContract(EVMContract): @staticmethod def get_sources(indices_data: Dict, source_data: Dict) -> None: """ - Get source indices mapping + Get source indices mapping. Function not needed for older solc versions. """ + if "generatedSources" not in source_data: return sources = source_data["generatedSources"] for source in sources: full_contract_src_maps = SolidityContract.get_full_contract_src_maps( - source["ast"] + SolcAST(source["ast"]) ) indices_data[source["id"]] = SolidityFile( source["name"], source["contents"], full_contract_src_maps ) @staticmethod - def get_solc_indices(data: Dict) -> Dict: + def get_solc_indices(input_file: str, data: Dict) -> Dict: """ Returns solc file indices """ @@ -161,31 +221,38 @@ class SolidityContract(EVMContract): indices, source_data["evm"]["deployedBytecode"] ) for source in data["sources"].values(): + source = SolcSource(source) full_contract_src_maps = SolidityContract.get_full_contract_src_maps( - source["ast"] + source.ast ) - with open(source["ast"]["absolutePath"], "rb") as f: + if source.ast.abs_path is not None: + abs_path = source.ast.abs_path + else: + abs_path = input_file + + with open(abs_path, "rb") as f: code = f.read() - indices[source["id"]] = SolidityFile( - source["ast"]["absolutePath"], + indices[source.id] = SolidityFile( + abs_path, code.decode("utf-8"), full_contract_src_maps, ) return indices @staticmethod - def get_full_contract_src_maps(ast: Dict) -> Set[str]: + def get_full_contract_src_maps(ast: SolcAST) -> Set[str]: """ Takes a solc AST and gets the src mappings for all the contracts defined in the top level of the ast :param ast: AST of the contract :return: The source maps """ + print source_maps = set() - if ast["nodeType"] == "SourceUnit": - for child in ast["nodes"]: + if ast.node_type == "SourceUnit": + for child in ast.nodes: if child.get("contractKind"): source_maps.add(child["src"]) - elif ast["nodeType"] == "YulBlock": + elif ast.node_type == "YulBlock": for child in ast["statements"]: source_maps.add(child["src"]) diff --git a/tests/integration_tests/old_version_test.py b/tests/integration_tests/old_version_test.py new file mode 100644 index 00000000..52b37517 --- /dev/null +++ b/tests/integration_tests/old_version_test.py @@ -0,0 +1,20 @@ +import pytest +import json +import sys + +from tests import PROJECT_DIR, TESTDATA +from utils import output_of + +MYTH = str(PROJECT_DIR / "myth") +test_data = ( + ("old_origin.sol", 1), + ("old_version.sol", 2), +) + + +@pytest.mark.parametrize("file_name, issues", test_data) +def test_analysis_old(file_name, issues): + file = str(TESTDATA / "input_contracts" / file_name) + command = f"python3 {MYTH} analyze {file} -o jsonv2" + output = json.loads(output_of(command)) + assert len(output[0]["issues"]) == issues diff --git a/tests/testdata/input_contracts/old_origin.sol b/tests/testdata/input_contracts/old_origin.sol new file mode 100644 index 00000000..020f31db --- /dev/null +++ b/tests/testdata/input_contracts/old_origin.sol @@ -0,0 +1,36 @@ +pragma solidity ^0.4.11; + + +contract Origin { + address public owner; + + + /** + * @dev The Ownable constructor sets the original `owner` of the contract to the sender + * account. + */ + function Origin() { + owner = msg.sender; + } + + + /** + * @dev Throws if called by any account other than the owner. + */ + modifier onlyOwner() { + require(tx.origin != owner); + _; + } + + + /** + * @dev Allows the current owner to transfer control of the contract to a newOwner. + * @param newOwner The address to transfer ownership to. + */ + function transferOwnership(address newOwner) public onlyOwner { + if (newOwner != address(0)) { + owner = newOwner; + } + } + +} diff --git a/tests/testdata/input_contracts/old_version.sol b/tests/testdata/input_contracts/old_version.sol new file mode 100644 index 00000000..95a27bc4 --- /dev/null +++ b/tests/testdata/input_contracts/old_version.sol @@ -0,0 +1,2 @@ +pragma solidity 0.4.11; +contract test { }