Merge branch 'develop' into specify-attacker-creator-address

specify-attacker-creator-address
Nathan 5 years ago committed by GitHub
commit ed06c3148e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      mythril/analysis/modules/exceptions.py
  2. 4
      mythril/analysis/modules/suicide.py
  3. 58
      mythril/analysis/solver.py
  4. 3
      mythril/laser/ethereum/cfg.py
  5. 38
      mythril/laser/ethereum/instructions.py
  6. 117
      mythril/laser/ethereum/keccak_function_manager.py
  7. 31
      mythril/laser/ethereum/state/account.py
  8. 5
      mythril/laser/ethereum/state/constraints.py
  9. 7
      mythril/laser/ethereum/strategy/extensions/bounded_loops.py
  10. 7
      mythril/laser/ethereum/svm.py
  11. 3
      mythril/laser/ethereum/transaction/concolic.py
  12. 2
      mythril/laser/ethereum/transaction/symbolic.py
  13. 22
      mythril/laser/smt/__init__.py
  14. 5
      mythril/laser/smt/bitvec_helper.py
  15. 25
      mythril/laser/smt/function.py
  16. 4
      mythril/laser/smt/solver/solver.py
  17. 5
      tests/instructions/create_test.py
  18. 11
      tests/laser/evm_testsuite/evm_test.py
  19. 138
      tests/laser/keccak_tests.py

@ -55,11 +55,9 @@ class ReachableExceptionsModule(DetectionModule):
"Note that explicit `assert()` should only be used to check invariants. "
"Use `require()` for regular input checking."
)
transaction_sequence = solver.get_transaction_sequence(
state, state.mstate.constraints
)
issue = Issue(
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name,

@ -9,7 +9,7 @@ from mythril.laser.ethereum.transaction.transaction_models import (
ContractCreationTransaction,
)
import logging
import json
log = logging.getLogger(__name__)
@ -69,7 +69,6 @@ class SuicideModule(DetectionModule):
for tx in state.world_state.transaction_sequence:
if not isinstance(tx, ContractCreationTransaction):
constraints.append(tx.caller == ACTORS.attacker)
try:
try:
transaction_sequence = solver.get_transaction_sequence(
@ -85,7 +84,6 @@ class SuicideModule(DetectionModule):
state, state.mstate.constraints + constraints
)
description_tail = "Arbitrary senders can kill this contract."
issue = Issue(
contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name,

@ -1,12 +1,16 @@
"""This module contains analysis module helpers to solve path constraints."""
from functools import lru_cache
from typing import Dict, Tuple, Union
from typing import Dict, List, Tuple, Union
from z3 import sat, unknown, FuncInterp
import z3
from mythril.analysis.analysis_args import analysis_args
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.keccak_function_manager import (
keccak_function_manager,
hash_matcher,
)
from mythril.laser.ethereum.transaction import BaseTransaction
from mythril.laser.smt import UGE, Optimize, symbol_factory
from mythril.laser.ethereum.time_handler import time_handler
@ -18,6 +22,7 @@ import logging
log = logging.getLogger(__name__)
# LRU cache works great when used in powers of 2
@lru_cache(maxsize=2 ** 23)
def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True):
@ -48,7 +53,6 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True
s.minimize(e)
for e in maximize:
s.maximize(e)
result = s.check()
if result == sat:
return s.model()
@ -97,7 +101,6 @@ def get_transaction_sequence(
tx_constraints, minimize = _set_minimisation_constraints(
transaction_sequence, constraints.copy(), [], 5000, global_state.world_state
)
try:
model = get_model(tx_constraints, minimize=minimize)
except UnsatError:
@ -122,12 +125,59 @@ def get_transaction_sequence(
).as_long()
concrete_initial_state = _get_concrete_state(initial_accounts, min_price_dict)
if isinstance(transaction_sequence[0], ContractCreationTransaction):
code = transaction_sequence[0].code
_replace_with_actual_sha(concrete_transactions, model, code)
else:
_replace_with_actual_sha(concrete_transactions, model)
steps = {"initialState": concrete_initial_state, "steps": concrete_transactions}
return steps
def _replace_with_actual_sha(
concrete_transactions: List[Dict[str, str]], model: z3.Model, code=None
):
for tx in concrete_transactions:
if hash_matcher not in tx["input"]:
continue
if code is not None and code.bytecode in tx["input"]:
s_index = len(code.bytecode) + 2
else:
s_index = 10
for i in range(s_index, len(tx["input"])):
data_slice = tx["input"][i : i + 64]
if hash_matcher not in data_slice or len(data_slice) != 64:
continue
find_input = symbol_factory.BitVecVal(int(data_slice, 16), 256)
input_ = None
for size in keccak_function_manager.store_function:
_, inverse = keccak_function_manager.get_function(size)
try:
input_ = symbol_factory.BitVecVal(
model.eval(inverse(find_input).raw).as_long(), size
)
except AttributeError:
continue
hex_input = hex(input_.value)[2:]
found = False
for new_tx in concrete_transactions:
if hex_input in new_tx["input"]:
found = True
break
if found:
break
if input_ is None:
continue
keccak = keccak_function_manager.find_concrete_keccak(input_)
hex_keccak = hex(keccak.value)[2:]
if len(hex_keccak) != 64:
hex_keccak = "0" * (64 - len(hex_keccak)) + hex_keccak
tx["input"] = tx["input"][:s_index] + tx["input"][s_index:].replace(
tx["input"][i : 64 + i], hex_keccak
)
def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]):
""" Gets a concrete state """
accounts = {}

@ -2,6 +2,7 @@
from enum import Enum
from typing import Dict, List, TYPE_CHECKING
from mythril.laser.ethereum.state.constraints import Constraints
from flags import Flags
if TYPE_CHECKING:
@ -46,7 +47,7 @@ class Node:
:param start_addr:
:param constraints:
"""
constraints = constraints if constraints else []
constraints = constraints if constraints else Constraints()
self.contract_name = contract_name
self.start_addr = start_addr
self.states = [] # type: List[GlobalState]

@ -34,6 +34,7 @@ from mythril.laser.ethereum.state.calldata import ConcreteCalldata, SymbolicCall
import mythril.laser.ethereum.util as helper
from mythril.laser.ethereum import util
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
from mythril.laser.ethereum.call import get_call_parameters, native_call, get_call_data
from mythril.laser.ethereum.evm_exceptions import (
VmException,
@ -982,7 +983,7 @@ class Instruction:
if isinstance(op0, Expression):
op0 = simplify(op0)
state.stack.append(
symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256)
symbol_factory.BitVecSym("KECCAC_mem[{}]".format(hash(op0)), 256)
)
gas_tuple = get_opcode_gas("SHA3")
state.min_gas_used += gas_tuple[0]
@ -996,40 +997,21 @@ class Instruction:
b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
for b in state.memory[index : index + length]
]
if len(data_list) > 1:
data = simplify(Concat(data_list))
elif len(data_list) == 1:
data = data_list[0]
else:
# length is 0; this only matters for input of the BitVecFuncVal
data = symbol_factory.BitVecVal(0, 1)
if data.symbolic:
annotations = set() # type: Set[Any]
for b in state.memory[index : index + length]:
if isinstance(b, BitVec):
annotations = annotations.union(b.annotations)
argument_hash = hash(state.memory[index])
result = symbol_factory.BitVecFuncSym(
"KECCAC[invhash({})]".format(hash(argument_hash)),
"keccak256",
256,
input_=data,
annotations=annotations,
)
log.debug("Created BitVecFunc hash.")
else:
keccak = utils.sha3(data.value.to_bytes(length, byteorder="big"))
result = symbol_factory.BitVecFuncVal(
util.concrete_int_from_bytes(keccak, 0), "keccak256", 256, input_=data
)
log.debug("Computed SHA3 Hash: " + str(binascii.hexlify(keccak)))
# TODO: handle finding x where func(x)==func("")
result = keccak_function_manager.get_empty_keccak_hash()
state.stack.append(result)
return [global_state]
result, condition = keccak_function_manager.create_keccak(data)
state.stack.append(result)
state.constraints.append(condition)
return [global_state]
@StateTransition()

@ -0,0 +1,117 @@
from ethereum import utils
from mythril.laser.smt import (
BitVec,
Function,
URem,
symbol_factory,
ULE,
And,
ULT,
Bool,
Or,
)
from typing import Dict, Tuple, List
TOTAL_PARTS = 10 ** 40
PART = (2 ** 256 - 1) // TOTAL_PARTS
INTERVAL_DIFFERENCE = 10 ** 30
hash_matcher = "fffffff" # This is usually the prefix for the hash in the output
class KeccakFunctionManager:
"""
A bunch of uninterpreted functions are considered like keccak256_160 ,...
where keccak256_160 means the input of keccak256() is 160 bit number.
the range of these functions are constrained to some mutually disjoint intervals
All the hashes modulo 64 are 0 as we need a spread among hashes for array type data structures
All the functions are kind of one to one due to constraint of the existence of inverse
for each encountered input.
For more info https://files.sri.inf.ethz.ch/website/papers/sp20-verx.pdf
"""
def __init__(self):
self.store_function = {} # type: Dict[int, Tuple[Function, Function]]
self.interval_hook_for_size = {} # type: Dict[int, int]
self._index_counter = TOTAL_PARTS - 34534
self.quick_inverse = {} # type: Dict[BitVec, BitVec] # This is for VMTests
@staticmethod
def find_concrete_keccak(data: BitVec) -> BitVec:
"""
Calculates concrete keccak
:param data: input bitvecval
:return: concrete keccak output
"""
keccak = symbol_factory.BitVecVal(
int.from_bytes(
utils.sha3(data.value.to_bytes(data.size() // 8, byteorder="big")),
"big",
),
256,
)
return keccak
def get_function(self, length: int) -> Tuple[Function, Function]:
"""
Returns the keccak functions for the corresponding length
:param length: input size
:return: tuple of keccak and it's inverse
"""
try:
func, inverse = self.store_function[length]
except KeyError:
func = Function("keccak256_{}".format(length), length, 256)
inverse = Function("keccak256_{}-1".format(length), 256, length)
self.store_function[length] = (func, inverse)
return func, inverse
@staticmethod
def get_empty_keccak_hash() -> BitVec:
"""
returns sha3("")
:return:
"""
val = 89477152217924674838424037953991966239322087453347756267410168184682657981552
return symbol_factory.BitVecVal(val, 256)
def create_keccak(self, data: BitVec) -> Tuple[BitVec, Bool]:
"""
Creates Keccak of the data
:param data: input
:return: Tuple of keccak and the condition it should satisfy
"""
length = data.size()
func, inverse = self.get_function(length)
condition = self._create_condition(func_input=data)
self.quick_inverse[func(data)] = data
return func(data), condition
def _create_condition(self, func_input: BitVec) -> Bool:
"""
Creates the constraints for hash
:param func_input: input of the hash
:return: condition
"""
length = func_input.size()
func, inv = self.get_function(length)
try:
index = self.interval_hook_for_size[length]
except KeyError:
self.interval_hook_for_size[length] = self._index_counter
index = self._index_counter
self._index_counter -= INTERVAL_DIFFERENCE
lower_bound = index * PART
upper_bound = lower_bound + PART
cond = And(
inv(func(func_input)) == func_input,
ULE(symbol_factory.BitVecVal(lower_bound, 256), func(func_input)),
ULT(func(func_input), symbol_factory.BitVecVal(upper_bound, 256)),
URem(func(func_input), symbol_factory.BitVecVal(64, 256)) == 0,
)
return cond
keccak_function_manager = KeccakFunctionManager()

@ -55,7 +55,6 @@ class Storage:
self._standard_storage = K(256, 256, 0) # type: BaseArray
else:
self._standard_storage = Array("Storage", 256, 256)
self._map_storage = {} # type: Dict[BitVec, BaseArray]
self.printable_storage = {} # type: Dict[BitVec, BitVec]
@ -73,10 +72,7 @@ class Storage:
return Concat(symbol_factory.BitVecVal(0, 512 - input_.size()), input_)
def __getitem__(self, item: BitVec) -> BitVec:
storage, is_keccak_storage = self._get_corresponding_storage(item)
if is_keccak_storage:
sanitized_item = self._sanitize(cast(BitVecFunc, item).input_)
else:
storage = self._standard_storage
sanitized_item = item
if (
self.address
@ -100,7 +96,6 @@ class Storage:
self.printable_storage[item] = storage[sanitized_item]
except ValueError as e:
log.debug("Couldn't read storage at %s: %s", item, e)
return simplify(storage[sanitized_item])
@staticmethod
@ -114,29 +109,12 @@ class Storage:
index = Extract(255, 0, key.input_)
return simplify(index)
def _get_corresponding_storage(self, key: BitVec) -> Tuple[BaseArray, bool]:
index = self.get_map_index(key)
if index is None:
storage = self._standard_storage
is_keccak_storage = False
else:
storage_map = self._map_storage
try:
storage = storage_map[index]
except KeyError:
if isinstance(self._standard_storage, Array):
storage_map[index] = Array("Storage", 512, 256)
else:
storage_map[index] = K(512, 256, 0)
storage = storage_map[index]
is_keccak_storage = True
return storage, is_keccak_storage
def _get_corresponding_storage(self, key: BitVec) -> BaseArray:
return self._standard_storage
def __setitem__(self, key, value: Any) -> None:
storage, is_keccak_storage = self._get_corresponding_storage(key)
storage = self._get_corresponding_storage(key)
self.printable_storage[key] = value
if is_keccak_storage:
key = self._sanitize(key.input_)
storage[key] = value
if key.symbolic is False:
self.storage_keys_loaded.add(int(key.value))
@ -147,7 +125,6 @@ class Storage:
concrete=concrete, address=self.address, dynamic_loader=self.dynld
)
storage._standard_storage = deepcopy(self._standard_storage)
storage._map_storage = deepcopy(self._map_storage)
storage.printable_storage = copy(self.printable_storage)
storage.storage_keys_loaded = copy(self.storage_keys_loaded)
return storage

@ -35,6 +35,7 @@ class Constraints(list):
"""
:return: True/False based on the existence of solution of constraints
"""
if self._is_possible is not None:
return self._is_possible
solver = Solver()
@ -109,8 +110,8 @@ class Constraints(list):
:param constraints:
:return:
"""
constraints = self._get_smt_bool_list(constraints)
super(Constraints, self).__iadd__(constraints)
list_constraints = self._get_smt_bool_list(constraints)
super(Constraints, self).__iadd__(list_constraints)
self._is_possible = None
return self

@ -3,6 +3,7 @@ from mythril.laser.ethereum.strategy.basic import BasicSearchStrategy
from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum.transaction import ContractCreationTransaction
from typing import Dict, cast, List
from copy import copy
import logging
@ -16,7 +17,9 @@ class JumpdestCountAnnotation(StateAnnotation):
self._reached_count = {} # type: Dict[str, int]
def __copy__(self):
return self
result = JumpdestCountAnnotation()
result._reached_count = copy(self._reached_count)
return result
class BoundedLoopsStrategy(BasicSearchStrategy):
@ -45,6 +48,7 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
:return: Global state
"""
while True:
state = self.super_strategy.get_strategic_global_state()
@ -56,7 +60,6 @@ class BoundedLoopsStrategy(BasicSearchStrategy):
if len(annotations) == 0:
annotation = JumpdestCountAnnotation()
log.debug("Adding JumpdestCountAnnotation to GlobalState")
state.annotate(annotation)
else:
annotation = annotations[0]

@ -20,6 +20,7 @@ from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy
from abc import ABCMeta
from mythril.laser.ethereum.time_handler import time_handler
from mythril.laser.ethereum.transaction import (
ContractCreationTransaction,
TransactionEndSignal,
@ -29,6 +30,7 @@ from mythril.laser.ethereum.transaction import (
)
from mythril.laser.smt import symbol_factory
log = logging.getLogger(__name__)
@ -206,6 +208,7 @@ class LaserEVM:
i, len(self.open_states)
)
)
for hook in self._start_sym_trans_hooks:
hook()
@ -245,7 +248,6 @@ class LaserEVM:
except NotImplementedError:
log.debug("Encountered unimplemented instruction")
continue
new_states = [
state for state in new_states if state.mstate.constraints.is_possible
]
@ -357,16 +359,15 @@ class LaserEVM:
]
log.debug("Ending transaction %s.", transaction)
if return_global_state is None:
if (
not isinstance(transaction, ContractCreationTransaction)
or transaction.return_data
) and not end_signal.revert:
check_potential_issues(global_state)
end_signal.global_state.world_state.node = global_state.node
self._add_world_state(end_signal.global_state)
new_global_states = []
else:
# First execute the post hook for the transaction ending instruction

@ -88,6 +88,9 @@ def _setup_global_state_for_execution(laser_evm, transaction) -> None:
condition=None,
)
)
global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node
new_node.states.append(global_state)

@ -185,7 +185,7 @@ def _setup_global_state_for_execution(laser_evm, transaction: BaseTransaction) -
)
global_state.mstate.constraints += transaction.world_state.node.constraints
new_node.constraints = global_state.mstate.constraints.as_list
new_node.constraints = global_state.mstate.constraints
global_state.world_state.transaction_sequence.append(transaction)
global_state.node = new_node

@ -22,6 +22,7 @@ from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And
from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.function import Function
from mythril.laser.smt.solver import Solver, Optimize, SolverStatistics
from mythril.laser.smt.model import Model
from mythril.laser.smt.bool import Bool as SMTBool
@ -47,6 +48,16 @@ class SymbolFactory(Generic[T, U]):
"""
raise NotImplementedError
@staticmethod
def BoolSym(name: str, annotations: Annotations = None) -> T:
"""
Creates a boolean symbol
:param name: The name of the Bool variable
:param annotations: The annotations to initialize the bool with
:return: The freshly created Bool()
"""
raise NotImplementedError
@staticmethod
def BitVecVal(value: int, size: int, annotations: Annotations = None) -> U:
"""Creates a new bit vector with a concrete value.
@ -125,6 +136,17 @@ class _SmtSymbolFactory(SymbolFactory[SMTBool, BitVec]):
raw = z3.BoolVal(value)
return SMTBool(raw, annotations)
@staticmethod
def BoolSym(name: str, annotations: Annotations = None) -> SMTBool:
"""
Creates a boolean symbol
:param name: The name of the Bool variable
:param annotations: The annotations to initialize the bool with
:return: The freshly created Bool()
"""
raw = z3.Bool(name)
return SMTBool(raw, annotations)
@staticmethod
def BitVecVal(value: int, size: int, annotations: Annotations = None) -> BitVec:
"""Creates a new bit vector with a concrete value."""

@ -1,8 +1,7 @@
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift, ne, eq
from typing import Union, overload, List, Set, cast, Any, Callable
import z3
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.bool import Bool, Or
from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper

@ -0,0 +1,25 @@
from typing import cast
import z3
from mythril.laser.smt.bitvec import BitVec
class Function:
"""An uninterpreted function."""
def __init__(self, name: str, domain: int, value_range: int):
"""Initializes an uninterpreted function.
:param name: Name of the Function
:param domain: The domain for the Function (10 -> all the values that a bv of size 10 could take)
:param value_range: The range for the values of the function (10 -> all the values that a bv of size 10 could take)
"""
self.domain = z3.BitVecSort(domain)
self.range = z3.BitVecSort(value_range)
self.raw = z3.Function(name, self.domain, self.range)
def __call__(self, item: BitVec) -> BitVec:
"""Function accessor, item can be symbolic."""
return BitVec(
cast(z3.BitVecRef, self.raw(item.raw)), annotations=item.annotations
)

@ -45,14 +45,14 @@ class BaseSolver(Generic[T]):
self.add(*constraints)
@stat_smt_query
def check(self) -> z3.CheckSatResult:
def check(self, *args) -> z3.CheckSatResult:
"""Returns z3 smt check result.
Also suppresses the stdout when running z3 library's check() to avoid unnecessary output
:return: The evaluated result which is either of sat, unsat or unknown
"""
old_stdout = sys.stdout
sys.stdout = open(os.devnull, "w")
evaluate = self.raw.check()
evaluate = self.raw.check(args)
sys.stdout = old_stdout
return evaluate

@ -1,6 +1,6 @@
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.cfg import Node
from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.machine_state import MachineState
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState
@ -30,11 +30,12 @@ def execute_create():
calldata = ConcreteCalldata(0, code_raw)
world_state = WorldState()
world_state.node = Node("Contract")
account = world_state.create_account(balance=1000000, address=101)
account.code = Disassembly("60a760006000f000")
environment = Environment(account, None, calldata, None, None, None)
og_state = GlobalState(
world_state, environment, None, MachineState(gas_limit=8000000)
world_state, environment, world_state.node, MachineState(gas_limit=8000000)
)
og_state.transaction_stack.append(
(MessageCallTransaction(world_state=WorldState(), gas_limit=8000000), None)

@ -1,10 +1,10 @@
from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.transaction.concolic import execute_message_call
from mythril.laser.smt import Expression, BitVec, symbol_factory
from mythril.analysis.solver import get_model
from datetime import datetime
import binascii
@ -117,7 +117,6 @@ def test_vmtest(
# Arrange
if test_name in ignored_test_names:
return
world_state = WorldState()
for address, details in pre_condition.items():
@ -178,6 +177,14 @@ def test_vmtest(
expected = int(value, 16)
actual = account.storage[symbol_factory.BitVecVal(int(index, 16), 256)]
if isinstance(actual, Expression):
if (
actual.symbolic
and actual in keccak_function_manager.quick_inverse
):
actual = keccak_function_manager.find_concrete_keccak(
keccak_function_manager.quick_inverse[actual]
)
else:
actual = actual.value
actual = 1 if actual is True else 0 if actual is False else actual
else:

@ -0,0 +1,138 @@
from mythril.laser.smt import Solver, symbol_factory, And
from mythril.laser.ethereum.keccak_function_manager import keccak_function_manager
import z3
import pytest
@pytest.mark.parametrize(
"input1, input2, expected",
[
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(101, 8), z3.unsat),
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 16), z3.unsat),
(symbol_factory.BitVecVal(100, 8), symbol_factory.BitVecVal(100, 8), z3.sat),
(
symbol_factory.BitVecSym("N1", 256),
symbol_factory.BitVecSym("N2", 256),
z3.sat,
),
(
symbol_factory.BitVecVal(100, 256),
symbol_factory.BitVecSym("N1", 256),
z3.sat,
),
(
symbol_factory.BitVecVal(100, 8),
symbol_factory.BitVecSym("N1", 256),
z3.unsat,
),
],
)
def test_keccak_basic(input1, input2, expected):
s = Solver()
o1, c1 = keccak_function_manager.create_keccak(input1)
o2, c2 = keccak_function_manager.create_keccak(input2)
s.add(And(c1, c2))
s.add(o1 == o2)
assert s.check() == expected
def test_keccak_symbol_and_val():
"""
check keccak(100) == keccak(n) && n == 10
:return:
"""
s = Solver()
hundred = symbol_factory.BitVecVal(100, 256)
n = symbol_factory.BitVecSym("n", 256)
o1, c1 = keccak_function_manager.create_keccak(hundred)
o2, c2 = keccak_function_manager.create_keccak(n)
s.add(And(c1, c2))
s.add(o1 == o2)
s.add(n == symbol_factory.BitVecVal(10, 256))
assert s.check() == z3.unsat
def test_keccak_complex_eq():
"""
check for keccak(keccak(b)*2) == keccak(keccak(a)*2) && a != b
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1
o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2))
s.add(o1 == o2)
s.add(a != b)
assert s.check() == z3.unsat
def test_keccak_complex_eq2():
"""
check for keccak(keccak(b)*2) == keccak(keccak(a)*2)
This isn't combined with prev test because incremental solving here requires extra-extra work
(solution is literally the opposite of prev one) so it will take forever to solve.
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 160)
o1, c1 = keccak_function_manager.create_keccak(a)
o2, c2 = keccak_function_manager.create_keccak(b)
s.add(And(c1, c2))
two = symbol_factory.BitVecVal(2, 256)
o1 = two * o1
o2 = two * o2
o1, c1 = keccak_function_manager.create_keccak(o1)
o2, c2 = keccak_function_manager.create_keccak(o2)
s.add(And(c1, c2))
s.add(o1 == o2)
assert s.check() == z3.sat
def test_keccak_simple_number():
"""
check for keccak(b) == 10
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
ten = symbol_factory.BitVecVal(10, 256)
o, c = keccak_function_manager.create_keccak(a)
s.add(c)
s.add(ten == o)
assert s.check() == z3.unsat
def test_keccak_other_num():
"""
check keccak(keccak(a)*2) == b
:return:
"""
s = Solver()
a = symbol_factory.BitVecSym("a", 160)
b = symbol_factory.BitVecSym("b", 256)
o, c = keccak_function_manager.create_keccak(a)
two = symbol_factory.BitVecVal(2, 256)
o = two * o
s.add(c)
o, c = keccak_function_manager.create_keccak(o)
s.add(c)
s.add(b == o)
assert s.check() == z3.sat
Loading…
Cancel
Save