mythril/laser/ethereum: Replace z3 with smt abstraction layer

pull/813/head
Joran Honig 6 years ago
parent 8afeb4d3ed
commit 4db5e25dc4
  1. 8
      mythril/laser/ethereum/call.py
  2. 85
      mythril/laser/ethereum/instructions.py
  3. 9
      mythril/laser/ethereum/keccak.py
  4. 2
      mythril/laser/ethereum/natives.py
  5. 54
      mythril/laser/ethereum/state/calldata.py
  6. 24
      mythril/laser/ethereum/state/memory.py
  7. 22
      mythril/laser/ethereum/svm.py
  8. 11
      mythril/laser/ethereum/taint_analysis.py
  9. 45
      mythril/laser/ethereum/util.py
  10. 6
      mythril/laser/smt/__init__.py
  11. 2
      mythril/laser/smt/array.py
  12. 38
      mythril/laser/smt/bitvec.py
  13. 29
      mythril/laser/smt/bool.py
  14. 1
      mythril/laser/smt/expression.py
  15. 4
      mythril/laser/smt/solver.py
  16. 7
      tests/laser/evm_testsuite/evm_test.py

@ -1,6 +1,6 @@
import logging
from typing import Union
from z3 import simplify, ExprRef, Extract
from mythril.laser.smt import simplify, Expression
import mythril.laser.ethereum.util as util
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.calldata import (
@ -60,7 +60,7 @@ def get_call_parameters(
def get_callee_address(
global_state: GlobalState, dynamic_loader: DynLoader, symbolic_to_address: ExprRef
global_state: GlobalState, dynamic_loader: DynLoader, symbolic_to_address: Expression
):
"""
Gets the address of the callee
@ -146,8 +146,8 @@ def get_callee_account(
def get_call_data(
global_state: GlobalState,
memory_start: Union[int, ExprRef],
memory_size: Union[int, ExprRef],
memory_start: Union[int, Expression],
memory_size: Union[int, Expression],
):
"""
Gets call_data from the global_state

@ -6,27 +6,25 @@ from typing import Callable, List, Union
from functools import reduce
from ethereum import utils
from z3 import (
from mythril.laser.smt import (
Extract,
Expression,
UDiv,
simplify,
Concat,
ULT,
UGT,
BitVecRef,
BitVecNumRef,
Not,
BitVec,
is_true,
is_false,
is_expr,
ExprRef,
URem,
SRem,
is_true,
If,
BoolRef,
Bool,
Or,
Not,
)
from mythril.laser.smt import symbol_factory
import mythril.laser.ethereum.natives as natives
@ -84,13 +82,13 @@ class StateTransition(object):
@staticmethod
def check_gas_usage_limit(global_state: GlobalState):
global_state.mstate.check_gas()
if isinstance(global_state.current_transaction.gas_limit, BitVecRef):
try:
global_state.current_transaction.gas_limit = (
global_state.current_transaction.gas_limit.as_long()
)
except AttributeError:
if isinstance(global_state.current_transaction.gas_limit, BitVec):
value = global_state.current_transaction.gas_limit.value
if value is None:
return
global_state.current_transaction.gas_limit = (
value
)
if (
global_state.mstate.min_gas_used
>= global_state.current_transaction.gas_limit
@ -141,7 +139,7 @@ class Instruction:
op = "swap"
elif self.op_code.startswith("LOG"):
op = "log"
print(global_state.get_current_instruction())
instruction_mutator = (
getattr(self, op + "_", None)
if not post
@ -195,11 +193,11 @@ class Instruction:
def and_(self, global_state: GlobalState) -> List[GlobalState]:
stack = global_state.mstate.stack
op1, op2 = stack.pop(), stack.pop()
if type(op1) == BoolRef:
if isinstance(op1, Bool):
op1 = If(
op1, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
if type(op2) == BoolRef:
if isinstance(op2, Bool):
op2 = If(
op2, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
@ -212,12 +210,12 @@ class Instruction:
stack = global_state.mstate.stack
op1, op2 = stack.pop(), stack.pop()
if type(op1) == BoolRef:
if isinstance(op1, Bool):
op1 = If(
op1, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
if type(op2) == BoolRef:
if isinstance(op2, Bool):
op2 = If(
op2, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
@ -235,14 +233,14 @@ class Instruction:
@StateTransition()
def not_(self, global_state: GlobalState):
mstate = global_state.mstate
mstate.stack.append(TT256M1 - mstate.stack.pop())
mstate.stack.append(symbol_factory.BitVecVal(TT256M1, 256) - mstate.stack.pop())
return [global_state]
@StateTransition()
def byte_(self, global_state: GlobalState) -> List[GlobalState]:
mstate = global_state.mstate
op0, op1 = mstate.stack.pop(), mstate.stack.pop()
if not isinstance(op1, ExprRef):
if not isinstance(op1, Expression):
op1 = symbol_factory.BitVecVal(op1, 256)
try:
index = util.get_concrete_int(op0)
@ -363,7 +361,7 @@ class Instruction:
state = global_state.mstate
base, exponent = util.pop_bitvec(state), util.pop_bitvec(state)
if (type(base) != BitVecNumRef) or (type(exponent) != BitVecNumRef):
if base.symbolic or exponent.symbolic:
state.stack.append(
global_state.new_bitvec(
"(" + str(simplify(base)) + ")**(" + str(simplify(exponent)) + ")",
@ -372,7 +370,7 @@ class Instruction:
)
else:
state.stack.append(pow(base.as_long(), exponent.as_long(), 2 ** 256))
state.stack.append(pow(base.value, exponent.value, 2 ** 256))
return [global_state]
@ -409,7 +407,8 @@ class Instruction:
@StateTransition()
def gt_(self, global_state: GlobalState) -> List[GlobalState]:
state = global_state.mstate
exp = UGT(util.pop_bitvec(state), util.pop_bitvec(state))
op1, op2 = util.pop_bitvec(state), util.pop_bitvec(state)
exp = UGT(op1, op2)
state.stack.append(exp)
return [global_state]
@ -435,12 +434,12 @@ class Instruction:
op1 = state.stack.pop()
op2 = state.stack.pop()
if type(op1) == BoolRef:
if isinstance(op1, Bool):
op1 = If(
op1, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
if type(op2) == BoolRef:
if isinstance(op2, Bool):
op2 = If(
op2, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
@ -455,7 +454,7 @@ class Instruction:
state = global_state.mstate
val = state.stack.pop()
exp = val == False if type(val) == BoolRef else val == 0
exp = (val == False) if isinstance(val, Bool) else val == 0
state.stack.append(exp)
return [global_state]
@ -629,7 +628,7 @@ class Instruction:
index, length = util.get_concrete_int(op0), util.get_concrete_int(op1)
except TypeError:
# Can't access symbolic memory offsets
if is_expr(op0):
if isinstance(op0, Expression):
op0 = simplify(op0)
state.stack.append(
symbol_factory.BitVecSym("KECCAC_mem[" + str(op0) + "]", 256)
@ -960,7 +959,7 @@ class Instruction:
@staticmethod
def _sload_helper(
global_state: GlobalState, index: Union[int, ExprRef], constraints=None
global_state: GlobalState, index: Union[int, Expression], constraints=None
):
try:
data = global_state.environment.active_account.storage[index]
@ -1049,7 +1048,7 @@ class Instruction:
] = global_state.environment.active_account
global_state.environment.active_account.storage[index] = (
value if not isinstance(value, ExprRef) else simplify(value)
value if not isinstance(value, Expression) else simplify(value)
)
except KeyError:
logging.debug("Error writing to storage: Invalid index")
@ -1113,11 +1112,11 @@ class Instruction:
# False case
negated = (
simplify(Not(condition)) if type(condition) == BoolRef else condition == 0
simplify(Not(condition)) if isinstance(condition, Bool) else condition == 0
)
if (type(negated) == bool and negated) or (
type(negated) == BoolRef and not is_false(negated)
isinstance(condition, Bool) and not is_false(negated)
):
new_state = copy(global_state)
# add JUMPI gas cost
@ -1142,10 +1141,10 @@ class Instruction:
instr = disassembly.instruction_list[index]
condi = simplify(condition) if type(condition) == BoolRef else condition != 0
condi = simplify(condition) if isinstance(condition, Bool) else condition != 0
if instr["opcode"] == "JUMPDEST":
if (type(condi) == bool and condi) or (
type(condi) == BoolRef and not is_false(condi)
isinstance(condi, Bool) and not is_false(condi)
):
new_state = copy(global_state)
# add JUMPI gas cost
@ -1216,8 +1215,8 @@ class Instruction:
account_created = False
# Often the target of the suicide instruction will be symbolic
# If it isn't then well transfer the balance to the indicated contract
if isinstance(target, BitVecNumRef):
target = "0x" + hex(target.as_long())[-40:]
if isinstance(target, BitVec) and not target.symbolic:
target = "0x" + hex(target.value)[-40:]
if isinstance(target, str):
try:
global_state.world_state[
@ -1385,12 +1384,12 @@ class Instruction:
try:
memory_out_offset = (
util.get_concrete_int(memory_out_offset)
if isinstance(memory_out_offset, ExprRef)
if isinstance(memory_out_offset, Expression)
else memory_out_offset
)
memory_out_size = (
util.get_concrete_int(memory_out_size)
if isinstance(memory_out_size, ExprRef)
if isinstance(memory_out_size, Expression)
else memory_out_size
)
except TypeError:
@ -1480,12 +1479,12 @@ class Instruction:
try:
memory_out_offset = (
util.get_concrete_int(memory_out_offset)
if isinstance(memory_out_offset, ExprRef)
if isinstance(memory_out_offset, Expression)
else memory_out_offset
)
memory_out_size = (
util.get_concrete_int(memory_out_size)
if isinstance(memory_out_size, ExprRef)
if isinstance(memory_out_size, Expression)
else memory_out_size
)
except TypeError:
@ -1574,12 +1573,12 @@ class Instruction:
try:
memory_out_offset = (
util.get_concrete_int(memory_out_offset)
if isinstance(memory_out_offset, ExprRef)
if isinstance(memory_out_offset, Expression)
else memory_out_offset
)
memory_out_size = (
util.get_concrete_int(memory_out_size)
if isinstance(memory_out_size, ExprRef)
if isinstance(memory_out_size, Expression)
else memory_out_size
)
except TypeError:

@ -1,18 +1,17 @@
from z3 import ExprRef
from mythril.laser.smt import Expression
class KeccakFunctionManager:
def __init__(self):
self.keccak_expression_mapping = {}
def is_keccak(self, expression: ExprRef) -> bool:
def is_keccak(self, expression: Expression) -> bool:
return str(expression) in self.keccak_expression_mapping.keys()
def get_argument(self, expression: str) -> ExprRef:
def get_argument(self, expression: str) -> Expression:
if not self.is_keccak(expression):
raise ValueError("Expression is not a recognized keccac result")
return self.keccak_expression_mapping[str(expression)][1]
def add_keccak(self, expression: ExprRef, argument: ExprRef) -> None:
def add_keccak(self, expression: Expression, argument: Expression) -> None:
index = str(expression)
self.keccak_expression_mapping[index] = (expression, argument)

@ -10,7 +10,7 @@ from rlp.utils import ALL_BYTES
from mythril.laser.ethereum.state.calldata import BaseCalldata, ConcreteCalldata
from mythril.laser.ethereum.util import bytearray_to_int, sha3, get_concrete_int
from z3 import Concat, simplify
from mythril.laser.smt import Concat, simplify
class NativeContractException(Exception):

@ -1,23 +1,13 @@
from enum import Enum
from typing import Union, Any
from z3 import (
BitVecRef,
BitVec,
simplify,
Concat,
If,
ExprRef,
K,
Array,
BitVecSort,
Store,
is_bv,
)
from z3.z3types import Z3Exception, Model
from mythril.laser.smt import K, Array, If, simplify, Concat, Expression, BitVec
from mythril.laser.smt import symbol_factory
from mythril.laser.ethereum.util import get_concrete_int
from z3 import Model
from z3.z3types import Z3Exception
class CalldataType(Enum):
CONCRETE = 1
@ -34,7 +24,7 @@ class BaseCalldata:
self.tx_id = tx_id
@property
def calldatasize(self) -> ExprRef:
def calldatasize(self) -> Expression:
"""
:return: Calldata size for this calldata object
"""
@ -43,13 +33,13 @@ class BaseCalldata:
return symbol_factory.BitVecVal(result, 256)
return result
def get_word_at(self, offset: int) -> ExprRef:
def get_word_at(self, offset: int) -> Expression:
""" Gets word at offset"""
parts = self[offset : offset + 32]
return simplify(Concat(parts))
def __getitem__(self, item: Union[int, slice]) -> Any:
if isinstance(item, int) or isinstance(item, ExprRef):
if isinstance(item, int) or isinstance(item, Expression):
return self._load(item)
if isinstance(item, slice):
@ -60,13 +50,13 @@ class BaseCalldata:
try:
current_index = (
start
if isinstance(start, BitVecRef)
if isinstance(start, Expression)
else symbol_factory.BitVecVal(start, 256)
)
parts = []
while simplify(current_index != stop):
element = self._load(current_index)
if not isinstance(element, ExprRef):
if not isinstance(element, Expression):
element = symbol_factory.BitVecVal(element, 8)
parts.append(element)
@ -78,11 +68,11 @@ class BaseCalldata:
raise ValueError
def _load(self, item: Union[int, ExprRef]) -> Any:
def _load(self, item: Union[int, Expression]) -> Any:
raise NotImplementedError()
@property
def size(self) -> Union[ExprRef, int]:
def size(self) -> Union[Expression, int]:
""" Returns the exact size of this calldata, this is not normalized"""
raise NotImplementedError()
@ -99,12 +89,12 @@ class ConcreteCalldata(BaseCalldata):
:param calldata: The concrete calldata content
"""
self._concrete_calldata = calldata
self._calldata = K(BitVecSort(256), symbol_factory.BitVecVal(0, 8))
self._calldata = K(256, 8, 0)
for i, element in enumerate(calldata, 0):
self._calldata = Store(self._calldata, i, element)
self._calldata[i] = element
super().__init__(tx_id)
def _load(self, item: Union[int, ExprRef]) -> BitVecSort(8):
def _load(self, item: Union[int, Expression]) -> BitVec:
item = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
return simplify(self._calldata[item])
@ -126,7 +116,7 @@ class BasicConcreteCalldata(BaseCalldata):
self._calldata = calldata
super().__init__(tx_id)
def _load(self, item: Union[int, ExprRef]) -> Any:
def _load(self, item: Union[int, Expression]) -> Any:
if isinstance(item, int):
try:
return self._calldata[item]
@ -154,11 +144,11 @@ class SymbolicCalldata(BaseCalldata):
"""
self._size = BitVec(str(tx_id) + "_calldatasize", 256)
self._calldata = Array(
"{}_calldata".format(tx_id), BitVecSort(256), BitVecSort(8)
"{}_calldata".format(tx_id), 256, 8
)
super().__init__(tx_id)
def _load(self, item: Union[int, ExprRef]) -> Any:
def _load(self, item: Union[int, Expression]) -> Any:
item = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
return simplify(If(item < self._size, simplify(self._calldata[item]), 0))
@ -173,7 +163,7 @@ class SymbolicCalldata(BaseCalldata):
return result
@property
def size(self) -> ExprRef:
def size(self) -> Expression:
return self._size
@ -187,12 +177,12 @@ class BasicSymbolicCalldata(BaseCalldata):
self._size = BitVec(str(tx_id) + "_calldatasize", 256)
super().__init__(tx_id)
def _load(self, item: Union[int, ExprRef], clean=False) -> Any:
x = BitVecVal(item, 256) if isinstance(item, int) else item
def _load(self, item: Union[int, Expression], clean=False) -> Any:
x = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
symbolic_base_value = If(
x >= self._size,
BitVecVal(0, 8),
symbol_factory.BitVecVal(0, 8),
BitVec("{}_calldata_{}".format(self.tx_id, str(item)), 8),
)
return_value = symbolic_base_value
@ -214,5 +204,5 @@ class BasicSymbolicCalldata(BaseCalldata):
return result
@property
def size(self) -> ExprRef:
def size(self) -> Expression:
return self._size

@ -1,6 +1,6 @@
from typing import Union
from z3 import BitVecRef, Extract, BitVecVal, If, BoolRef, Z3Exception, simplify, Concat
from z3 import Z3Exception
from mythril.laser.smt import BitVec, symbol_factory, If, Concat, simplify, Bool, Extract
from mythril.laser.ethereum import util
@ -14,7 +14,7 @@ class Memory:
def extend(self, size):
self._memory.extend(bytearray(size))
def get_word_at(self, index: int) -> Union[int, BitVecRef]:
def get_word_at(self, index: int) -> Union[int, BitVec]:
"""
Access a word from a specified memory index
:param index: integer representing the index to access
@ -24,11 +24,11 @@ class Memory:
return util.concrete_int_from_bytes(
bytes([util.get_concrete_int(b) for b in self[index : index + 32]]), 0
)
except TypeError:
except:
result = simplify(
Concat(
[
b if isinstance(b, BitVecRef) else BitVecVal(b, 8)
b if isinstance(b, BitVec) else symbol_factory.BitVecVal(b, 8)
for b in self[index : index + 32]
]
)
@ -37,7 +37,7 @@ class Memory:
return result
def write_word_at(
self, index: int, value: Union[int, BitVecRef, bool, BoolRef]
self, index: int, value: Union[int, BitVec, bool, Bool]
) -> None:
"""
Writes a 32 byte word to memory at the specified index`
@ -57,8 +57,8 @@ class Memory:
assert len(_bytes) == 32
self[index : index + 32] = _bytes
except (Z3Exception, AttributeError): # BitVector or BoolRef
if isinstance(value, BoolRef):
value_to_write = If(value, BitVecVal(1, 256), BitVecVal(0, 256))
if isinstance(value, Bool):
value_to_write = If(value, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256))
else:
value_to_write = value
assert value_to_write.size() == 256
@ -66,7 +66,7 @@ class Memory:
for i in range(0, value_to_write.size(), 8):
self[index + 31 - (i // 8)] = Extract(i + 7, i, value_to_write)
def __getitem__(self, item: Union[int, slice]) -> Union[BitVecRef, int, list]:
def __getitem__(self, item: Union[int, slice]) -> Union[BitVec, int, list]:
if isinstance(item, slice):
start, step, stop = item.start, item.step, item.stop
if start is None:
@ -82,7 +82,7 @@ class Memory:
except IndexError:
return 0
def __setitem__(self, key: Union[int, slice], value: Union[BitVecRef, int, list]):
def __setitem__(self, key: Union[int, slice], value: Union[BitVec, int, list]):
if isinstance(key, slice):
start, step, stop = key.start, key.step, key.stop
@ -99,6 +99,6 @@ class Memory:
else:
if isinstance(value, int):
assert 0 <= value <= 0xFF
if isinstance(value, BitVecRef):
assert value.size() == 8
# if isinstance(value, BitVec):
# assert value.size() == 8
self._memory[key] = value

@ -1,29 +1,27 @@
import logging
from collections import defaultdict
from ethereum.opcodes import opcodes
from copy import copy
from datetime import datetime, timedelta
from functools import reduce
from typing import List, Tuple, Union, Callable, Dict
from mythril.analysis.security import get_detection_modules
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType
from mythril.laser.ethereum.evm_exceptions import StackUnderflowException
from mythril.laser.ethereum.evm_exceptions import VmException
from mythril.laser.ethereum.instructions import Instruction
from mythril.laser.ethereum.state.account import Account
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.laser.ethereum.state.world_state import WorldState
from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy
from mythril.laser.ethereum.transaction import (
TransactionStartSignal,
TransactionEndSignal,
ContractCreationTransaction,
)
from mythril.laser.ethereum.evm_exceptions import StackUnderflowException
from mythril.laser.ethereum.instructions import Instruction
from mythril.laser.ethereum.cfg import NodeFlags, Node, Edge, JumpType
from mythril.laser.ethereum.strategy.basic import DepthFirstSearchStrategy
from datetime import datetime, timedelta
from copy import copy
from mythril.laser.ethereum.transaction import (
execute_contract_creation,
execute_message_call,
)
from functools import reduce
from mythril.laser.ethereum.evm_exceptions import VmException
class SVMError(Exception):

@ -1,13 +1,12 @@
import logging
import copy
from typing import Union, List, Tuple
from z3 import ExprRef
import mythril.laser.ethereum.util as helper
from mythril.laser.ethereum.cfg import JumpType, Node
from mythril.laser.ethereum.state.environment import Environment
from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.analysis.symbolic import SymExecWrapper
from mythril.laser.smt import Expression
class TaintRecord:
"""
@ -239,7 +238,7 @@ class TaintRunner:
record.stack[l], record.stack[i] = record.stack[i], record.stack[l]
@staticmethod
def mutate_mload(record: TaintRecord, op0: ExprRef) -> None:
def mutate_mload(record: TaintRecord, op0: Expression) -> None:
_ = record.stack.pop()
try:
index = helper.get_concrete_int(op0)
@ -251,7 +250,7 @@ class TaintRunner:
record.stack.append(record.memory_tainted(index))
@staticmethod
def mutate_mstore(record: TaintRecord, op0: ExprRef) -> None:
def mutate_mstore(record: TaintRecord, op0: Expression) -> None:
_, value_taint = record.stack.pop(), record.stack.pop()
try:
index = helper.get_concrete_int(op0)
@ -262,7 +261,7 @@ class TaintRunner:
record.memory[index] = value_taint
@staticmethod
def mutate_sload(record: TaintRecord, op0: ExprRef) -> None:
def mutate_sload(record: TaintRecord, op0: Expression) -> None:
_ = record.stack.pop()
try:
index = helper.get_concrete_int(op0)
@ -274,7 +273,7 @@ class TaintRunner:
record.stack.append(record.storage_tainted(index))
@staticmethod
def mutate_sstore(record: TaintRecord, op0: ExprRef) -> None:
def mutate_sstore(record: TaintRecord, op0: Expression) -> None:
_, value_taint = record.stack.pop(), record.stack.pop()
try:
index = helper.get_concrete_int(op0)

@ -1,14 +1,6 @@
import re
from z3 import (
BitVecVal,
BoolRef,
If,
simplify,
is_false,
is_true,
ExprRef,
BitVecNumRef,
)
from mythril.laser.smt import is_false, is_true, simplify, If, BitVec, Bool, Expression
from mythril.laser.smt import symbol_factory
import logging
@ -55,13 +47,13 @@ def get_trace_line(instr: Dict, state: "MachineState") -> str:
return str(instr["address"]) + " " + instr["opcode"] + "\tSTACK: " + stack
def pop_bitvec(state: "MachineState") -> BitVecVal:
def pop_bitvec(state: "MachineState") -> BitVec:
# pop one element from stack, converting boolean expressions and
# concrete Python variables to BitVecVal
item = state.stack.pop()
if type(item) == BoolRef:
if type(item) == Bool:
return If(
item, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)
)
@ -76,29 +68,24 @@ def pop_bitvec(state: "MachineState") -> BitVecVal:
return simplify(item)
def get_concrete_int(item: Union[int, ExprRef]) -> int:
def get_concrete_int(item: Union[int, Expression]) -> int:
if isinstance(item, int):
return item
elif isinstance(item, BitVecNumRef):
return item.as_long()
elif isinstance(item, BoolRef):
simplified = simplify(item)
if is_false(simplified):
return 0
elif is_true(simplified):
return 1
else:
elif isinstance(item, BitVec):
if item.symbolic:
raise TypeError("Got a symbolic BitVecRef")
return item.value
elif isinstance(item, Bool):
value = item.value
if value is None:
raise TypeError("Symbolic boolref encountered")
try:
return simplify(item).as_long()
except AttributeError:
raise TypeError("Got a symbolic BitVecRef")
return value
def concrete_int_from_bytes(concrete_bytes: bytes, start_index: int) -> int:
concrete_bytes = [
byte.as_long() if type(byte) == BitVecNumRef else byte
byte.value if isinstance(byte, BitVec) and not byte.symbolic else byte
for byte in concrete_bytes
]
integer_bytes = concrete_bytes[start_index : start_index + 32]
@ -110,7 +97,7 @@ def concrete_int_to_bytes(val):
# logging.debug("concrete_int_to_bytes " + str(val))
if type(val) == int:
return val.to_bytes(32, byteorder="big")
return (simplify(val).as_long()).to_bytes(32, byteorder="big")
return (simplify(val).value).to_bytes(32, byteorder="big")
def bytearray_to_int(arr):

@ -10,7 +10,9 @@ from mythril.laser.smt.bitvec import (
UDiv,
)
from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not
from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver
import z3
@ -80,4 +82,4 @@ class _Z3SymbolFactory(SymbolFactory):
# This is the instance that other parts of mythril should use
symbol_factory: SymbolFactory = _Z3SymbolFactory()
symbol_factory: SymbolFactory = _SmtSymbolFactory()

@ -13,7 +13,7 @@ class BaseArray:
raise ValueError(
"Instance of BaseArray, does not support getitem with slices"
)
return z3.Select(self.raw, item.raw)
return BitVec(z3.Select(self.raw, item.raw))
def __setitem__(self, key: BitVec, value: BitVec):
""" Sets an item in the array, key can be symbolic"""

@ -13,6 +13,9 @@ class BitVec(Expression):
def __init__(self, raw, annotations=None):
super().__init__(raw, annotations)
def size(self):
return self.raw.size()
@property
def symbolic(self):
""" Returns whether this symbol doesn't have a concrete value """
@ -27,13 +30,18 @@ class BitVec(Expression):
assert isinstance(self.raw, z3.BitVecNumRef)
return self.raw.as_long()
def __add__(self, other: "BV") -> "BV":
def __add__(self, other) -> "BV":
""" Create an addition expression """
if isinstance(other, int):
return BitVec(self.raw + other, annotations=self.annotations)
union = self.annotations + other.annotations
return BitVec(self.raw + other.raw, annotations=union)
def __sub__(self, other: "BV") -> "BV":
def __sub__(self, other):
""" Create a subtraction expression """
if isinstance(other, int):
return BitVec(self.raw - other, annotations=self.annotations)
union = self.annotations + other.annotations
return BitVec(self.raw - other.raw, annotations=union)
@ -70,34 +78,50 @@ class BitVec(Expression):
def __gt__(self, other: "BV") -> Bool:
""" Create a signed greater than expression """
union = self.annotations + other.annotations
return Bool(self.raw < other.raw, annotations=union)
return Bool(self.raw > other.raw, annotations=union)
def __eq__(self, other: "BV") -> Bool:
def __eq__(self, other) -> Bool:
""" Create an equality expression """
if not isinstance(other, BitVec):
return Bool(self.raw == other, annotations=self.annotations)
union = self.annotations + other.annotations
return Bool(self.raw == other.raw, annotations=union)
def __ne__(self, other) -> Bool:
""" Create an inequality expression """
if not isinstance(other, BitVec):
return Bool(self.raw != other, annotations=self.annotations)
union = self.annotations + other.annotations
return Bool(self.raw != other.raw, annotations=union)
def If(a: Bool, b: BitVec, c: BitVec):
""" Create an if-then-else expression """
union = a.annotations + b.annotations + c.annotations
return BitVec(z3.If(a, b, c), union)
return BitVec(z3.If(a.raw, b.raw, c.raw), union)
def UGT(a: BitVec, b: BitVec) -> Bool:
""" Create an unsigned greater than expression """
annotations = a.annotations + b.annotations
return Bool(z3.UGT(a, b), annotations)
return Bool(z3.UGT(a.raw, b.raw), annotations)
def ULT(a: BitVec, b: BitVec) -> Bool:
""" Create an unsigned less than expression """
annotations = a.annotations + b.annotations
return Bool(z3.ULT(a, b), annotations)
return Bool(z3.ULT(a.raw, b.raw), annotations)
def Concat(*args) -> BitVec:
""" Create a concatenation expression """
# The following statement is used if a list is provided as an argument to concat
if len(args) == 1 and isinstance(args[0], list):
args = args[0]
nraw = z3.Concat([a.raw for a in args])
annotations = []
for bv in args:

@ -26,6 +26,7 @@ class Bool(Expression):
@property
def value(self) -> Union[bool, None]:
""" Returns the concrete value of this bool if concrete, otherwise None"""
self.simplify()
if self.is_true:
return True
elif self.is_false:
@ -33,12 +34,36 @@ class Bool(Expression):
else:
return None
def __eq__(self, other):
if isinstance(other, Expression):
return Bool(self.raw == other.raw, self.annotations + other.annotations)
return Bool(self.raw == other, self.annotations)
def __ne__(self, other):
if isinstance(other, Expression):
return Bool(self.raw != other.raw, self.annotations + other.annotations)
return Bool(self.raw != other, self.annotations)
def __bool__(self):
if self.value is not None:
return self.value
else:
raise AttributeError("Can not evalutate symbolic bool value")
def Or(a: Bool, b: Bool):
union = a.annotations + b.annotations
return Bool(z3.Or(a.raw, b.raw), annotations=union)
def Not(a: Bool):
return Bool(z3.Not(a.raw), a.annotations)
def is_false(a: Bool) -> bool:
""" Returns whether the provided bool can be simplified to false"""
return is_false(a)
return z3.is_false(a.raw)
def is_true(a: Bool) -> bool:
""" Returns whether the provided bool can be simplified to true"""
return is_true(a)
return z3.is_true(a.raw)

@ -30,3 +30,4 @@ class Expression:
def simplify(expression: Expression):
""" Simplifies the expression """
expression.simplify()
return expression

@ -14,12 +14,12 @@ class Solver:
""" Sets the timeout that will be used by this solver"""
self.raw.set(timeout=timeout)
def add(self, constraints: list[Bool]) -> None:
def add(self, constraints: list) -> None:
""" Adds the constraints to this solver """
constraints = [c.raw for c in constraints]
self.raw.add(constraints)
def append(self, constraints: list[Bool]) -> None:
def append(self, constraints: list) -> None:
""" Adds the constraints to this solver """
constraints = [c.raw for c in constraints]
self.raw.add(constraints)

@ -2,6 +2,7 @@ from mythril.laser.ethereum.svm import LaserEVM
from mythril.laser.ethereum.state.account import Account
from mythril.disassembler.disassembly import Disassembly
from mythril.laser.ethereum.transaction.concolic import execute_message_call
from mythril.laser.smt import Expression, BitVec
from mythril.analysis.solver import get_model
from datetime import datetime
@ -133,10 +134,10 @@ def test_vmtest(
for index, value in details["storage"].items():
expected = int(value, 16)
actual = account.storage[int(index, 16)]
if isinstance(actual, ExprRef):
actual = model.eval(actual)
if isinstance(actual, Expression):
actual = actual.value
actual = (
1 if actual == True else 0 if actual == False else actual
1 if actual is True else 0 if actual is False else actual
) # Comparisons should be done with == than 'is' here as actual can be a BoolRef
else:
if type(actual) == bytes:

Loading…
Cancel
Save