diff --git a/mythril/ethereum/util.py b/mythril/ethereum/util.py index abe8a99f..e1c64fce 100644 --- a/mythril/ethereum/util.py +++ b/mythril/ethereum/util.py @@ -50,7 +50,7 @@ def get_solc_json(file, solc_binary="solc", solc_settings_json=None): settings = json.load(f) settings.update( { - "optimizer": {"enabled": True}, + "optimizer": {"enabled": False}, "outputSelection": { "*": { "": ["ast"], diff --git a/mythril/laser/ethereum/function_managers/keccak_function_manager.py b/mythril/laser/ethereum/function_managers/keccak_function_manager.py index 9bac0b81..6cf7bb0f 100644 --- a/mythril/laser/ethereum/function_managers/keccak_function_manager.py +++ b/mythril/laser/ethereum/function_managers/keccak_function_manager.py @@ -14,6 +14,7 @@ from typing import Dict, Tuple, List, Optional import logging + TOTAL_PARTS = 10 ** 40 PART = (2 ** 256 - 1) // TOTAL_PARTS INTERVAL_DIFFERENCE = 10 ** 30 @@ -34,12 +35,22 @@ class KeccakFunctionManager: hash_matcher = "fffffff" # This is usually the prefix for the hash in the output def __init__(self): - self.store_function = {} # type: Dict[int, Tuple[Function, Function]] - self.interval_hook_for_size = {} # type: Dict[int, int] + self.store_function: Dict[int, Tuple[Function, Function]] = {} + self.interval_hook_for_size: Dict[int, int] = {} self._index_counter = TOTAL_PARTS - 34534 - self.hash_result_store = {} # type: Dict[int, List[BitVec]] - self.quick_inverse = {} # type: Dict[BitVec, BitVec] # This is for VMTests - self.concrete_hashes = {} # type: Dict[BitVec, BitVec] + self.hash_result_store: Dict[int, List[BitVec]] = {} + + self.quick_inverse: Dict[BitVec, BitVec] = {} # This is for VMTests + self.concrete_hashes: Dict[BitVec, BitVec] = {} + self.symbolic_inputs: Dict[int, List[BitVec]] = {} + + def reset(self): + self.store_function = {} + self.interval_hook_for_size = {} + self.hash_result_store: Dict[int, List[BitVec]] = {} + self.quick_inverse = {} + self.concrete_hashes = {} + self.symbolic_inputs = {} @staticmethod def find_concrete_keccak(data: BitVec) -> BitVec: @@ -81,7 +92,7 @@ class KeccakFunctionManager: val = 89477152217924674838424037953991966239322087453347756267410168184682657981552 return symbol_factory.BitVecVal(val, 256) - def create_keccak(self, data: BitVec) -> Tuple[BitVec, Bool]: + def create_keccak(self, data: BitVec) -> BitVec: """ Creates Keccak of the data :param data: input @@ -93,13 +104,31 @@ class KeccakFunctionManager: if data.symbolic is False: concrete_hash = self.find_concrete_keccak(data) self.concrete_hashes[data] = concrete_hash - # This condition is essential to avoid some edge cases - condition = And(func(data) == concrete_hash, inverse(func(data)) == data) - return concrete_hash, condition + return concrete_hash + + if length not in self.symbolic_inputs: + self.symbolic_inputs[length] = [] - condition = self._create_condition(func_input=data) + self.symbolic_inputs[length].append(data) self.hash_result_store[length].append(func(data)) - return func(data), condition + return func(data) + + def create_conditions(self) -> Bool: + condition = symbol_factory.Bool(True) + for inputs_list in self.symbolic_inputs.values(): + for symbolic_input in inputs_list: + condition = And( + condition, self._create_condition(func_input=symbolic_input) + ) + for concrete_input, concrete_hash in self.concrete_hashes.items(): + func, inverse = self.get_function(concrete_input.size()) + condition = And( + condition, + func(concrete_input) == concrete_hash, + inverse(func(concrete_input)) == concrete_input, + ) + + return condition def get_concrete_hash_data(self, model) -> Dict[int, List[Optional[int]]]: """ @@ -145,8 +174,9 @@ class KeccakFunctionManager: ) concrete_cond = symbol_factory.Bool(False) for key, keccak in self.concrete_hashes.items(): - hash_eq = And(func(func_input) == keccak, key == func_input) - concrete_cond = Or(concrete_cond, hash_eq) + if key.size() == func_input.size(): + hash_eq = And(func(func_input) == keccak, key == func_input) + concrete_cond = Or(concrete_cond, hash_eq) return And(inv(func(func_input)) == func_input, Or(cond, concrete_cond)) diff --git a/mythril/laser/ethereum/instructions.py b/mythril/laser/ethereum/instructions.py index adc3be60..4c68eb33 100644 --- a/mythril/laser/ethereum/instructions.py +++ b/mythril/laser/ethereum/instructions.py @@ -1017,9 +1017,8 @@ class Instruction: state.stack.append(result) return [global_state] - result, condition = keccak_function_manager.create_keccak(data) + result = keccak_function_manager.create_keccak(data) state.stack.append(result) - global_state.world_state.constraints.append(condition) return [global_state] @@ -1538,6 +1537,7 @@ class Instruction: states = [] op0, condition = state.stack.pop(), state.stack.pop() + try: jump_addr = util.get_concrete_int(op0) except TypeError: @@ -1740,7 +1740,7 @@ class Instruction: if create2_salt.size() != 256: pad = symbol_factory.BitVecVal(0, 256 - create2_salt.size()) create2_salt = Concat(pad, create2_salt) - address, constraint = keccak_function_manager.create_keccak( + address = keccak_function_manager.create_keccak( Concat( symbol_factory.BitVecVal(255, 8), caller, @@ -1749,7 +1749,7 @@ class Instruction: ) ) contract_address = Extract(255, 96, address) - global_state.world_state.constraints.append(constraint) + else: salt = hex(create2_salt.value)[2:] salt = "0" * (64 - len(salt)) + salt diff --git a/mythril/laser/ethereum/state/constraints.py b/mythril/laser/ethereum/state/constraints.py index d9fc55cc..4ca59bf3 100644 --- a/mythril/laser/ethereum/state/constraints.py +++ b/mythril/laser/ethereum/state/constraints.py @@ -4,6 +4,7 @@ from mythril.exceptions import UnsatError from mythril.laser.smt import symbol_factory, simplify, Bool from mythril.support.model import get_model from typing import Iterable, List, Optional, Union +from mythril.laser.ethereum.function_managers import keccak_function_manager class Constraints(list): @@ -29,7 +30,7 @@ class Constraints(list): """ try: - get_model(tuple(self[:])) + get_model(self) except UnsatError: return False return True @@ -108,5 +109,8 @@ class Constraints(list): for constraint in constraints ] + def get_all_constraints(self): + return self[:] + [keccak_function_manager.create_conditions()] + def __hash__(self): return tuple(self[:]).__hash__() diff --git a/mythril/laser/ethereum/svm.py b/mythril/laser/ethereum/svm.py index 917f123c..2e467f43 100644 --- a/mythril/laser/ethereum/svm.py +++ b/mythril/laser/ethereum/svm.py @@ -255,12 +255,14 @@ class LaserEVM: except NotImplementedError: log.debug("Encountered unimplemented instruction") continue + if args.sparse_pruning is False: new_states = [ state for state in new_states if state.world_state.constraints.is_possible ] + self.manage_cfg(op_code, new_states) # TODO: What about op_code is None? if new_states: self.work_list += new_states diff --git a/mythril/support/model.py b/mythril/support/model.py index ee91ff7a..2c6a8fe6 100644 --- a/mythril/support/model.py +++ b/mythril/support/model.py @@ -32,7 +32,7 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True for constraint in constraints: if type(constraint) == bool and not constraint: raise UnsatError - + constraints = constraints.get_all_constraints() constraints = [constraint for constraint in constraints if type(constraint) != bool] for constraint in constraints: diff --git a/tests/integration_tests/test_safe_functions.py b/tests/integration_tests/test_safe_functions.py index 34d8db14..f6538cea 100644 --- a/tests/integration_tests/test_safe_functions.py +++ b/tests/integration_tests/test_safe_functions.py @@ -8,10 +8,16 @@ from tests import PROJECT_DIR, TESTDATA MYTH = str(PROJECT_DIR / "myth") test_data = ( ("suicide.sol", [], "0.5.0"), - ("overflow.sol", ["balanceOf(address)"], "0.5.0"), + ("overflow.sol", ["balanceOf(address)", "totalSupply()"], "0.5.0"), ( "ether_send.sol", - ["crowdfunding()", "withdrawfunds()", "owner()", "balances(address)"], + [ + "crowdfunding()", + "withdrawfunds()", + "owner()", + "balances(address)", + "getBalance()", + ], "0.5.0", ), ) diff --git a/tests/laser/keccak_tests.py b/tests/laser/keccak_tests.py index 0f83346d..6dabc37a 100644 --- a/tests/laser/keccak_tests.py +++ b/tests/laser/keccak_tests.py @@ -29,10 +29,10 @@ import pytest ) def test_keccak_basic(input1, input2, expected): s = Solver() - - o1, c1 = keccak_function_manager.create_keccak(input1) - o2, c2 = keccak_function_manager.create_keccak(input2) - s.add(And(c1, c2)) + keccak_function_manager.reset() + o1 = keccak_function_manager.create_keccak(input1) + o2 = keccak_function_manager.create_keccak(input2) + s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) assert s.check() == expected @@ -44,11 +44,13 @@ def test_keccak_symbol_and_val(): :return: """ s = Solver() + keccak_function_manager.reset() hundred = symbol_factory.BitVecVal(100, 256) n = symbol_factory.BitVecSym("n", 256) - o1, c1 = keccak_function_manager.create_keccak(hundred) - o2, c2 = keccak_function_manager.create_keccak(n) - s.add(And(c1, c2)) + o1 = keccak_function_manager.create_keccak(hundred) + o2 = keccak_function_manager.create_keccak(n) + + s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) s.add(n == symbol_factory.BitVecVal(10, 256)) assert s.check() == z3.unsat @@ -59,19 +61,20 @@ def test_keccak_complex_eq(): check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b :return: """ + keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 160) - o1, c1 = keccak_function_manager.create_keccak(a) - o2, c2 = keccak_function_manager.create_keccak(b) - s.add(And(c1, c2)) + o1 = keccak_function_manager.create_keccak(a) + o2 = keccak_function_manager.create_keccak(b) + two = symbol_factory.BitVecVal(2, 256) o1 = two * o1 o2 = two * o2 - o1, c1 = keccak_function_manager.create_keccak(o1) - o2, c2 = keccak_function_manager.create_keccak(o2) + o1 = keccak_function_manager.create_keccak(o1) + o2 = keccak_function_manager.create_keccak(o2) - s.add(And(c1, c2)) + s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) s.add(a != b) @@ -85,19 +88,20 @@ def test_keccak_complex_eq2(): (solution is literally the opposite of prev one) so it will take forever to solve. :return: """ + keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 160) - o1, c1 = keccak_function_manager.create_keccak(a) - o2, c2 = keccak_function_manager.create_keccak(b) - s.add(And(c1, c2)) + o1 = keccak_function_manager.create_keccak(a) + o2 = keccak_function_manager.create_keccak(b) + two = symbol_factory.BitVecVal(2, 256) o1 = two * o1 o2 = two * o2 - o1, c1 = keccak_function_manager.create_keccak(o1) - o2, c2 = keccak_function_manager.create_keccak(o2) + o1 = keccak_function_manager.create_keccak(o1) + o2 = keccak_function_manager.create_keccak(o2) - s.add(And(c1, c2)) + s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) assert s.check() == z3.sat @@ -108,12 +112,13 @@ def test_keccak_simple_number(): check for keccak(b) == 10 :return: """ + keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) ten = symbol_factory.BitVecVal(10, 256) - o, c = keccak_function_manager.create_keccak(a) + o = keccak_function_manager.create_keccak(a) - s.add(c) + s.add(keccak_function_manager.create_conditions()) s.add(ten == o) assert s.check() == z3.unsat @@ -124,15 +129,17 @@ def test_keccak_other_num(): check keccak(keccak(a)*2) == b :return: """ + keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 256) - o, c = keccak_function_manager.create_keccak(a) + o = keccak_function_manager.create_keccak(a) two = symbol_factory.BitVecVal(2, 256) o = two * o - s.add(c) - o, c = keccak_function_manager.create_keccak(o) - s.add(c) + + o = keccak_function_manager.create_keccak(o) + + s.add(keccak_function_manager.create_conditions()) s.add(b == o) assert s.check() == z3.sat diff --git a/tests/solidity_contract_test.py b/tests/solidity_contract_test.py index 9e3976a4..bde93c20 100644 --- a/tests/solidity_contract_test.py +++ b/tests/solidity_contract_test.py @@ -36,7 +36,7 @@ class SolidityContractTest(BaseTestCase): str(input_file), name="AssertFail", solc_binary=solc_binary ) - code_info = contract.get_source_info(58, constructor=True) + code_info = contract.get_source_info(75, constructor=True) self.assertEqual(code_info.filename, str(input_file)) self.assertEqual(code_info.lineno, 6)