Fixes issues by considering future concrete hashes (#1547)

* Fix issues with hashing

* Fix offsets and edge cases
pull/1549/head
Nikhil Parasaram 3 years ago committed by GitHub
parent f35f1df509
commit 9d7873621e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      mythril/ethereum/util.py
  2. 52
      mythril/laser/ethereum/function_managers/keccak_function_manager.py
  3. 8
      mythril/laser/ethereum/instructions.py
  4. 6
      mythril/laser/ethereum/state/constraints.py
  5. 2
      mythril/laser/ethereum/svm.py
  6. 2
      mythril/support/model.py
  7. 10
      tests/integration_tests/test_safe_functions.py
  8. 57
      tests/laser/keccak_tests.py
  9. 2
      tests/solidity_contract_test.py

@ -50,7 +50,7 @@ def get_solc_json(file, solc_binary="solc", solc_settings_json=None):
settings = json.load(f) settings = json.load(f)
settings.update( settings.update(
{ {
"optimizer": {"enabled": True}, "optimizer": {"enabled": False},
"outputSelection": { "outputSelection": {
"*": { "*": {
"": ["ast"], "": ["ast"],

@ -14,6 +14,7 @@ from typing import Dict, Tuple, List, Optional
import logging import logging
TOTAL_PARTS = 10 ** 40 TOTAL_PARTS = 10 ** 40
PART = (2 ** 256 - 1) // TOTAL_PARTS PART = (2 ** 256 - 1) // TOTAL_PARTS
INTERVAL_DIFFERENCE = 10 ** 30 INTERVAL_DIFFERENCE = 10 ** 30
@ -34,12 +35,22 @@ class KeccakFunctionManager:
hash_matcher = "fffffff" # This is usually the prefix for the hash in the output hash_matcher = "fffffff" # This is usually the prefix for the hash in the output
def __init__(self): def __init__(self):
self.store_function = {} # type: Dict[int, Tuple[Function, Function]] self.store_function: Dict[int, Tuple[Function, Function]] = {}
self.interval_hook_for_size = {} # type: Dict[int, int] self.interval_hook_for_size: Dict[int, int] = {}
self._index_counter = TOTAL_PARTS - 34534 self._index_counter = TOTAL_PARTS - 34534
self.hash_result_store = {} # type: Dict[int, List[BitVec]] self.hash_result_store: Dict[int, List[BitVec]] = {}
self.quick_inverse = {} # type: Dict[BitVec, BitVec] # This is for VMTests
self.concrete_hashes = {} # type: Dict[BitVec, BitVec] self.quick_inverse: Dict[BitVec, BitVec] = {} # This is for VMTests
self.concrete_hashes: Dict[BitVec, BitVec] = {}
self.symbolic_inputs: Dict[int, List[BitVec]] = {}
def reset(self):
self.store_function = {}
self.interval_hook_for_size = {}
self.hash_result_store: Dict[int, List[BitVec]] = {}
self.quick_inverse = {}
self.concrete_hashes = {}
self.symbolic_inputs = {}
@staticmethod @staticmethod
def find_concrete_keccak(data: BitVec) -> BitVec: def find_concrete_keccak(data: BitVec) -> BitVec:
@ -81,7 +92,7 @@ class KeccakFunctionManager:
val = 89477152217924674838424037953991966239322087453347756267410168184682657981552 val = 89477152217924674838424037953991966239322087453347756267410168184682657981552
return symbol_factory.BitVecVal(val, 256) return symbol_factory.BitVecVal(val, 256)
def create_keccak(self, data: BitVec) -> Tuple[BitVec, Bool]: def create_keccak(self, data: BitVec) -> BitVec:
""" """
Creates Keccak of the data Creates Keccak of the data
:param data: input :param data: input
@ -93,13 +104,31 @@ class KeccakFunctionManager:
if data.symbolic is False: if data.symbolic is False:
concrete_hash = self.find_concrete_keccak(data) concrete_hash = self.find_concrete_keccak(data)
self.concrete_hashes[data] = concrete_hash self.concrete_hashes[data] = concrete_hash
# This condition is essential to avoid some edge cases return concrete_hash
condition = And(func(data) == concrete_hash, inverse(func(data)) == data)
return concrete_hash, condition if length not in self.symbolic_inputs:
self.symbolic_inputs[length] = []
condition = self._create_condition(func_input=data) self.symbolic_inputs[length].append(data)
self.hash_result_store[length].append(func(data)) self.hash_result_store[length].append(func(data))
return func(data), condition return func(data)
def create_conditions(self) -> Bool:
condition = symbol_factory.Bool(True)
for inputs_list in self.symbolic_inputs.values():
for symbolic_input in inputs_list:
condition = And(
condition, self._create_condition(func_input=symbolic_input)
)
for concrete_input, concrete_hash in self.concrete_hashes.items():
func, inverse = self.get_function(concrete_input.size())
condition = And(
condition,
func(concrete_input) == concrete_hash,
inverse(func(concrete_input)) == concrete_input,
)
return condition
def get_concrete_hash_data(self, model) -> Dict[int, List[Optional[int]]]: def get_concrete_hash_data(self, model) -> Dict[int, List[Optional[int]]]:
""" """
@ -145,6 +174,7 @@ class KeccakFunctionManager:
) )
concrete_cond = symbol_factory.Bool(False) concrete_cond = symbol_factory.Bool(False)
for key, keccak in self.concrete_hashes.items(): for key, keccak in self.concrete_hashes.items():
if key.size() == func_input.size():
hash_eq = And(func(func_input) == keccak, key == func_input) hash_eq = And(func(func_input) == keccak, key == func_input)
concrete_cond = Or(concrete_cond, hash_eq) concrete_cond = Or(concrete_cond, hash_eq)
return And(inv(func(func_input)) == func_input, Or(cond, concrete_cond)) return And(inv(func(func_input)) == func_input, Or(cond, concrete_cond))

@ -1017,9 +1017,8 @@ class Instruction:
state.stack.append(result) state.stack.append(result)
return [global_state] return [global_state]
result, condition = keccak_function_manager.create_keccak(data) result = keccak_function_manager.create_keccak(data)
state.stack.append(result) state.stack.append(result)
global_state.world_state.constraints.append(condition)
return [global_state] return [global_state]
@ -1538,6 +1537,7 @@ class Instruction:
states = [] states = []
op0, condition = state.stack.pop(), state.stack.pop() op0, condition = state.stack.pop(), state.stack.pop()
try: try:
jump_addr = util.get_concrete_int(op0) jump_addr = util.get_concrete_int(op0)
except TypeError: except TypeError:
@ -1740,7 +1740,7 @@ class Instruction:
if create2_salt.size() != 256: if create2_salt.size() != 256:
pad = symbol_factory.BitVecVal(0, 256 - create2_salt.size()) pad = symbol_factory.BitVecVal(0, 256 - create2_salt.size())
create2_salt = Concat(pad, create2_salt) create2_salt = Concat(pad, create2_salt)
address, constraint = keccak_function_manager.create_keccak( address = keccak_function_manager.create_keccak(
Concat( Concat(
symbol_factory.BitVecVal(255, 8), symbol_factory.BitVecVal(255, 8),
caller, caller,
@ -1749,7 +1749,7 @@ class Instruction:
) )
) )
contract_address = Extract(255, 96, address) contract_address = Extract(255, 96, address)
global_state.world_state.constraints.append(constraint)
else: else:
salt = hex(create2_salt.value)[2:] salt = hex(create2_salt.value)[2:]
salt = "0" * (64 - len(salt)) + salt salt = "0" * (64 - len(salt)) + salt

@ -4,6 +4,7 @@ from mythril.exceptions import UnsatError
from mythril.laser.smt import symbol_factory, simplify, Bool from mythril.laser.smt import symbol_factory, simplify, Bool
from mythril.support.model import get_model from mythril.support.model import get_model
from typing import Iterable, List, Optional, Union from typing import Iterable, List, Optional, Union
from mythril.laser.ethereum.function_managers import keccak_function_manager
class Constraints(list): class Constraints(list):
@ -29,7 +30,7 @@ class Constraints(list):
""" """
try: try:
get_model(tuple(self[:])) get_model(self)
except UnsatError: except UnsatError:
return False return False
return True return True
@ -108,5 +109,8 @@ class Constraints(list):
for constraint in constraints for constraint in constraints
] ]
def get_all_constraints(self):
return self[:] + [keccak_function_manager.create_conditions()]
def __hash__(self): def __hash__(self):
return tuple(self[:]).__hash__() return tuple(self[:]).__hash__()

@ -255,12 +255,14 @@ class LaserEVM:
except NotImplementedError: except NotImplementedError:
log.debug("Encountered unimplemented instruction") log.debug("Encountered unimplemented instruction")
continue continue
if args.sparse_pruning is False: if args.sparse_pruning is False:
new_states = [ new_states = [
state state
for state in new_states for state in new_states
if state.world_state.constraints.is_possible if state.world_state.constraints.is_possible
] ]
self.manage_cfg(op_code, new_states) # TODO: What about op_code is None? self.manage_cfg(op_code, new_states) # TODO: What about op_code is None?
if new_states: if new_states:
self.work_list += new_states self.work_list += new_states

@ -32,7 +32,7 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True
for constraint in constraints: for constraint in constraints:
if type(constraint) == bool and not constraint: if type(constraint) == bool and not constraint:
raise UnsatError raise UnsatError
constraints = constraints.get_all_constraints()
constraints = [constraint for constraint in constraints if type(constraint) != bool] constraints = [constraint for constraint in constraints if type(constraint) != bool]
for constraint in constraints: for constraint in constraints:

@ -8,10 +8,16 @@ from tests import PROJECT_DIR, TESTDATA
MYTH = str(PROJECT_DIR / "myth") MYTH = str(PROJECT_DIR / "myth")
test_data = ( test_data = (
("suicide.sol", [], "0.5.0"), ("suicide.sol", [], "0.5.0"),
("overflow.sol", ["balanceOf(address)"], "0.5.0"), ("overflow.sol", ["balanceOf(address)", "totalSupply()"], "0.5.0"),
( (
"ether_send.sol", "ether_send.sol",
["crowdfunding()", "withdrawfunds()", "owner()", "balances(address)"], [
"crowdfunding()",
"withdrawfunds()",
"owner()",
"balances(address)",
"getBalance()",
],
"0.5.0", "0.5.0",
), ),
) )

@ -29,10 +29,10 @@ import pytest
) )
def test_keccak_basic(input1, input2, expected): def test_keccak_basic(input1, input2, expected):
s = Solver() s = Solver()
keccak_function_manager.reset()
o1, c1 = keccak_function_manager.create_keccak(input1) o1 = keccak_function_manager.create_keccak(input1)
o2, c2 = keccak_function_manager.create_keccak(input2) o2 = keccak_function_manager.create_keccak(input2)
s.add(And(c1, c2)) s.add(keccak_function_manager.create_conditions())
s.add(o1 == o2) s.add(o1 == o2)
assert s.check() == expected assert s.check() == expected
@ -44,11 +44,13 @@ def test_keccak_symbol_and_val():
:return: :return:
""" """
s = Solver() s = Solver()
keccak_function_manager.reset()
hundred = symbol_factory.BitVecVal(100, 256) hundred = symbol_factory.BitVecVal(100, 256)
n = symbol_factory.BitVecSym("n", 256) n = symbol_factory.BitVecSym("n", 256)
o1, c1 = keccak_function_manager.create_keccak(hundred) o1 = keccak_function_manager.create_keccak(hundred)
o2, c2 = keccak_function_manager.create_keccak(n) o2 = keccak_function_manager.create_keccak(n)
s.add(And(c1, c2))
s.add(keccak_function_manager.create_conditions())
s.add(o1 == o2) s.add(o1 == o2)
s.add(n == symbol_factory.BitVecVal(10, 256)) s.add(n == symbol_factory.BitVecVal(10, 256))
assert s.check() == z3.unsat assert s.check() == z3.unsat
@ -59,19 +61,20 @@ def test_keccak_complex_eq():
check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b
:return: :return:
""" """
keccak_function_manager.reset()
s = Solver() s = Solver()
a = symbol_factory.BitVecSym("a", 160) a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160) b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a) o1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b) o2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256) two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1 o1 = two * o1
o2 = two * o2 o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1) o1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2) o2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2)) s.add(keccak_function_manager.create_conditions())
s.add(o1 == o2) s.add(o1 == o2)
s.add(a != b) s.add(a != b)
@ -85,19 +88,20 @@ def test_keccak_complex_eq2():
(solution is literally the opposite of prev one) so it will take forever to solve. (solution is literally the opposite of prev one) so it will take forever to solve.
:return: :return:
""" """
keccak_function_manager.reset()
s = Solver() s = Solver()
a = symbol_factory.BitVecSym("a", 160) a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160) b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a) o1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b) o2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256) two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1 o1 = two * o1
o2 = two * o2 o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1) o1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2) o2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2)) s.add(keccak_function_manager.create_conditions())
s.add(o1 == o2) s.add(o1 == o2)
assert s.check() == z3.sat assert s.check() == z3.sat
@ -108,12 +112,13 @@ def test_keccak_simple_number():
check for keccak(b) == 10 check for keccak(b) == 10
:return: :return:
""" """
keccak_function_manager.reset()
s = Solver() s = Solver()
a = symbol_factory.BitVecSym("a", 160) a = symbol_factory.BitVecSym("a", 160)
ten = symbol_factory.BitVecVal(10, 256) ten = symbol_factory.BitVecVal(10, 256)
o, c = keccak_function_manager.create_keccak(a) o = keccak_function_manager.create_keccak(a)
s.add(c) s.add(keccak_function_manager.create_conditions())
s.add(ten == o) s.add(ten == o)
assert s.check() == z3.unsat assert s.check() == z3.unsat
@ -124,15 +129,17 @@ def test_keccak_other_num():
check keccak(keccak(a)*2) == b check keccak(keccak(a)*2) == b
:return: :return:
""" """
keccak_function_manager.reset()
s = Solver() s = Solver()
a = symbol_factory.BitVecSym("a", 160) a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 256) b = symbol_factory.BitVecSym("b", 256)
o, c = keccak_function_manager.create_keccak(a) o = keccak_function_manager.create_keccak(a)
two = symbol_factory.BitVecVal(2, 256) two = symbol_factory.BitVecVal(2, 256)
o = two * o o = two * o
s.add(c)
o, c = keccak_function_manager.create_keccak(o) o = keccak_function_manager.create_keccak(o)
s.add(c)
s.add(keccak_function_manager.create_conditions())
s.add(b == o) s.add(b == o)
assert s.check() == z3.sat assert s.check() == z3.sat

@ -36,7 +36,7 @@ class SolidityContractTest(BaseTestCase):
str(input_file), name="AssertFail", solc_binary=solc_binary str(input_file), name="AssertFail", solc_binary=solc_binary
) )
code_info = contract.get_source_info(58, constructor=True) code_info = contract.get_source_info(75, constructor=True)
self.assertEqual(code_info.filename, str(input_file)) self.assertEqual(code_info.filename, str(input_file))
self.assertEqual(code_info.lineno, 6) self.assertEqual(code_info.lineno, 6)

Loading…
Cancel
Save