import pytest import z3 from mythril.laser.ethereum.function_managers import keccak_function_manager from mythril.laser.smt import Solver, symbol_factory @pytest.mark.parametrize( "input1, input2, expected", [ (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(101, 8), z3.unsat), (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 16), z3.unsat), (symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 8), z3.sat), ( symbol_factory.BitVecSym("N1", 256), symbol_factory.BitVecSym("N2", 256), z3.sat, ), ( symbol_factory.BitVecVal(100, 256), symbol_factory.BitVecSym("N1", 256), z3.sat, ), ( symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecSym("N1", 256), z3.unsat, ), ], ) def test_keccak_basic(input1, input2, expected): s = Solver() keccak_function_manager.reset() o1 = keccak_function_manager.create_keccak(input1) o2 = keccak_function_manager.create_keccak(input2) s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) assert s.check() == expected def test_keccak_symbol_and_val(): """ check keccak(100) == keccak(n) && n == 10 :return: """ s = Solver() keccak_function_manager.reset() hundred = symbol_factory.BitVecVal(100, 256) n = symbol_factory.BitVecSym("n", 256) o1 = keccak_function_manager.create_keccak(hundred) o2 = keccak_function_manager.create_keccak(n) s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) s.add(n == symbol_factory.BitVecVal(10, 256)) assert s.check() == z3.unsat def test_keccak_complex_eq(): """ check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b :return: """ keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 160) o1 = keccak_function_manager.create_keccak(a) o2 = keccak_function_manager.create_keccak(b) two = symbol_factory.BitVecVal(2, 256) o1 = two * o1 o2 = two * o2 o1 = keccak_function_manager.create_keccak(o1) o2 = keccak_function_manager.create_keccak(o2) s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) s.add(a != b) assert s.check() == z3.unsat def test_keccak_complex_eq2(): """ check for keccak(keccak(b)*2) == keccak(keccak(a)*2) This isn't combined with prev test because incremental solving here requires extra-extra work (solution is literally the opposite of prev one) so it will take forever to solve. :return: """ keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 160) o1 = keccak_function_manager.create_keccak(a) o2 = keccak_function_manager.create_keccak(b) two = symbol_factory.BitVecVal(2, 256) o1 = two * o1 o2 = two * o2 o1 = keccak_function_manager.create_keccak(o1) o2 = keccak_function_manager.create_keccak(o2) s.add(keccak_function_manager.create_conditions()) s.add(o1 == o2) assert s.check() == z3.sat def test_keccak_simple_number(): """ check for keccak(b) == 10 :return: """ keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) ten = symbol_factory.BitVecVal(10, 256) o = keccak_function_manager.create_keccak(a) s.add(keccak_function_manager.create_conditions()) s.add(ten == o) assert s.check() == z3.unsat def test_keccak_other_num(): """ check keccak(keccak(a)*2) == b :return: """ keccak_function_manager.reset() s = Solver() a = symbol_factory.BitVecSym("a", 160) b = symbol_factory.BitVecSym("b", 256) o = keccak_function_manager.create_keccak(a) two = symbol_factory.BitVecVal(2, 256) o = two * o o = keccak_function_manager.create_keccak(o) s.add(keccak_function_manager.create_conditions()) s.add(b == o) assert s.check() == z3.sat