Merge branch 'develop' into model-balances

model-balances
Nathan 5 years ago
commit 01c7afd6a7
  1. 39
      README.md
  2. 16
      all_tests.sh
  3. 2
      mythril/__version__.py
  4. 30
      mythril/analysis/analysis_args.py
  5. 5
      mythril/analysis/modules/deprecated_ops.py
  6. 3
      mythril/analysis/modules/dos.py
  7. 3
      mythril/analysis/modules/ether_thief.py
  8. 4
      mythril/analysis/modules/state_change_external_calls.py
  9. 26
      mythril/analysis/report.py
  10. 7
      mythril/analysis/solver.py
  11. 2
      mythril/analysis/symbolic.py
  12. 2
      mythril/analysis/templates/report_as_markdown.jinja2
  13. 2
      mythril/analysis/templates/report_as_text.jinja2
  14. 6
      mythril/disassembler/disassembly.py
  15. 13
      mythril/interfaces/cli.py
  16. 7
      mythril/interfaces/old_cli.py
  17. 22
      mythril/laser/ethereum/state/account.py
  18. 7
      mythril/laser/smt/__init__.py
  19. 3
      mythril/laser/smt/array.py
  20. 276
      mythril/laser/smt/bitvec.py
  21. 291
      mythril/laser/smt/bitvec_helper.py
  22. 68
      mythril/laser/smt/bitvecfunc.py
  23. 5
      mythril/mythril/mythril_analyzer.py
  24. 52
      mythril/support/loader.py
  25. 165
      tests/laser/smt/bitvecfunc_test.py

@ -36,6 +36,45 @@ See the [Wiki](https://github.com/ConsenSys/mythril/wiki/Installation-and-Setup)
## Usage ## Usage
Run:
```
$ myth analyze <solidity-file>
```
Or:
```
$ myth analyze -a <contract-address>
```
Specify the maximum number of transaction to explore with `-t <number>`. You can also set a timeout with `--execution-timeout <seconds>`. Example ([source code](https://gist.github.com/b-mueller/2b251297ce88aa7628680f50f177a81a#file-killbilly-sol)):
```
==== Unprotected Selfdestruct ====
SWC ID: 106
Severity: High
Contract: KillBilly
Function name: commencekilling()
PC address: 354
Estimated Gas Usage: 574 - 999
The contract can be killed by anyone.
Anyone can kill this contract and withdraw its balance to an arbitrary address.
--------------------
In file: killbilly.sol:22
selfdestruct(msg.sender)
--------------------
Transaction Sequence:
Caller: [CREATOR], data: [CONTRACT CREATION], value: 0x0
Caller: [ATTACKER], function: killerize(address), txdata: 0x9fa299ccbebebebebebebebebebebebedeadbeefdeadbeefdeadbeefdeadbeefdeadbeef, value: 0x0
Caller: [ATTACKER], function: activatekillability(), txdata: 0x84057065, value: 0x0
Caller: [ATTACKER], function: commencekilling(), txdata: 0x7c11da20, value: 0x0
```
Instructions for using Mythril are found on the [Wiki](https://github.com/ConsenSys/mythril/wiki). Instructions for using Mythril are found on the [Wiki](https://github.com/ConsenSys/mythril/wiki).
For support or general discussions please join the Mythril community on [Discord](https://discord.gg/E3YrVtG). For support or general discussions please join the Mythril community on [Discord](https://discord.gg/E3YrVtG).

@ -7,22 +7,6 @@ assert sys.version_info[0:2] >= (3,5), \
"""Please make sure you are using Python 3.5 or later. """Please make sure you are using Python 3.5 or later.
You ran with {}""".format(sys.version)' || exit $? You ran with {}""".format(sys.version)' || exit $?
echo "Checking solc version..."
out=$(solc --version) || {
echo 2>&1 "Please make sure you have solc installed, version 0.4.21 or greater"
}
case $out in
*Version:\ 0.4.2[1-9]* )
echo $out
;;
* )
echo $out
echo "Please make sure your solc version is at least 0.4.21"
exit 1
;;
esac
echo "Checking that truffle is installed..." echo "Checking that truffle is installed..."
if ! which truffle ; then if ! which truffle ; then
echo "Please make sure you have etherum truffle installed (npm install -g truffle)" echo "Please make sure you have etherum truffle installed (npm install -g truffle)"

@ -4,4 +4,4 @@ This file is suitable for sourcing inside POSIX shell, e.g. bash as well
as for importing into Python. as for importing into Python.
""" """
__version__ = "v0.21.9" __version__ = "v0.21.12"

@ -0,0 +1,30 @@
from mythril.support.support_utils import Singleton
class AnalysisArgs(object, metaclass=Singleton):
"""
This module helps in preventing args being sent through multiple of classes to reach analysis modules
"""
def __init__(self):
self._loop_bound = 3
self._solver_timeout = 10000
def set_loop_bound(self, loop_bound: int):
if loop_bound is not None:
self._loop_bound = loop_bound
def set_solver_timeout(self, solver_timeout: int):
if solver_timeout is not None:
self._solver_timeout = solver_timeout
@property
def loop_bound(self):
return self._loop_bound
@property
def solver_timeout(self):
return self._solver_timeout
analysis_args = AnalysisArgs()

@ -35,6 +35,7 @@ class DeprecatedOperationsModule(DetectionModule):
if state.get_current_instruction()["address"] in self._cache: if state.get_current_instruction()["address"] in self._cache:
return return
issues = self._analyze_state(state) issues = self._analyze_state(state)
for issue in issues: for issue in issues:
self._cache.add(issue.address) self._cache.add(issue.address)
self._issues.extend(issues) self._issues.extend(issues)
@ -74,13 +75,13 @@ class DeprecatedOperationsModule(DetectionModule):
) )
swc_id = DEPRECATED_FUNCTIONS_USAGE swc_id = DEPRECATED_FUNCTIONS_USAGE
else: else:
return return []
try: try:
transaction_sequence = get_transaction_sequence( transaction_sequence = get_transaction_sequence(
state, state.mstate.constraints state, state.mstate.constraints
) )
except UnsatError: except UnsatError:
return return []
issue = Issue( issue = Issue(
contract=state.environment.active_account.contract_name, contract=state.environment.active_account.contract_name,
function_name=state.environment.active_function_name, function_name=state.environment.active_function_name,

@ -7,6 +7,7 @@ from mythril.analysis.swc_data import DOS_WITH_BLOCK_GAS_LIMIT
from mythril.analysis.report import Issue from mythril.analysis.report import Issue
from mythril.analysis.modules.base import DetectionModule from mythril.analysis.modules.base import DetectionModule
from mythril.analysis.solver import get_transaction_sequence, UnsatError from mythril.analysis.solver import get_transaction_sequence, UnsatError
from mythril.analysis.analysis_args import analysis_args
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum.state.annotation import StateAnnotation
from mythril.laser.ethereum import util from mythril.laser.ethereum import util
@ -90,7 +91,7 @@ class DosModule(DetectionModule):
else: else:
annotation.jump_targets[target] = 1 annotation.jump_targets[target] = 1
if annotation.jump_targets[target] > 2: if annotation.jump_targets[target] > min(2, analysis_args.loop_bound - 1):
annotation.loop_start = address annotation.loop_start = address
elif annotation.loop_start is not None: elif annotation.loop_start is not None:

@ -15,8 +15,7 @@ from mythril.analysis.swc_data import UNPROTECTED_ETHER_WITHDRAWAL
from mythril.exceptions import UnsatError from mythril.exceptions import UnsatError
from mythril.laser.ethereum.transaction import ContractCreationTransaction from mythril.laser.ethereum.transaction import ContractCreationTransaction
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.smt import UGT, Sum, symbol_factory, BVAddNoOverflow from mythril.laser.smt import UGT, Sum, symbol_factory, BVAddNoOverflow, If
from mythril.laser.smt.bitvec import If
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

@ -171,7 +171,9 @@ class StateChange(DetectionModule):
for annotation in annotations: for annotation in annotations:
if not annotation.state_change_states: if not annotation.state_change_states:
continue continue
vulnerabilities.append(annotation.get_issue(global_state)) issue = annotation.get_issue(global_state)
if issue:
vulnerabilities.append(issue)
return vulnerabilities return vulnerabilities
@staticmethod @staticmethod

@ -11,6 +11,7 @@ from mythril.analysis.swc_data import SWC_TO_TITLE
from mythril.support.source_support import Source from mythril.support.source_support import Source
from mythril.support.start_time import StartTime from mythril.support.start_time import StartTime
from mythril.support.support_utils import get_code_hash from mythril.support.support_utils import get_code_hash
from mythril.support.signatures import SignatureDB
from time import time from time import time
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -151,6 +152,30 @@ class Issue:
else: else:
self.source_mapping = self.address self.source_mapping = self.address
def resolve_function_names(self):
""" Resolves function names for each step """
if (
self.transaction_sequence is None
or "steps" not in self.transaction_sequence
):
return
signatures = SignatureDB()
for step in self.transaction_sequence["steps"]:
_hash = step["input"][:10]
try:
sig = signatures.get(_hash)
if len(sig) > 0:
step["name"] = sig[0]
else:
step["name"] = "unknown"
except ValueError:
step["name"] = "unknown"
class Report: class Report:
"""A report containing the content of multiple issues.""" """A report containing the content of multiple issues."""
@ -187,6 +212,7 @@ class Report:
""" """
m = hashlib.md5() m = hashlib.md5()
m.update((issue.contract + str(issue.address) + issue.title).encode("utf-8")) m.update((issue.contract + str(issue.address) + issue.title).encode("utf-8"))
issue.resolve_function_names()
self.issues[m.digest()] = issue self.issues[m.digest()] = issue
def as_text(self): def as_text(self):

@ -4,6 +4,7 @@ from typing import Dict, Tuple, Union
from z3 import sat, unknown, FuncInterp from z3 import sat, unknown, FuncInterp
import z3 import z3
from mythril.analysis.analysis_args import analysis_args
from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.state.constraints import Constraints
from mythril.laser.ethereum.transaction import BaseTransaction from mythril.laser.ethereum.transaction import BaseTransaction
@ -29,7 +30,7 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True
:return: :return:
""" """
s = Optimize() s = Optimize()
timeout = 100000 timeout = analysis_args.solver_timeout
if enforce_execution_time: if enforce_execution_time:
timeout = min(timeout, time_handler.time_remaining() - 500) timeout = min(timeout, time_handler.time_remaining() - 500)
if timeout <= 0: if timeout <= 0:
@ -132,8 +133,8 @@ def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]):
data = dict() # type: Dict[str, Union[int, str]] data = dict() # type: Dict[str, Union[int, str]]
data["nonce"] = account.nonce data["nonce"] = account.nonce
data["code"] = account.code.bytecode data["code"] = account.code.bytecode
data["storage"] = str(account.storage) data["storage"] = account.storage.printable_storage
data["balance"] = min_price_dict.get(address, 0) data["balance"] = hex(min_price_dict.get(address, 0))
accounts[hex(address)] = data accounts[hex(address)] = data
return {"accounts": accounts} return {"accounts": accounts}

@ -47,7 +47,7 @@ class SymExecWrapper:
dynloader=None, dynloader=None,
max_depth=22, max_depth=22,
execution_timeout=None, execution_timeout=None,
loop_bound=2, loop_bound=3,
create_timeout=None, create_timeout=None,
transaction_count=2, transaction_count=2,
modules=(), modules=(),

@ -32,7 +32,7 @@ In file: {{ issue.filename }}:{{ issue.lineno }}
{% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %} {% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}
Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }} Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }}
{% else %} {% else %}
Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, data: {{ step.input }}, value: {{ step.value }} Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, function: {{ step.name }}, txdata: {{ step.input }}, value: {{ step.value }}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endif %} {% endif %}

@ -27,7 +27,7 @@ Transaction Sequence:
{% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %} {% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}
Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }} Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }}
{% else %} {% else %}
Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, data: {{ step.input }}, value: {{ step.value }} Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, function: {{ step.name }}, txdata: {{ step.input }}, value: {{ step.value }}
{% endif %} {% endif %}
{% endfor %} {% endfor %}
{% endif %} {% endif %}

@ -84,10 +84,8 @@ def get_function_info(
# Append with missing 0s at the beginning # Append with missing 0s at the beginning
function_hash = "0x" + instruction_list[index]["argument"][2:].rjust(8, "0") function_hash = "0x" + instruction_list[index]["argument"][2:].rjust(8, "0")
function_names = signature_database.get(function_hash) function_names = signature_database.get(function_hash)
if len(function_names) > 1:
# In this case there was an ambiguous result if len(function_names) > 0:
function_name = "[{}] (ambiguous)".format(", ".join(function_names))
elif len(function_names) == 1:
function_name = function_names[0] function_name = function_names[0]
else: else:
function_name = "_function_" + function_hash function_name = "_function_" + function_hash

@ -346,11 +346,17 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser):
default=86400, default=86400,
help="The amount of seconds to spend on symbolic execution", help="The amount of seconds to spend on symbolic execution",
) )
options.add_argument(
"--solver-timeout",
type=int,
default=10000,
help="The maximum amount of time(in milli seconds) the solver spends for queries from analysis modules",
)
options.add_argument( options.add_argument(
"--create-timeout", "--create-timeout",
type=int, type=int,
default=10, default=10,
help="The amount of seconds to spend on " "the initial contract creation", help="The amount of seconds to spend on the initial contract creation",
) )
options.add_argument( options.add_argument(
"-l", "-l",
@ -360,8 +366,9 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser):
) )
options.add_argument( options.add_argument(
"--no-onchain-storage-access", "--no-onchain-storage-access",
"--no-onchain-access",
action="store_true", action="store_true",
help="turns off getting the data from onchain contracts", help="turns off getting the data from onchain contracts (both loading storage and contract code)",
) )
options.add_argument( options.add_argument(
@ -552,6 +559,8 @@ def execute_command(
enable_iprof=args.enable_iprof, enable_iprof=args.enable_iprof,
disable_dependency_pruning=args.disable_dependency_pruning, disable_dependency_pruning=args.disable_dependency_pruning,
onchain_storage_access=not args.no_onchain_storage_access, onchain_storage_access=not args.no_onchain_storage_access,
solver_timeout=args.solver_timeout,
requires_dynld=not args.no_onchain_storage_access,
) )
if not disassembler.contracts: if not disassembler.contracts:

@ -213,6 +213,12 @@ def create_parser(parser: argparse.ArgumentParser) -> None:
default=2, default=2,
help="Maximum number of transactions issued by laser", help="Maximum number of transactions issued by laser",
) )
options.add_argument(
"--solver-timeout",
type=int,
default=10000,
help="The maximum amount of time(in milli seconds) the solver spends for queries from analysis modules",
)
options.add_argument( options.add_argument(
"--execution-timeout", "--execution-timeout",
type=int, type=int,
@ -419,6 +425,7 @@ def execute_command(
enable_iprof=args.enable_iprof, enable_iprof=args.enable_iprof,
disable_dependency_pruning=args.disable_dependency_pruning, disable_dependency_pruning=args.disable_dependency_pruning,
onchain_storage_access=not args.no_onchain_storage_access, onchain_storage_access=not args.no_onchain_storage_access,
solver_timeout=args.solver_timeout,
) )
if args.disassemble: if args.disassemble:

@ -20,6 +20,26 @@ from mythril.disassembler.disassembly import Disassembly
from mythril.laser.smt import symbol_factory from mythril.laser.smt import symbol_factory
class StorageRegion:
def __getitem__(self, item):
raise NotImplementedError
def __setitem__(self, key, value):
raise NotImplementedError
class ArrayStorageRegion(StorageRegion):
""" An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions"""
pass
class IteStorageRegion(StorageRegion):
""" An IteStorageRegion is a storage region that uses Ite statements to implement a storage"""
pass
class Storage: class Storage:
"""Storage class represents the storage of an Account.""" """Storage class represents the storage of an Account."""
@ -114,7 +134,7 @@ class Storage:
key = self._sanitize(key.input_) key = self._sanitize(key.input_)
storage[key] = value storage[key] = value
def __deepcopy__(self, memodict={}): def __deepcopy__(self, memodict=dict()):
concrete = isinstance(self._standard_storage, K) concrete = isinstance(self._standard_storage, K)
storage = Storage( storage = Storage(
concrete=concrete, address=self.address, dynamic_loader=self.dynld concrete=concrete, address=self.address, dynamic_loader=self.dynld

@ -1,8 +1,10 @@
from mythril.laser.smt.bitvec import ( from mythril.laser.smt.bitvec import BitVec
BitVec,
from mythril.laser.smt.bitvec_helper import (
If, If,
UGT, UGT,
ULT, ULT,
ULE,
Concat, Concat,
Extract, Extract,
URem, URem,
@ -15,6 +17,7 @@ from mythril.laser.smt.bitvec import (
BVSubNoUnderflow, BVSubNoUnderflow,
LShR, LShR,
) )
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And

@ -8,7 +8,8 @@ default values over a certain range.
from typing import cast from typing import cast
import z3 import z3
from mythril.laser.smt.bitvec import BitVec, If from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvec_helper import If
from mythril.laser.smt.bool import Bool from mythril.laser.smt.bool import Bool

@ -1,10 +1,11 @@
"""This module provides classes for an SMT abstraction of bit vectors.""" """This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift, ne, eq from operator import lshift, rshift, ne, eq
from typing import Union, Set, cast, Any, Optional, Callable
import z3 import z3
from mythril.laser.smt.bool import Bool, And, Or from mythril.laser.smt.bool import Bool
from mythril.laser.smt.expression import Expression from mythril.laser.smt.expression import Expression
Annotations = Set[Any] Annotations = Set[Any]
@ -276,276 +277,5 @@ class BitVec(Expression[z3.BitVecRef]):
return self.raw.__hash__() return self.raw.__hash__()
def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(operation(a.raw, b.raw), annotations=annotations)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(default_value), annotations=annotations)
return And(
Bool(operation(a.raw, b.raw), annotations=annotations),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=a.func_name, input_=a.input_, annotations=union
)
elif isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=b.func_name, input_=b.input_, annotations=union
)
return BitVec(raw, annotations=union)
def LShR(a: BitVec, b: BitVec):
return _arithmetic_helper(a, b, z3.LShR)
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression.
:param a:
:param b:
:param c:
:return:
"""
# TODO: Handle BitVecFunc
if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations.union(b.annotations).union(c.annotations)
return BitVec(z3.If(a.raw, b.raw, c.raw), union)
def UGT(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater than expression.
:param a:
:param b:
:return:
"""
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False)
def UGE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater or equals expression.
:param a:
:param b:
:return:
"""
return Or(UGT(a, b), a == b)
def ULT(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
:param a:
:param b:
:return:
"""
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False)
def ULE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
:param a:
:param b:
:return:
"""
return Or(ULT(a, b), a == b)
@overload
def Concat(*args: List[BitVec]) -> BitVec: ...
@overload
def Concat(*args: BitVec) -> BitVec: ...
def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
"""Create a concatenation expression.
:param args:
:return:
"""
# The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list):
bvs = args[0] # type: List[BitVec]
else:
bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations
bitvecfunc = False
for bv in bvs:
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfunc = True
if bitvecfunc:
# Added this not so good and misleading NOTATION to help with this
str_hash = ",".join(["hashed({})".format(hash(bv)) for bv in bvs])
input_string = "MisleadingNotationConcat({})".format(str_hash)
return BitVecFunc(
raw=nraw, func_name="Hybrid", input_=BitVec(z3.BitVec(input_string, 256), annotations=annotations)
)
return BitVec(nraw, annotations)
def Extract(high: int, low: int, bv: BitVec) -> BitVec:
"""Create an extract expression.
:param high:
:param low:
:param bv:
:return:
"""
raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
input_string = "MisleadingNotationExtract({}, {}, hashed({}))".format(high, low, hash(bv))
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw, func_name="Hybrid", input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations)
)
return BitVec(raw, annotations=bv.annotations)
def URem(a: BitVec, b: BitVec) -> BitVec:
"""Create an unsigned remainder expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.URem)
def SRem(a: BitVec, b: BitVec) -> BitVec:
"""Create a signed remainder expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.SRem)
def UDiv(a: BitVec, b: BitVec) -> BitVec:
"""Create an unsigned division expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.UDiv)
def Sum(*args: BitVec) -> BitVec:
"""Create sum expression.
:return:
"""
raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args:
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
if len(bitvecfuncs) >= 2:
return BitVecFunc(raw=raw, func_name="Hybrid", input_=None, annotations=annotations)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
)
return BitVec(raw, annotations)
def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
"""Creates predicate that verifies that the addition doesn't overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed))
def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
"""Creates predicate that verifies that the multiplication doesn't
overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
def BVSubNoUnderflow(
a: Union[BitVec, int], b: Union[BitVec, int], signed: bool
) -> Bool:
"""Creates predicate that verifies that the subtraction doesn't overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))
# TODO: Fix circular import issues # TODO: Fix circular import issues
from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.bitvecfunc import BitVecFunc

@ -0,0 +1,291 @@
from typing import Union, overload, List, Set, cast, Any, Optional, Callable
from operator import lshift, rshift, ne, eq
import z3
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.bitvec import BitVec
from mythril.laser.smt.bitvecfunc import BitVecFunc
from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper
from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper
Annotations = Set[Any]
def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_comparison_helper(a, b, operation, default_value, inputs_equal)
return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw)
union = a.annotations.union(b.annotations)
if isinstance(a, BitVecFunc):
return _func_arithmetic_helper(a, b, operation)
elif isinstance(b, BitVecFunc):
return _func_arithmetic_helper(b, a, operation)
return BitVec(raw, annotations=union)
def LShR(a: BitVec, b: BitVec):
return _arithmetic_helper(a, b, z3.LShR)
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression.
:param a:
:param b:
:param c:
:return:
"""
# TODO: Handle BitVecFunc
if not isinstance(a, Bool):
a = Bool(z3.BoolVal(a))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, BitVec):
c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations.union(b.annotations).union(c.annotations)
bvf = [] # type: List[BitVecFunc]
if isinstance(a, BitVecFunc):
bvf += [a]
if isinstance(b, BitVecFunc):
bvf += [b]
if isinstance(c, BitVecFunc):
bvf += [c]
if bvf:
raw = z3.If(a.raw, b.raw, c.raw)
nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf
return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions)
return BitVec(z3.If(a.raw, b.raw, c.raw), union)
def UGT(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater than expression.
:param a:
:param b:
:return:
"""
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False)
def UGE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned greater or equals expression.
:param a:
:param b:
:return:
"""
return Or(UGT(a, b), a == b)
def ULT(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
:param a:
:param b:
:return:
"""
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False)
def ULE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
:param a:
:param b:
:return:
"""
return Or(ULT(a, b), a == b)
@overload
def Concat(*args: List[BitVec]) -> BitVec:
...
@overload
def Concat(*args: BitVec) -> BitVec:
...
def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
"""Create a concatenation expression.
:param args:
:return:
"""
# The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list):
bvs = args[0] # type: List[BitVec]
else:
bvs = cast(List[BitVec], args)
nraw = z3.Concat([a.raw for a in bvs])
annotations = set() # type: Annotations
nested_functions = [] # type: List[BitVecFunc]
for bv in bvs:
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
nested_functions += bv.nested_functions
nested_functions += [bv]
if nested_functions:
return BitVecFunc(
raw=nraw,
func_name="Hybrid",
input_=BitVec(z3.BitVec("", 256), annotations=annotations),
nested_functions=nested_functions,
)
return BitVec(nraw, annotations)
def Extract(high: int, low: int, bv: BitVec) -> BitVec:
"""Create an extract expression.
:param high:
:param low:
:param bv:
:return:
"""
raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
input_string = ""
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations),
nested_functions=bv.nested_functions + [bv],
)
return BitVec(raw, annotations=bv.annotations)
def URem(a: BitVec, b: BitVec) -> BitVec:
"""Create an unsigned remainder expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.URem)
def SRem(a: BitVec, b: BitVec) -> BitVec:
"""Create a signed remainder expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.SRem)
def UDiv(a: BitVec, b: BitVec) -> BitVec:
"""Create an unsigned division expression.
:param a:
:param b:
:return:
"""
return _arithmetic_helper(a, b, z3.UDiv)
def Sum(*args: BitVec) -> BitVec:
"""Create sum expression.
:return:
"""
raw = z3.Sum([a.raw for a in args])
annotations = set() # type: Annotations
bitvecfuncs = []
for bv in args:
annotations = annotations.union(bv.annotations)
if isinstance(bv, BitVecFunc):
bitvecfuncs.append(bv)
nested_functions = [
nf for func in bitvecfuncs for nf in func.nested_functions
] + bitvecfuncs
if len(bitvecfuncs) >= 2:
return BitVecFunc(
raw=raw,
func_name="Hybrid",
input_=None,
annotations=annotations,
nested_functions=nested_functions,
)
elif len(bitvecfuncs) == 1:
return BitVecFunc(
raw=raw,
func_name=bitvecfuncs[0].func_name,
input_=bitvecfuncs[0].input_,
annotations=annotations,
nested_functions=nested_functions,
)
return BitVec(raw, annotations)
def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
"""Creates predicate that verifies that the addition doesn't overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed))
def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool:
"""Creates predicate that verifies that the multiplication doesn't
overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
def BVSubNoUnderflow(
a: Union[BitVec, int], b: Union[BitVec, int], signed: bool
) -> Bool:
"""Creates predicate that verifies that the subtraction doesn't overflow.
:param a:
:param b:
:param signed:
:return:
"""
if not isinstance(a, BitVec):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, BitVec):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))

@ -1,11 +1,10 @@
from typing import Optional, Union, cast, Callable import operator
from itertools import product
from typing import Optional, Union, cast, Callable, List
import z3 import z3
from mythril.laser.smt.bitvec import BitVec, Bool, And, Annotations from mythril.laser.smt.bitvec import BitVec, Annotations
from mythril.laser.smt.bool import Or from mythril.laser.smt.bool import Or, Bool, And
import operator
def _arithmetic_helper( def _arithmetic_helper(
@ -26,18 +25,19 @@ def _arithmetic_helper(
union = a.annotations.union(b.annotations) union = a.annotations.union(b.annotations)
if isinstance(b, BitVecFunc): if isinstance(b, BitVecFunc):
# TODO: Find better value to set input and name to in this case?
input_string = "MisleadingNotationop(invhash({}) {} invhash({})".format(
hash(a), operation, hash(b)
)
return BitVecFunc( return BitVecFunc(
raw=raw, raw=raw,
func_name="Hybrid", func_name="Hybrid",
input_=BitVec(z3.BitVec(input_string, 256), annotations=union), input_=BitVec(z3.BitVec("", 256), annotations=union),
nested_functions=a.nested_functions + b.nested_functions + [a, b],
) )
return BitVecFunc( return BitVecFunc(
raw=raw, func_name=a.func_name, input_=a.input_, annotations=union raw=raw,
func_name=a.func_name,
input_=a.input_,
annotations=union,
nested_functions=a.nested_functions + [a],
) )
@ -62,18 +62,55 @@ def _comparison_helper(
union = a.annotations.union(b.annotations) union = a.annotations.union(b.annotations)
if not a.symbolic and not b.symbolic: if not a.symbolic and not b.symbolic:
if operation == z3.UGT:
operation = operator.gt
if operation == z3.ULT:
operation = operator.lt
return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union)
if ( if (
not isinstance(b, BitVecFunc) not isinstance(b, BitVecFunc)
or not a.func_name or not a.func_name
or not a.input_ or not a.input_
or not a.func_name == b.func_name or not a.func_name == b.func_name
or str(operation) not in ("<built-in function eq>", "<built-in function ne>")
): ):
return Bool(z3.BoolVal(default_value), annotations=union) return Bool(z3.BoolVal(default_value), annotations=union)
condition = True
for a_nest, b_nest in product(a.nested_functions, b.nested_functions):
if a_nest.func_name != b_nest.func_name:
continue
if a_nest.func_name == "Hybrid":
continue
# a.input (eq/neq) b.input ==> a == b
if inputs_equal:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ == b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ == b_nest.input_).raw,
),
)
else:
condition = z3.And(
condition,
z3.Or(
z3.Not((a_nest.input_ != b_nest.input_).raw),
(a_nest.raw == b_nest.raw),
),
z3.Or(
z3.Not((a_nest.raw == b_nest.raw)),
(a_nest.input_ != b_nest.input_).raw,
),
)
return And( return And(
Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union),
Bool(condition) if b.nested_functions else Bool(True),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
) )
@ -87,6 +124,7 @@ class BitVecFunc(BitVec):
func_name: Optional[str], func_name: Optional[str],
input_: "BitVec" = None, input_: "BitVec" = None,
annotations: Optional[Annotations] = None, annotations: Optional[Annotations] = None,
nested_functions: Optional[List["BitVecFunc"]] = None,
): ):
""" """
@ -98,6 +136,10 @@ class BitVecFunc(BitVec):
self.func_name = func_name self.func_name = func_name
self.input_ = input_ self.input_ = input_
self.nested_functions = nested_functions or []
self.nested_functions = list(dict.fromkeys(self.nested_functions))
if isinstance(input_, BitVecFunc):
self.nested_functions.extend(input_.nested_functions)
super().__init__(raw, annotations) super().__init__(raw, annotations)
def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc":

@ -10,6 +10,7 @@ from mythril.support.source_support import Source
from mythril.support.loader import DynLoader from mythril.support.loader import DynLoader
from mythril.analysis.symbolic import SymExecWrapper from mythril.analysis.symbolic import SymExecWrapper
from mythril.analysis.callgraph import generate_graph from mythril.analysis.callgraph import generate_graph
from mythril.analysis.analysis_args import analysis_args
from mythril.analysis.traceexplore import get_serializable_statespace from mythril.analysis.traceexplore import get_serializable_statespace
from mythril.analysis.security import fire_lasers, retrieve_callback_issues from mythril.analysis.security import fire_lasers, retrieve_callback_issues
from mythril.analysis.report import Report, Issue from mythril.analysis.report import Report, Issue
@ -39,6 +40,7 @@ class MythrilAnalyzer:
create_timeout: Optional[int] = None, create_timeout: Optional[int] = None,
enable_iprof: bool = False, enable_iprof: bool = False,
disable_dependency_pruning: bool = False, disable_dependency_pruning: bool = False,
solver_timeout: Optional[int] = None,
): ):
""" """
@ -60,6 +62,9 @@ class MythrilAnalyzer:
self.enable_iprof = enable_iprof self.enable_iprof = enable_iprof
self.disable_dependency_pruning = disable_dependency_pruning self.disable_dependency_pruning = disable_dependency_pruning
analysis_args.set_loop_bound(loop_bound)
analysis_args.set_solver_timeout(solver_timeout)
def dump_statespace(self, contract: EVMContract = None) -> str: def dump_statespace(self, contract: EVMContract = None) -> str:
""" """
Returns serializable statespace of the contract Returns serializable statespace of the contract

@ -3,6 +3,11 @@ and dependencies."""
from mythril.disassembler.disassembly import Disassembly from mythril.disassembler.disassembly import Disassembly
import logging import logging
import re import re
import functools
from mythril.ethereum.interface.rpc.client import EthJsonRpc
from typing import Optional
LRU_CACHE_SIZE = 4096
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -10,7 +15,9 @@ log = logging.getLogger(__name__)
class DynLoader: class DynLoader:
"""The dynamic loader class.""" """The dynamic loader class."""
def __init__(self, eth, contract_loading=True, storage_loading=True): def __init__(
self, eth: Optional[EthJsonRpc], contract_loading=True, storage_loading=True
):
""" """
:param eth: :param eth:
@ -18,11 +25,11 @@ class DynLoader:
:param storage_loading: :param storage_loading:
""" """
self.eth = eth self.eth = eth
self.storage_cache = {}
self.contract_loading = contract_loading self.contract_loading = contract_loading
self.storage_loading = storage_loading self.storage_loading = storage_loading
def read_storage(self, contract_address: str, index: int): @functools.lru_cache(LRU_CACHE_SIZE)
def read_storage(self, contract_address: str, index: int) -> str:
""" """
:param contract_address: :param contract_address:
@ -30,43 +37,28 @@ class DynLoader:
:return: :return:
""" """
if not self.storage_loading: if not self.storage_loading:
raise Exception( raise ValueError(
"Cannot load from the storage when the storage_loading flag is false" "Cannot load from the storage when the storage_loading flag is false"
) )
if not self.eth:
raise ValueError("Cannot load from the storage when eth is None")
try: return self.eth.eth_getStorageAt(
contract_ref = self.storage_cache[contract_address] contract_address, position=index, block="latest"
data = contract_ref[index] )
except KeyError:
self.storage_cache[contract_address] = {}
data = self.eth.eth_getStorageAt(
contract_address, position=index, block="latest"
)
self.storage_cache[contract_address][index] = data
except IndexError:
data = self.eth.eth_getStorageAt(
contract_address, position=index, block="latest"
)
self.storage_cache[contract_address][index] = data
return data
def dynld(self, dependency_address): @functools.lru_cache(LRU_CACHE_SIZE)
def dynld(self, dependency_address: str) -> Optional[Disassembly]:
""" """
:param dependency_address: :param dependency_address:
:return: :return:
""" """
if not self.contract_loading: if not self.contract_loading:
raise ValueError("Cannot load contract when contract_loading flag is false") raise ValueError("Cannot load contract when contract_loading flag is false")
if not self.eth:
raise ValueError("Cannot load from the storage when eth is None")
log.debug("Dynld at contract " + dependency_address) log.debug("Dynld at contract %s", dependency_address)
# Ensure that dependency_address is the correct length, with 0s prepended as needed. # Ensure that dependency_address is the correct length, with 0s prepended as needed.
dependency_address = ( dependency_address = (
@ -81,7 +73,7 @@ class DynLoader:
else: else:
return None return None
log.debug("Dependency address: " + dependency_address) log.debug("Dependency address: %s", dependency_address)
code = self.eth.eth_getCode(dependency_address) code = self.eth.eth_getCode(dependency_address)

@ -1,4 +1,4 @@
from mythril.laser.smt import Solver, symbol_factory, bitvec from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE
import z3 import z3
import pytest import pytest
@ -42,10 +42,10 @@ def test_bitvecfunc_arithmetic(operation, expected):
(operator.le, z3.sat), (operator.le, z3.sat),
(operator.gt, z3.unsat), (operator.gt, z3.unsat),
(operator.ge, z3.sat), (operator.ge, z3.sat),
(bitvec.UGT, z3.unsat), (UGT, z3.unsat),
(bitvec.UGE, z3.sat), (UGE, z3.sat),
(bitvec.ULT, z3.unsat), (ULT, z3.unsat),
(bitvec.ULE, z3.sat), (ULE, z3.sat),
], ],
) )
def test_bitvecfunc_bitvecfunc_comparison(operation, expected): def test_bitvecfunc_bitvecfunc_comparison(operation, expected):
@ -80,3 +80,158 @@ def test_bitvecfunc_bitvecfuncval_comparison():
# Assert # Assert
assert s.check() == z3.sat assert s.check() == z3.sat
assert s.model().eval(input2.raw) == 1337 assert s.model().eval(input2.raw) == 1337
def test_bitvecfunc_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_unequal_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 != input2)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_nested_comparison():
# arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
def test_bitvecfunc_ext_unequal_nested_comparison():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 == input2)
s.add(input3 != input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_ext_unequal_nested_comparison_f():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
input3 = symbol_factory.BitVecSym("input3", 256)
input4 = symbol_factory.BitVecSym("input4", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4)
# Act
s.add(input1 != input2)
s.add(input3 == input4)
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.unsat
def test_bitvecfunc_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
# Act
s.add(input1 == symbol_factory.BitVecVal(1, 256))
s.add(bvf1 == bvf2)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 1
def test_bitvecfunc_nested_find_input():
# Arrange
s = Solver()
input1 = symbol_factory.BitVecSym("input1", 256)
input2 = symbol_factory.BitVecSym("input2", 256)
bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1)
bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1)
bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2)
bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3)
# Act
s.add(input1 == symbol_factory.BitVecVal(123, 256))
s.add(bvf2 == bvf4)
# Assert
assert s.check() == z3.sat
assert s.model()[input2.raw] == 123

Loading…
Cancel
Save