diff --git a/README.md b/README.md index 586d6fd8..4eee32a3 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,45 @@ See the [Wiki](https://github.com/ConsenSys/mythril/wiki/Installation-and-Setup) ## Usage +Run: + +``` +$ myth analyze +``` + +Or: + +``` +$ myth analyze -a +``` + +Specify the maximum number of transaction to explore with `-t `. You can also set a timeout with `--execution-timeout `. Example ([source code](https://gist.github.com/b-mueller/2b251297ce88aa7628680f50f177a81a#file-killbilly-sol)): + +``` +==== Unprotected Selfdestruct ==== +SWC ID: 106 +Severity: High +Contract: KillBilly +Function name: commencekilling() +PC address: 354 +Estimated Gas Usage: 574 - 999 +The contract can be killed by anyone. +Anyone can kill this contract and withdraw its balance to an arbitrary address. +-------------------- +In file: killbilly.sol:22 + +selfdestruct(msg.sender) + +-------------------- +Transaction Sequence: + +Caller: [CREATOR], data: [CONTRACT CREATION], value: 0x0 +Caller: [ATTACKER], function: killerize(address), txdata: 0x9fa299ccbebebebebebebebebebebebedeadbeefdeadbeefdeadbeefdeadbeefdeadbeef, value: 0x0 +Caller: [ATTACKER], function: activatekillability(), txdata: 0x84057065, value: 0x0 +Caller: [ATTACKER], function: commencekilling(), txdata: 0x7c11da20, value: 0x0 +``` + + Instructions for using Mythril are found on the [Wiki](https://github.com/ConsenSys/mythril/wiki). For support or general discussions please join the Mythril community on [Discord](https://discord.gg/E3YrVtG). diff --git a/all_tests.sh b/all_tests.sh index 61cd7d67..7fd0ae81 100755 --- a/all_tests.sh +++ b/all_tests.sh @@ -7,22 +7,6 @@ assert sys.version_info[0:2] >= (3,5), \ """Please make sure you are using Python 3.5 or later. You ran with {}""".format(sys.version)' || exit $? -echo "Checking solc version..." -out=$(solc --version) || { - echo 2>&1 "Please make sure you have solc installed, version 0.4.21 or greater" - - } -case $out in - *Version:\ 0.4.2[1-9]* ) - echo $out - ;; - * ) - echo $out - echo "Please make sure your solc version is at least 0.4.21" - exit 1 - ;; -esac - echo "Checking that truffle is installed..." if ! which truffle ; then echo "Please make sure you have etherum truffle installed (npm install -g truffle)" diff --git a/mythril/__version__.py b/mythril/__version__.py index 069a8c91..3490630e 100644 --- a/mythril/__version__.py +++ b/mythril/__version__.py @@ -4,4 +4,4 @@ This file is suitable for sourcing inside POSIX shell, e.g. bash as well as for importing into Python. """ -__version__ = "v0.21.9" +__version__ = "v0.21.12" diff --git a/mythril/analysis/analysis_args.py b/mythril/analysis/analysis_args.py new file mode 100644 index 00000000..fcd5fc36 --- /dev/null +++ b/mythril/analysis/analysis_args.py @@ -0,0 +1,30 @@ +from mythril.support.support_utils import Singleton + + +class AnalysisArgs(object, metaclass=Singleton): + """ + This module helps in preventing args being sent through multiple of classes to reach analysis modules + """ + + def __init__(self): + self._loop_bound = 3 + self._solver_timeout = 10000 + + def set_loop_bound(self, loop_bound: int): + if loop_bound is not None: + self._loop_bound = loop_bound + + def set_solver_timeout(self, solver_timeout: int): + if solver_timeout is not None: + self._solver_timeout = solver_timeout + + @property + def loop_bound(self): + return self._loop_bound + + @property + def solver_timeout(self): + return self._solver_timeout + + +analysis_args = AnalysisArgs() diff --git a/mythril/analysis/modules/deprecated_ops.py b/mythril/analysis/modules/deprecated_ops.py index a0b69f98..7e495b5b 100644 --- a/mythril/analysis/modules/deprecated_ops.py +++ b/mythril/analysis/modules/deprecated_ops.py @@ -35,6 +35,7 @@ class DeprecatedOperationsModule(DetectionModule): if state.get_current_instruction()["address"] in self._cache: return issues = self._analyze_state(state) + for issue in issues: self._cache.add(issue.address) self._issues.extend(issues) @@ -74,13 +75,13 @@ class DeprecatedOperationsModule(DetectionModule): ) swc_id = DEPRECATED_FUNCTIONS_USAGE else: - return + return [] try: transaction_sequence = get_transaction_sequence( state, state.mstate.constraints ) except UnsatError: - return + return [] issue = Issue( contract=state.environment.active_account.contract_name, function_name=state.environment.active_function_name, diff --git a/mythril/analysis/modules/dos.py b/mythril/analysis/modules/dos.py index 20426727..2ee08abe 100644 --- a/mythril/analysis/modules/dos.py +++ b/mythril/analysis/modules/dos.py @@ -7,6 +7,7 @@ from mythril.analysis.swc_data import DOS_WITH_BLOCK_GAS_LIMIT from mythril.analysis.report import Issue from mythril.analysis.modules.base import DetectionModule from mythril.analysis.solver import get_transaction_sequence, UnsatError +from mythril.analysis.analysis_args import analysis_args from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.annotation import StateAnnotation from mythril.laser.ethereum import util @@ -90,7 +91,7 @@ class DosModule(DetectionModule): else: annotation.jump_targets[target] = 1 - if annotation.jump_targets[target] > 2: + if annotation.jump_targets[target] > min(2, analysis_args.loop_bound - 1): annotation.loop_start = address elif annotation.loop_start is not None: diff --git a/mythril/analysis/modules/ether_thief.py b/mythril/analysis/modules/ether_thief.py index b5dc8c02..95fbe2fb 100644 --- a/mythril/analysis/modules/ether_thief.py +++ b/mythril/analysis/modules/ether_thief.py @@ -15,8 +15,7 @@ from mythril.analysis.swc_data import UNPROTECTED_ETHER_WITHDRAWAL from mythril.exceptions import UnsatError from mythril.laser.ethereum.transaction import ContractCreationTransaction from mythril.laser.ethereum.state.global_state import GlobalState -from mythril.laser.smt import UGT, Sum, symbol_factory, BVAddNoOverflow -from mythril.laser.smt.bitvec import If +from mythril.laser.smt import UGT, Sum, symbol_factory, BVAddNoOverflow, If log = logging.getLogger(__name__) diff --git a/mythril/analysis/modules/state_change_external_calls.py b/mythril/analysis/modules/state_change_external_calls.py index 9b331dd0..c410b4c0 100644 --- a/mythril/analysis/modules/state_change_external_calls.py +++ b/mythril/analysis/modules/state_change_external_calls.py @@ -171,7 +171,9 @@ class StateChange(DetectionModule): for annotation in annotations: if not annotation.state_change_states: continue - vulnerabilities.append(annotation.get_issue(global_state)) + issue = annotation.get_issue(global_state) + if issue: + vulnerabilities.append(issue) return vulnerabilities @staticmethod diff --git a/mythril/analysis/report.py b/mythril/analysis/report.py index 7491b614..1a39c0ff 100644 --- a/mythril/analysis/report.py +++ b/mythril/analysis/report.py @@ -11,6 +11,7 @@ from mythril.analysis.swc_data import SWC_TO_TITLE from mythril.support.source_support import Source from mythril.support.start_time import StartTime from mythril.support.support_utils import get_code_hash +from mythril.support.signatures import SignatureDB from time import time log = logging.getLogger(__name__) @@ -151,6 +152,30 @@ class Issue: else: self.source_mapping = self.address + def resolve_function_names(self): + """ Resolves function names for each step """ + + if ( + self.transaction_sequence is None + or "steps" not in self.transaction_sequence + ): + return + + signatures = SignatureDB() + + for step in self.transaction_sequence["steps"]: + _hash = step["input"][:10] + + try: + sig = signatures.get(_hash) + + if len(sig) > 0: + step["name"] = sig[0] + else: + step["name"] = "unknown" + except ValueError: + step["name"] = "unknown" + class Report: """A report containing the content of multiple issues.""" @@ -187,6 +212,7 @@ class Report: """ m = hashlib.md5() m.update((issue.contract + str(issue.address) + issue.title).encode("utf-8")) + issue.resolve_function_names() self.issues[m.digest()] = issue def as_text(self): diff --git a/mythril/analysis/solver.py b/mythril/analysis/solver.py index 104dbb4c..6a14fb68 100644 --- a/mythril/analysis/solver.py +++ b/mythril/analysis/solver.py @@ -4,6 +4,7 @@ from typing import Dict, Tuple, Union from z3 import sat, unknown, FuncInterp import z3 +from mythril.analysis.analysis_args import analysis_args from mythril.laser.ethereum.state.global_state import GlobalState from mythril.laser.ethereum.state.constraints import Constraints from mythril.laser.ethereum.transaction import BaseTransaction @@ -29,7 +30,7 @@ def get_model(constraints, minimize=(), maximize=(), enforce_execution_time=True :return: """ s = Optimize() - timeout = 100000 + timeout = analysis_args.solver_timeout if enforce_execution_time: timeout = min(timeout, time_handler.time_remaining() - 500) if timeout <= 0: @@ -132,8 +133,8 @@ def _get_concrete_state(initial_accounts: Dict, min_price_dict: Dict[str, int]): data = dict() # type: Dict[str, Union[int, str]] data["nonce"] = account.nonce data["code"] = account.code.bytecode - data["storage"] = str(account.storage) - data["balance"] = min_price_dict.get(address, 0) + data["storage"] = account.storage.printable_storage + data["balance"] = hex(min_price_dict.get(address, 0)) accounts[hex(address)] = data return {"accounts": accounts} diff --git a/mythril/analysis/symbolic.py b/mythril/analysis/symbolic.py index e6c2bf24..8528a51f 100644 --- a/mythril/analysis/symbolic.py +++ b/mythril/analysis/symbolic.py @@ -47,7 +47,7 @@ class SymExecWrapper: dynloader=None, max_depth=22, execution_timeout=None, - loop_bound=2, + loop_bound=3, create_timeout=None, transaction_count=2, modules=(), diff --git a/mythril/analysis/templates/report_as_markdown.jinja2 b/mythril/analysis/templates/report_as_markdown.jinja2 index 4fd73f84..4583eb26 100644 --- a/mythril/analysis/templates/report_as_markdown.jinja2 +++ b/mythril/analysis/templates/report_as_markdown.jinja2 @@ -32,7 +32,7 @@ In file: {{ issue.filename }}:{{ issue.lineno }} {% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %} Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }} {% else %} -Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, data: {{ step.input }}, value: {{ step.value }} +Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, function: {{ step.name }}, txdata: {{ step.input }}, value: {{ step.value }} {% endif %} {% endfor %} {% endif %} diff --git a/mythril/analysis/templates/report_as_text.jinja2 b/mythril/analysis/templates/report_as_text.jinja2 index 776b661d..8fa2454a 100644 --- a/mythril/analysis/templates/report_as_text.jinja2 +++ b/mythril/analysis/templates/report_as_text.jinja2 @@ -27,7 +27,7 @@ Transaction Sequence: {% if step == issue.tx_sequence.steps[0] and step.input != "0x" and step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %} Caller: [CREATOR], data: [CONTRACT CREATION], value: {{ step.value }} {% else %} -Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, data: {{ step.input }}, value: {{ step.value }} +Caller: {% if step.origin == "0xaffeaffeaffeaffeaffeaffeaffeaffeaffeaffe" %}[CREATOR]{% elif step.origin == "0xdeadbeefdeadbeefdeadbeefdeadbeefdeadbeef" %}[ATTACKER]{% else %}[SOMEGUY]{% endif %}, function: {{ step.name }}, txdata: {{ step.input }}, value: {{ step.value }} {% endif %} {% endfor %} {% endif %} diff --git a/mythril/disassembler/disassembly.py b/mythril/disassembler/disassembly.py index 595b2583..d51927a1 100644 --- a/mythril/disassembler/disassembly.py +++ b/mythril/disassembler/disassembly.py @@ -84,10 +84,8 @@ def get_function_info( # Append with missing 0s at the beginning function_hash = "0x" + instruction_list[index]["argument"][2:].rjust(8, "0") function_names = signature_database.get(function_hash) - if len(function_names) > 1: - # In this case there was an ambiguous result - function_name = "[{}] (ambiguous)".format(", ".join(function_names)) - elif len(function_names) == 1: + + if len(function_names) > 0: function_name = function_names[0] else: function_name = "_function_" + function_hash diff --git a/mythril/interfaces/cli.py b/mythril/interfaces/cli.py index fc665b8a..f18aa254 100644 --- a/mythril/interfaces/cli.py +++ b/mythril/interfaces/cli.py @@ -346,11 +346,17 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser): default=86400, help="The amount of seconds to spend on symbolic execution", ) + options.add_argument( + "--solver-timeout", + type=int, + default=10000, + help="The maximum amount of time(in milli seconds) the solver spends for queries from analysis modules", + ) options.add_argument( "--create-timeout", type=int, default=10, - help="The amount of seconds to spend on " "the initial contract creation", + help="The amount of seconds to spend on the initial contract creation", ) options.add_argument( "-l", @@ -360,8 +366,9 @@ def create_analyzer_parser(analyzer_parser: ArgumentParser): ) options.add_argument( "--no-onchain-storage-access", + "--no-onchain-access", action="store_true", - help="turns off getting the data from onchain contracts", + help="turns off getting the data from onchain contracts (both loading storage and contract code)", ) options.add_argument( @@ -552,6 +559,8 @@ def execute_command( enable_iprof=args.enable_iprof, disable_dependency_pruning=args.disable_dependency_pruning, onchain_storage_access=not args.no_onchain_storage_access, + solver_timeout=args.solver_timeout, + requires_dynld=not args.no_onchain_storage_access, ) if not disassembler.contracts: diff --git a/mythril/interfaces/old_cli.py b/mythril/interfaces/old_cli.py index 0157c2b7..0deedc85 100644 --- a/mythril/interfaces/old_cli.py +++ b/mythril/interfaces/old_cli.py @@ -213,6 +213,12 @@ def create_parser(parser: argparse.ArgumentParser) -> None: default=2, help="Maximum number of transactions issued by laser", ) + options.add_argument( + "--solver-timeout", + type=int, + default=10000, + help="The maximum amount of time(in milli seconds) the solver spends for queries from analysis modules", + ) options.add_argument( "--execution-timeout", type=int, @@ -419,6 +425,7 @@ def execute_command( enable_iprof=args.enable_iprof, disable_dependency_pruning=args.disable_dependency_pruning, onchain_storage_access=not args.no_onchain_storage_access, + solver_timeout=args.solver_timeout, ) if args.disassemble: diff --git a/mythril/laser/ethereum/state/account.py b/mythril/laser/ethereum/state/account.py index 9e83f63a..51d61d27 100644 --- a/mythril/laser/ethereum/state/account.py +++ b/mythril/laser/ethereum/state/account.py @@ -20,6 +20,26 @@ from mythril.disassembler.disassembly import Disassembly from mythril.laser.smt import symbol_factory +class StorageRegion: + def __getitem__(self, item): + raise NotImplementedError + + def __setitem__(self, key, value): + raise NotImplementedError + + +class ArrayStorageRegion(StorageRegion): + """ An ArrayStorageRegion is a storage region that leverages smt array theory to resolve expressions""" + + pass + + +class IteStorageRegion(StorageRegion): + """ An IteStorageRegion is a storage region that uses Ite statements to implement a storage""" + + pass + + class Storage: """Storage class represents the storage of an Account.""" @@ -114,7 +134,7 @@ class Storage: key = self._sanitize(key.input_) storage[key] = value - def __deepcopy__(self, memodict={}): + def __deepcopy__(self, memodict=dict()): concrete = isinstance(self._standard_storage, K) storage = Storage( concrete=concrete, address=self.address, dynamic_loader=self.dynld diff --git a/mythril/laser/smt/__init__.py b/mythril/laser/smt/__init__.py index f441948e..6ab752ce 100644 --- a/mythril/laser/smt/__init__.py +++ b/mythril/laser/smt/__init__.py @@ -1,8 +1,10 @@ -from mythril.laser.smt.bitvec import ( - BitVec, +from mythril.laser.smt.bitvec import BitVec + +from mythril.laser.smt.bitvec_helper import ( If, UGT, ULT, + ULE, Concat, Extract, URem, @@ -15,6 +17,7 @@ from mythril.laser.smt.bitvec import ( BVSubNoUnderflow, LShR, ) + from mythril.laser.smt.bitvecfunc import BitVecFunc from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not, And diff --git a/mythril/laser/smt/array.py b/mythril/laser/smt/array.py index 00107df1..9289b290 100644 --- a/mythril/laser/smt/array.py +++ b/mythril/laser/smt/array.py @@ -8,7 +8,8 @@ default values over a certain range. from typing import cast import z3 -from mythril.laser.smt.bitvec import BitVec, If +from mythril.laser.smt.bitvec import BitVec +from mythril.laser.smt.bitvec_helper import If from mythril.laser.smt.bool import Bool diff --git a/mythril/laser/smt/bitvec.py b/mythril/laser/smt/bitvec.py index df537582..b308e863 100644 --- a/mythril/laser/smt/bitvec.py +++ b/mythril/laser/smt/bitvec.py @@ -1,10 +1,11 @@ """This module provides classes for an SMT abstraction of bit vectors.""" -from typing import Union, overload, List, Set, cast, Any, Optional, Callable from operator import lshift, rshift, ne, eq +from typing import Union, Set, cast, Any, Optional, Callable + import z3 -from mythril.laser.smt.bool import Bool, And, Or +from mythril.laser.smt.bool import Bool from mythril.laser.smt.expression import Expression Annotations = Set[Any] @@ -276,276 +277,5 @@ class BitVec(Expression[z3.BitVecRef]): return self.raw.__hash__() -def _comparison_helper( - a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool -) -> Bool: - annotations = a.annotations.union(b.annotations) - if isinstance(a, BitVecFunc): - if not a.symbolic and not b.symbolic: - return Bool(operation(a.raw, b.raw), annotations=annotations) - - if ( - not isinstance(b, BitVecFunc) - or not a.func_name - or not a.input_ - or not a.func_name == b.func_name - ): - return Bool(z3.BoolVal(default_value), annotations=annotations) - - return And( - Bool(operation(a.raw, b.raw), annotations=annotations), - a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, - ) - - return Bool(operation(a.raw, b.raw), annotations) - - -def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: - raw = operation(a.raw, b.raw) - union = a.annotations.union(b.annotations) - - if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc): - return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union) - elif isinstance(a, BitVecFunc): - return BitVecFunc( - raw=raw, func_name=a.func_name, input_=a.input_, annotations=union - ) - elif isinstance(b, BitVecFunc): - return BitVecFunc( - raw=raw, func_name=b.func_name, input_=b.input_, annotations=union - ) - - return BitVec(raw, annotations=union) - - -def LShR(a: BitVec, b: BitVec): - return _arithmetic_helper(a, b, z3.LShR) - - -def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: - """Create an if-then-else expression. - - :param a: - :param b: - :param c: - :return: - """ - # TODO: Handle BitVecFunc - - if not isinstance(a, Bool): - a = Bool(z3.BoolVal(a)) - if not isinstance(b, BitVec): - b = BitVec(z3.BitVecVal(b, 256)) - if not isinstance(c, BitVec): - c = BitVec(z3.BitVecVal(c, 256)) - union = a.annotations.union(b.annotations).union(c.annotations) - return BitVec(z3.If(a.raw, b.raw, c.raw), union) - - -def UGT(a: BitVec, b: BitVec) -> Bool: - """Create an unsigned greater than expression. - - :param a: - :param b: - :return: - """ - return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) - - -def UGE(a: BitVec, b: BitVec) -> Bool: - """Create an unsigned greater or equals expression. - - :param a: - :param b: - :return: - """ - return Or(UGT(a, b), a == b) - - -def ULT(a: BitVec, b: BitVec) -> Bool: - """Create an unsigned less than expression. - - :param a: - :param b: - :return: - """ - return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) - - -def ULE(a: BitVec, b: BitVec) -> Bool: - """Create an unsigned less than expression. - - :param a: - :param b: - :return: - """ - return Or(ULT(a, b), a == b) - - -@overload -def Concat(*args: List[BitVec]) -> BitVec: ... - - -@overload -def Concat(*args: BitVec) -> BitVec: ... - - -def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: - """Create a concatenation expression. - - :param args: - :return: - """ - # The following statement is used if a list is provided as an argument to concat - if len(args) == 1 and isinstance(args[0], list): - bvs = args[0] # type: List[BitVec] - else: - bvs = cast(List[BitVec], args) - - nraw = z3.Concat([a.raw for a in bvs]) - annotations = set() # type: Annotations - bitvecfunc = False - for bv in bvs: - annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - bitvecfunc = True - - if bitvecfunc: - # Added this not so good and misleading NOTATION to help with this - str_hash = ",".join(["hashed({})".format(hash(bv)) for bv in bvs]) - input_string = "MisleadingNotationConcat({})".format(str_hash) - - return BitVecFunc( - raw=nraw, func_name="Hybrid", input_=BitVec(z3.BitVec(input_string, 256), annotations=annotations) - ) - - return BitVec(nraw, annotations) - - -def Extract(high: int, low: int, bv: BitVec) -> BitVec: - """Create an extract expression. - - :param high: - :param low: - :param bv: - :return: - """ - raw = z3.Extract(high, low, bv.raw) - if isinstance(bv, BitVecFunc): - input_string = "MisleadingNotationExtract({}, {}, hashed({}))".format(high, low, hash(bv)) - # Is there a better value to set func_name and input to in this case? - return BitVecFunc( - raw=raw, func_name="Hybrid", input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations) - ) - - return BitVec(raw, annotations=bv.annotations) - - -def URem(a: BitVec, b: BitVec) -> BitVec: - """Create an unsigned remainder expression. - - :param a: - :param b: - :return: - """ - return _arithmetic_helper(a, b, z3.URem) - - -def SRem(a: BitVec, b: BitVec) -> BitVec: - """Create a signed remainder expression. - - :param a: - :param b: - :return: - """ - return _arithmetic_helper(a, b, z3.SRem) - - -def UDiv(a: BitVec, b: BitVec) -> BitVec: - """Create an unsigned division expression. - - :param a: - :param b: - :return: - """ - return _arithmetic_helper(a, b, z3.UDiv) - - -def Sum(*args: BitVec) -> BitVec: - """Create sum expression. - - :return: - """ - raw = z3.Sum([a.raw for a in args]) - annotations = set() # type: Annotations - bitvecfuncs = [] - - for bv in args: - annotations = annotations.union(bv.annotations) - if isinstance(bv, BitVecFunc): - bitvecfuncs.append(bv) - - if len(bitvecfuncs) >= 2: - return BitVecFunc(raw=raw, func_name="Hybrid", input_=None, annotations=annotations) - elif len(bitvecfuncs) == 1: - return BitVecFunc( - raw=raw, - func_name=bitvecfuncs[0].func_name, - input_=bitvecfuncs[0].input_, - annotations=annotations, - ) - - return BitVec(raw, annotations) - - -def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: - """Creates predicate that verifies that the addition doesn't overflow. - - :param a: - :param b: - :param signed: - :return: - """ - if not isinstance(a, BitVec): - a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, BitVec): - b = BitVec(z3.BitVecVal(b, 256)) - return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed)) - - -def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: - """Creates predicate that verifies that the multiplication doesn't - overflow. - - :param a: - :param b: - :param signed: - :return: - """ - if not isinstance(a, BitVec): - a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, BitVec): - b = BitVec(z3.BitVecVal(b, 256)) - return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) - - -def BVSubNoUnderflow( - a: Union[BitVec, int], b: Union[BitVec, int], signed: bool -) -> Bool: - """Creates predicate that verifies that the subtraction doesn't overflow. - - :param a: - :param b: - :param signed: - :return: - """ - if not isinstance(a, BitVec): - a = BitVec(z3.BitVecVal(a, 256)) - if not isinstance(b, BitVec): - b = BitVec(z3.BitVecVal(b, 256)) - - return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) - - # TODO: Fix circular import issues from mythril.laser.smt.bitvecfunc import BitVecFunc diff --git a/mythril/laser/smt/bitvec_helper.py b/mythril/laser/smt/bitvec_helper.py new file mode 100644 index 00000000..8e68e0c9 --- /dev/null +++ b/mythril/laser/smt/bitvec_helper.py @@ -0,0 +1,291 @@ +from typing import Union, overload, List, Set, cast, Any, Optional, Callable +from operator import lshift, rshift, ne, eq +import z3 + +from mythril.laser.smt.bool import Bool, And, Or +from mythril.laser.smt.bitvec import BitVec +from mythril.laser.smt.bitvecfunc import BitVecFunc +from mythril.laser.smt.bitvecfunc import _arithmetic_helper as _func_arithmetic_helper +from mythril.laser.smt.bitvecfunc import _comparison_helper as _func_comparison_helper + +Annotations = Set[Any] + + +def _comparison_helper( + a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool +) -> Bool: + annotations = a.annotations.union(b.annotations) + if isinstance(a, BitVecFunc): + return _func_comparison_helper(a, b, operation, default_value, inputs_equal) + return Bool(operation(a.raw, b.raw), annotations) + + +def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec: + raw = operation(a.raw, b.raw) + union = a.annotations.union(b.annotations) + + if isinstance(a, BitVecFunc): + return _func_arithmetic_helper(a, b, operation) + elif isinstance(b, BitVecFunc): + return _func_arithmetic_helper(b, a, operation) + + return BitVec(raw, annotations=union) + + +def LShR(a: BitVec, b: BitVec): + return _arithmetic_helper(a, b, z3.LShR) + + +def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec: + """Create an if-then-else expression. + + :param a: + :param b: + :param c: + :return: + """ + # TODO: Handle BitVecFunc + + if not isinstance(a, Bool): + a = Bool(z3.BoolVal(a)) + if not isinstance(b, BitVec): + b = BitVec(z3.BitVecVal(b, 256)) + if not isinstance(c, BitVec): + c = BitVec(z3.BitVecVal(c, 256)) + union = a.annotations.union(b.annotations).union(c.annotations) + + bvf = [] # type: List[BitVecFunc] + if isinstance(a, BitVecFunc): + bvf += [a] + if isinstance(b, BitVecFunc): + bvf += [b] + if isinstance(c, BitVecFunc): + bvf += [c] + if bvf: + raw = z3.If(a.raw, b.raw, c.raw) + nested_functions = [nf for func in bvf for nf in func.nested_functions] + bvf + return BitVecFunc(raw, func_name="Hybrid", nested_functions=nested_functions) + + return BitVec(z3.If(a.raw, b.raw, c.raw), union) + + +def UGT(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned greater than expression. + + :param a: + :param b: + :return: + """ + return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False) + + +def UGE(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned greater or equals expression. + + :param a: + :param b: + :return: + """ + return Or(UGT(a, b), a == b) + + +def ULT(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned less than expression. + + :param a: + :param b: + :return: + """ + return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False) + + +def ULE(a: BitVec, b: BitVec) -> Bool: + """Create an unsigned less than expression. + + :param a: + :param b: + :return: + """ + return Or(ULT(a, b), a == b) + + +@overload +def Concat(*args: List[BitVec]) -> BitVec: + ... + + +@overload +def Concat(*args: BitVec) -> BitVec: + ... + + +def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec: + """Create a concatenation expression. + + :param args: + :return: + """ + # The following statement is used if a list is provided as an argument to concat + if len(args) == 1 and isinstance(args[0], list): + bvs = args[0] # type: List[BitVec] + else: + bvs = cast(List[BitVec], args) + + nraw = z3.Concat([a.raw for a in bvs]) + annotations = set() # type: Annotations + + nested_functions = [] # type: List[BitVecFunc] + for bv in bvs: + annotations = annotations.union(bv.annotations) + if isinstance(bv, BitVecFunc): + nested_functions += bv.nested_functions + nested_functions += [bv] + + if nested_functions: + return BitVecFunc( + raw=nraw, + func_name="Hybrid", + input_=BitVec(z3.BitVec("", 256), annotations=annotations), + nested_functions=nested_functions, + ) + + return BitVec(nraw, annotations) + + +def Extract(high: int, low: int, bv: BitVec) -> BitVec: + """Create an extract expression. + + :param high: + :param low: + :param bv: + :return: + """ + raw = z3.Extract(high, low, bv.raw) + if isinstance(bv, BitVecFunc): + input_string = "" + # Is there a better value to set func_name and input to in this case? + return BitVecFunc( + raw=raw, + func_name="Hybrid", + input_=BitVec(z3.BitVec(input_string, 256), annotations=bv.annotations), + nested_functions=bv.nested_functions + [bv], + ) + + return BitVec(raw, annotations=bv.annotations) + + +def URem(a: BitVec, b: BitVec) -> BitVec: + """Create an unsigned remainder expression. + + :param a: + :param b: + :return: + """ + return _arithmetic_helper(a, b, z3.URem) + + +def SRem(a: BitVec, b: BitVec) -> BitVec: + """Create a signed remainder expression. + + :param a: + :param b: + :return: + """ + return _arithmetic_helper(a, b, z3.SRem) + + +def UDiv(a: BitVec, b: BitVec) -> BitVec: + """Create an unsigned division expression. + + :param a: + :param b: + :return: + """ + return _arithmetic_helper(a, b, z3.UDiv) + + +def Sum(*args: BitVec) -> BitVec: + """Create sum expression. + + :return: + """ + raw = z3.Sum([a.raw for a in args]) + annotations = set() # type: Annotations + bitvecfuncs = [] + + for bv in args: + annotations = annotations.union(bv.annotations) + if isinstance(bv, BitVecFunc): + bitvecfuncs.append(bv) + + nested_functions = [ + nf for func in bitvecfuncs for nf in func.nested_functions + ] + bitvecfuncs + + if len(bitvecfuncs) >= 2: + return BitVecFunc( + raw=raw, + func_name="Hybrid", + input_=None, + annotations=annotations, + nested_functions=nested_functions, + ) + elif len(bitvecfuncs) == 1: + return BitVecFunc( + raw=raw, + func_name=bitvecfuncs[0].func_name, + input_=bitvecfuncs[0].input_, + annotations=annotations, + nested_functions=nested_functions, + ) + + return BitVec(raw, annotations) + + +def BVAddNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: + """Creates predicate that verifies that the addition doesn't overflow. + + :param a: + :param b: + :param signed: + :return: + """ + if not isinstance(a, BitVec): + a = BitVec(z3.BitVecVal(a, 256)) + if not isinstance(b, BitVec): + b = BitVec(z3.BitVecVal(b, 256)) + return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed)) + + +def BVMulNoOverflow(a: Union[BitVec, int], b: Union[BitVec, int], signed: bool) -> Bool: + """Creates predicate that verifies that the multiplication doesn't + overflow. + + :param a: + :param b: + :param signed: + :return: + """ + if not isinstance(a, BitVec): + a = BitVec(z3.BitVecVal(a, 256)) + if not isinstance(b, BitVec): + b = BitVec(z3.BitVecVal(b, 256)) + return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed)) + + +def BVSubNoUnderflow( + a: Union[BitVec, int], b: Union[BitVec, int], signed: bool +) -> Bool: + """Creates predicate that verifies that the subtraction doesn't overflow. + + :param a: + :param b: + :param signed: + :return: + """ + if not isinstance(a, BitVec): + a = BitVec(z3.BitVecVal(a, 256)) + if not isinstance(b, BitVec): + b = BitVec(z3.BitVecVal(b, 256)) + + return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed)) diff --git a/mythril/laser/smt/bitvecfunc.py b/mythril/laser/smt/bitvecfunc.py index 2b7b9e63..c645b146 100644 --- a/mythril/laser/smt/bitvecfunc.py +++ b/mythril/laser/smt/bitvecfunc.py @@ -1,11 +1,10 @@ -from typing import Optional, Union, cast, Callable - +import operator +from itertools import product +from typing import Optional, Union, cast, Callable, List import z3 -from mythril.laser.smt.bitvec import BitVec, Bool, And, Annotations -from mythril.laser.smt.bool import Or - -import operator +from mythril.laser.smt.bitvec import BitVec, Annotations +from mythril.laser.smt.bool import Or, Bool, And def _arithmetic_helper( @@ -26,18 +25,19 @@ def _arithmetic_helper( union = a.annotations.union(b.annotations) if isinstance(b, BitVecFunc): - # TODO: Find better value to set input and name to in this case? - input_string = "MisleadingNotationop(invhash({}) {} invhash({})".format( - hash(a), operation, hash(b) - ) return BitVecFunc( raw=raw, func_name="Hybrid", - input_=BitVec(z3.BitVec(input_string, 256), annotations=union), + input_=BitVec(z3.BitVec("", 256), annotations=union), + nested_functions=a.nested_functions + b.nested_functions + [a, b], ) return BitVecFunc( - raw=raw, func_name=a.func_name, input_=a.input_, annotations=union + raw=raw, + func_name=a.func_name, + input_=a.input_, + annotations=union, + nested_functions=a.nested_functions + [a], ) @@ -62,18 +62,55 @@ def _comparison_helper( union = a.annotations.union(b.annotations) if not a.symbolic and not b.symbolic: + if operation == z3.UGT: + operation = operator.gt + if operation == z3.ULT: + operation = operator.lt return Bool(z3.BoolVal(operation(a.value, b.value)), annotations=union) - if ( not isinstance(b, BitVecFunc) or not a.func_name or not a.input_ or not a.func_name == b.func_name + or str(operation) not in ("", "") ): return Bool(z3.BoolVal(default_value), annotations=union) + condition = True + for a_nest, b_nest in product(a.nested_functions, b.nested_functions): + if a_nest.func_name != b_nest.func_name: + continue + if a_nest.func_name == "Hybrid": + continue + # a.input (eq/neq) b.input ==> a == b + if inputs_equal: + condition = z3.And( + condition, + z3.Or( + z3.Not((a_nest.input_ == b_nest.input_).raw), + (a_nest.raw == b_nest.raw), + ), + z3.Or( + z3.Not((a_nest.raw == b_nest.raw)), + (a_nest.input_ == b_nest.input_).raw, + ), + ) + else: + condition = z3.And( + condition, + z3.Or( + z3.Not((a_nest.input_ != b_nest.input_).raw), + (a_nest.raw == b_nest.raw), + ), + z3.Or( + z3.Not((a_nest.raw == b_nest.raw)), + (a_nest.input_ != b_nest.input_).raw, + ), + ) + return And( Bool(cast(z3.BoolRef, operation(a.raw, b.raw)), annotations=union), + Bool(condition) if b.nested_functions else Bool(True), a.input_ == b.input_ if inputs_equal else a.input_ != b.input_, ) @@ -87,6 +124,7 @@ class BitVecFunc(BitVec): func_name: Optional[str], input_: "BitVec" = None, annotations: Optional[Annotations] = None, + nested_functions: Optional[List["BitVecFunc"]] = None, ): """ @@ -98,6 +136,10 @@ class BitVecFunc(BitVec): self.func_name = func_name self.input_ = input_ + self.nested_functions = nested_functions or [] + self.nested_functions = list(dict.fromkeys(self.nested_functions)) + if isinstance(input_, BitVecFunc): + self.nested_functions.extend(input_.nested_functions) super().__init__(raw, annotations) def __add__(self, other: Union[int, "BitVec"]) -> "BitVecFunc": diff --git a/mythril/mythril/mythril_analyzer.py b/mythril/mythril/mythril_analyzer.py index 3bef8203..497c4c33 100644 --- a/mythril/mythril/mythril_analyzer.py +++ b/mythril/mythril/mythril_analyzer.py @@ -10,6 +10,7 @@ from mythril.support.source_support import Source from mythril.support.loader import DynLoader from mythril.analysis.symbolic import SymExecWrapper from mythril.analysis.callgraph import generate_graph +from mythril.analysis.analysis_args import analysis_args from mythril.analysis.traceexplore import get_serializable_statespace from mythril.analysis.security import fire_lasers, retrieve_callback_issues from mythril.analysis.report import Report, Issue @@ -39,6 +40,7 @@ class MythrilAnalyzer: create_timeout: Optional[int] = None, enable_iprof: bool = False, disable_dependency_pruning: bool = False, + solver_timeout: Optional[int] = None, ): """ @@ -60,6 +62,9 @@ class MythrilAnalyzer: self.enable_iprof = enable_iprof self.disable_dependency_pruning = disable_dependency_pruning + analysis_args.set_loop_bound(loop_bound) + analysis_args.set_solver_timeout(solver_timeout) + def dump_statespace(self, contract: EVMContract = None) -> str: """ Returns serializable statespace of the contract diff --git a/mythril/support/loader.py b/mythril/support/loader.py index f46feff5..2da25d7b 100644 --- a/mythril/support/loader.py +++ b/mythril/support/loader.py @@ -3,6 +3,11 @@ and dependencies.""" from mythril.disassembler.disassembly import Disassembly import logging import re +import functools +from mythril.ethereum.interface.rpc.client import EthJsonRpc +from typing import Optional + +LRU_CACHE_SIZE = 4096 log = logging.getLogger(__name__) @@ -10,7 +15,9 @@ log = logging.getLogger(__name__) class DynLoader: """The dynamic loader class.""" - def __init__(self, eth, contract_loading=True, storage_loading=True): + def __init__( + self, eth: Optional[EthJsonRpc], contract_loading=True, storage_loading=True + ): """ :param eth: @@ -18,11 +25,11 @@ class DynLoader: :param storage_loading: """ self.eth = eth - self.storage_cache = {} self.contract_loading = contract_loading self.storage_loading = storage_loading - def read_storage(self, contract_address: str, index: int): + @functools.lru_cache(LRU_CACHE_SIZE) + def read_storage(self, contract_address: str, index: int) -> str: """ :param contract_address: @@ -30,43 +37,28 @@ class DynLoader: :return: """ if not self.storage_loading: - raise Exception( + raise ValueError( "Cannot load from the storage when the storage_loading flag is false" ) + if not self.eth: + raise ValueError("Cannot load from the storage when eth is None") - try: - contract_ref = self.storage_cache[contract_address] - data = contract_ref[index] - - except KeyError: - - self.storage_cache[contract_address] = {} - - data = self.eth.eth_getStorageAt( - contract_address, position=index, block="latest" - ) - - self.storage_cache[contract_address][index] = data - - except IndexError: - - data = self.eth.eth_getStorageAt( - contract_address, position=index, block="latest" - ) - - self.storage_cache[contract_address][index] = data - - return data + return self.eth.eth_getStorageAt( + contract_address, position=index, block="latest" + ) - def dynld(self, dependency_address): + @functools.lru_cache(LRU_CACHE_SIZE) + def dynld(self, dependency_address: str) -> Optional[Disassembly]: """ :param dependency_address: :return: """ if not self.contract_loading: raise ValueError("Cannot load contract when contract_loading flag is false") + if not self.eth: + raise ValueError("Cannot load from the storage when eth is None") - log.debug("Dynld at contract " + dependency_address) + log.debug("Dynld at contract %s", dependency_address) # Ensure that dependency_address is the correct length, with 0s prepended as needed. dependency_address = ( @@ -81,7 +73,7 @@ class DynLoader: else: return None - log.debug("Dependency address: " + dependency_address) + log.debug("Dependency address: %s", dependency_address) code = self.eth.eth_getCode(dependency_address) diff --git a/tests/laser/smt/bitvecfunc_test.py b/tests/laser/smt/bitvecfunc_test.py index ea19dad1..37217c73 100644 --- a/tests/laser/smt/bitvecfunc_test.py +++ b/tests/laser/smt/bitvecfunc_test.py @@ -1,4 +1,4 @@ -from mythril.laser.smt import Solver, symbol_factory, bitvec +from mythril.laser.smt import Solver, symbol_factory, UGT, UGE, ULT, ULE import z3 import pytest @@ -42,10 +42,10 @@ def test_bitvecfunc_arithmetic(operation, expected): (operator.le, z3.sat), (operator.gt, z3.unsat), (operator.ge, z3.sat), - (bitvec.UGT, z3.unsat), - (bitvec.UGE, z3.sat), - (bitvec.ULT, z3.unsat), - (bitvec.ULE, z3.sat), + (UGT, z3.unsat), + (UGE, z3.sat), + (ULT, z3.unsat), + (ULE, z3.sat), ], ) def test_bitvecfunc_bitvecfunc_comparison(operation, expected): @@ -80,3 +80,158 @@ def test_bitvecfunc_bitvecfuncval_comparison(): # Assert assert s.check() == z3.sat assert s.model().eval(input2.raw) == 1337 + + +def test_bitvecfunc_nested_comparison(): + # arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) + + # Act + s.add(input1 == input2) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.sat + + +def test_bitvecfunc_unequal_nested_comparison(): + # arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) + + # Act + s.add(input1 != input2) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.unsat + + +def test_bitvecfunc_ext_nested_comparison(): + # arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + input3 = symbol_factory.BitVecSym("input3", 256) + input4 = symbol_factory.BitVecSym("input4", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) + + # Act + s.add(input1 == input2) + s.add(input3 == input4) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.sat + + +def test_bitvecfunc_ext_unequal_nested_comparison(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + input3 = symbol_factory.BitVecSym("input3", 256) + input4 = symbol_factory.BitVecSym("input4", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) + + # Act + s.add(input1 == input2) + s.add(input3 != input4) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.unsat + + +def test_bitvecfunc_ext_unequal_nested_comparison_f(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + input3 = symbol_factory.BitVecSym("input3", 256) + input4 = symbol_factory.BitVecSym("input4", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1 + input3) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3 + input4) + + # Act + s.add(input1 != input2) + s.add(input3 == input4) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.unsat + + +def test_bitvecfunc_find_input(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + + # Act + s.add(input1 == symbol_factory.BitVecVal(1, 256)) + s.add(bvf1 == bvf2) + + # Assert + assert s.check() == z3.sat + assert s.model()[input2.raw] == 1 + + +def test_bitvecfunc_nested_find_input(): + # Arrange + s = Solver() + + input1 = symbol_factory.BitVecSym("input1", 256) + input2 = symbol_factory.BitVecSym("input2", 256) + + bvf1 = symbol_factory.BitVecFuncSym("bvf1", "sha3", 256, input_=input1) + bvf2 = symbol_factory.BitVecFuncSym("bvf2", "sha3", 256, input_=bvf1) + + bvf3 = symbol_factory.BitVecFuncSym("bvf3", "sha3", 256, input_=input2) + bvf4 = symbol_factory.BitVecFuncSym("bvf4", "sha3", 256, input_=bvf3) + + # Act + s.add(input1 == symbol_factory.BitVecVal(123, 256)) + s.add(bvf2 == bvf4) + + # Assert + assert s.check() == z3.sat + assert s.model()[input2.raw] == 123