Make calldata compatible with partial abstraction, and add missing features to smt

pull/813/head
Joran Honig 6 years ago
parent 12ca4426d6
commit 4b0cb3e397
  1. 9
      mythril/analysis/modules/integer.py
  2. 3
      mythril/analysis/solver.py
  3. 4
      mythril/laser/ethereum/call.py
  4. 5
      mythril/laser/ethereum/instructions.py
  5. 1
      mythril/laser/ethereum/keccak.py
  6. 19
      mythril/laser/ethereum/state/calldata.py
  7. 20
      mythril/laser/ethereum/state/memory.py
  8. 1
      mythril/laser/ethereum/taint_analysis.py
  9. 7
      mythril/laser/smt/__init__.py
  10. 30
      mythril/laser/smt/bitvec.py
  11. 1
      mythril/laser/smt/bool.py
  12. 3
      mythril/laser/smt/expression.py
  13. 19
      mythril/laser/smt/solver.py

@ -5,7 +5,14 @@ from mythril.exceptions import UnsatError
from mythril.laser.ethereum.taint_analysis import TaintRunner from mythril.laser.ethereum.taint_analysis import TaintRunner
from mythril.analysis.modules.base import DetectionModule from mythril.analysis.modules.base import DetectionModule
from mythril.laser.smt import BVAddNoOverflow, BVSubNoUnderflow, BVMulNoOverflow, BitVec, symbol_factory, Not from mythril.laser.smt import (
BVAddNoOverflow,
BVSubNoUnderflow,
BVMulNoOverflow,
BitVec,
symbol_factory,
Not,
)
import copy import copy
import logging import logging

@ -115,7 +115,8 @@ def get_transaction_sequence(global_state, constraints):
) )
concrete_transactions[tx_id]["call_value"] = ( concrete_transactions[tx_id]["call_value"] = (
"0x%x" % model.eval(transaction.call_value.raw, model_completion=True).as_long() "0x%x"
% model.eval(transaction.call_value.raw, model_completion=True).as_long()
) )
concrete_transactions[tx_id]["caller"] = "0x" + ( concrete_transactions[tx_id]["caller"] = "0x" + (
"%x" % model.eval(transaction.caller.raw, model_completion=True).as_long() "%x" % model.eval(transaction.caller.raw, model_completion=True).as_long()

@ -60,7 +60,9 @@ def get_call_parameters(
def get_callee_address( def get_callee_address(
global_state: GlobalState, dynamic_loader: DynLoader, symbolic_to_address: Expression global_state: GlobalState,
dynamic_loader: DynLoader,
symbolic_to_address: Expression,
): ):
""" """
Gets the address of the callee Gets the address of the callee

@ -86,9 +86,7 @@ class StateTransition(object):
value = global_state.current_transaction.gas_limit.value value = global_state.current_transaction.gas_limit.value
if value is None: if value is None:
return return
global_state.current_transaction.gas_limit = ( global_state.current_transaction.gas_limit = value
value
)
if ( if (
global_state.mstate.min_gas_used global_state.mstate.min_gas_used
>= global_state.current_transaction.gas_limit >= global_state.current_transaction.gas_limit
@ -139,7 +137,6 @@ class Instruction:
op = "swap" op = "swap"
elif self.op_code.startswith("LOG"): elif self.op_code.startswith("LOG"):
op = "log" op = "log"
print(global_state.get_current_instruction())
instruction_mutator = ( instruction_mutator = (
getattr(self, op + "_", None) getattr(self, op + "_", None)
if not post if not post

@ -1,5 +1,6 @@
from mythril.laser.smt import Expression from mythril.laser.smt import Expression
class KeccakFunctionManager: class KeccakFunctionManager:
def __init__(self): def __init__(self):
self.keccak_expression_mapping = {} self.keccak_expression_mapping = {}

@ -9,6 +9,7 @@ from mythril.laser.ethereum.util import get_concrete_int
from z3 import Model from z3 import Model
from z3.z3types import Z3Exception from z3.z3types import Z3Exception
class CalldataType(Enum): class CalldataType(Enum):
CONCRETE = 1 CONCRETE = 1
SYMBOLIC = 2 SYMBOLIC = 2
@ -142,22 +143,26 @@ class SymbolicCalldata(BaseCalldata):
Initializes the SymbolicCalldata object Initializes the SymbolicCalldata object
:param tx_id: Id of the transaction that the calldata is for. :param tx_id: Id of the transaction that the calldata is for.
""" """
self._size = BitVec(str(tx_id) + "_calldatasize", 256) self._size = symbol_factory.BitVecSym(str(tx_id) + "_calldatasize", 256)
self._calldata = Array( self._calldata = Array("{}_calldata".format(tx_id), 256, 8)
"{}_calldata".format(tx_id), 256, 8
)
super().__init__(tx_id) super().__init__(tx_id)
def _load(self, item: Union[int, Expression]) -> Any: def _load(self, item: Union[int, Expression]) -> Any:
item = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item item = symbol_factory.BitVecVal(item, 256) if isinstance(item, int) else item
return simplify(If(item < self._size, simplify(self._calldata[item]), 0)) return simplify(
If(
item < self._size,
simplify(self._calldata[item]),
symbol_factory.BitVecVal(0, 8),
)
)
def concrete(self, model: Model) -> list: def concrete(self, model: Model) -> list:
concrete_length = get_concrete_int(model.eval(self.size, model_completion=True)) concrete_length = model.eval(self.size.raw, model_completion=True).as_long()
result = [] result = []
for i in range(concrete_length): for i in range(concrete_length):
value = self._load(i) value = self._load(i)
c_value = get_concrete_int(model.eval(value, model_completion=True)) c_value = model.eval(value.raw, model_completion=True).as_long()
result.append(c_value) result.append(c_value)
return result return result

@ -1,6 +1,14 @@
from typing import Union from typing import Union
from z3 import Z3Exception from z3 import Z3Exception
from mythril.laser.smt import BitVec, symbol_factory, If, Concat, simplify, Bool, Extract from mythril.laser.smt import (
BitVec,
symbol_factory,
If,
Concat,
simplify,
Bool,
Extract,
)
from mythril.laser.ethereum import util from mythril.laser.ethereum import util
@ -36,9 +44,7 @@ class Memory:
assert result.size() == 256 assert result.size() == 256
return result return result
def write_word_at( def write_word_at(self, index: int, value: Union[int, BitVec, bool, Bool]) -> None:
self, index: int, value: Union[int, BitVec, bool, Bool]
) -> None:
""" """
Writes a 32 byte word to memory at the specified index` Writes a 32 byte word to memory at the specified index`
:param index: index to write to :param index: index to write to
@ -58,7 +64,11 @@ class Memory:
self[index : index + 32] = _bytes self[index : index + 32] = _bytes
except (Z3Exception, AttributeError): # BitVector or BoolRef except (Z3Exception, AttributeError): # BitVector or BoolRef
if isinstance(value, Bool): if isinstance(value, Bool):
value_to_write = If(value, symbol_factory.BitVecVal(1, 256), symbol_factory.BitVecVal(0, 256)) value_to_write = If(
value,
symbol_factory.BitVecVal(1, 256),
symbol_factory.BitVecVal(0, 256),
)
else: else:
value_to_write = value value_to_write = value
assert value_to_write.size() == 256 assert value_to_write.size() == 256

@ -8,6 +8,7 @@ from mythril.laser.ethereum.state.global_state import GlobalState
from mythril.analysis.symbolic import SymExecWrapper from mythril.analysis.symbolic import SymExecWrapper
from mythril.laser.smt import Expression from mythril.laser.smt import Expression
class TaintRecord: class TaintRecord:
""" """
TaintRecord contains tainting information for a specific (state, node) TaintRecord contains tainting information for a specific (state, node)

@ -8,11 +8,16 @@ from mythril.laser.smt.bitvec import (
URem, URem,
SRem, SRem,
UDiv, UDiv,
UGE,
Sum,
BVAddNoOverflow,
BVMulNoOverflow,
BVSubNoUnderflow,
) )
from mythril.laser.smt.expression import Expression, simplify from mythril.laser.smt.expression import Expression, simplify
from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not from mythril.laser.smt.bool import Bool, is_true, is_false, Or, Not
from mythril.laser.smt.array import K, Array, BaseArray from mythril.laser.smt.array import K, Array, BaseArray
from mythril.laser.smt.solver import Solver from mythril.laser.smt.solver import Solver, Optimize
import z3 import z3

@ -57,6 +57,8 @@ class BitVec(Expression):
def __and__(self, other: "BV") -> "BV": def __and__(self, other: "BV") -> "BV":
""" Create an and expression """ """ Create an and expression """
if not isinstance(other, BitVec):
other = BitVec(z3.BitVecVal(other, 256))
union = self.annotations + other.annotations union = self.annotations + other.annotations
return BitVec(self.raw & other.raw, annotations=union) return BitVec(self.raw & other.raw, annotations=union)
@ -99,6 +101,12 @@ class BitVec(Expression):
def If(a: Bool, b: BitVec, c: BitVec): def If(a: Bool, b: BitVec, c: BitVec):
""" Create an if-then-else expression """ """ Create an if-then-else expression """
if not isinstance(a, Expression):
a = Bool(z3.BoolVal(a))
if not isinstance(b, Expression):
b = BitVec(z3.BitVecVal(b, 256))
if not isinstance(c, Expression):
c = BitVec(z3.BitVecVal(c, 256))
union = a.annotations + b.annotations + c.annotations union = a.annotations + b.annotations + c.annotations
return BitVec(z3.If(a.raw, b.raw, c.raw), union) return BitVec(z3.If(a.raw, b.raw, c.raw), union)
@ -109,6 +117,11 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
return Bool(z3.UGT(a.raw, b.raw), annotations) return Bool(z3.UGT(a.raw, b.raw), annotations)
def UGE(a: BitVec, b:BitVec) -> Bool:
annotations = a.annotations + b.annotations
return Bool(z3.UGE(a.raw, b.raw), annotations)
def ULT(a: BitVec, b: BitVec) -> Bool: def ULT(a: BitVec, b: BitVec) -> Bool:
""" Create an unsigned less than expression """ """ Create an unsigned less than expression """
annotations = a.annotations + b.annotations annotations = a.annotations + b.annotations
@ -159,3 +172,20 @@ def Sum(*args) -> BitVec:
for bv in args: for bv in args:
annotations += bv.annotations annotations += bv.annotations
return BitVec(nraw, annotations) return BitVec(nraw, annotations)
def BVAddNoOverflow(a: BitVec, b: BitVec, signed: bool) -> Bool:
return Bool(z3.BVAddNoOverflow(a.raw, b.raw, signed))
def BVMulNoOverflow(a: BitVec, b: BitVec, signed: bool) -> Bool:
return Bool(z3.BVMulNoOverflow(a.raw, b.raw, signed))
def BVSubNoUnderflow(a: BitVec, b: BitVec, signed: bool) -> Bool:
if not isinstance(a, Expression):
a = BitVec(z3.BitVecVal(a, 256))
if not isinstance(b, Expression):
b = BitVec(z3.BitVecVal(b, 256))
return Bool(z3.BVSubNoUnderflow(a.raw, b.raw, signed))

@ -48,6 +48,7 @@ class Bool(Expression):
if self.value is not None: if self.value is not None:
return self.value return self.value
else: else:
return False
raise AttributeError("Can not evalutate symbolic bool value") raise AttributeError("Can not evalutate symbolic bool value")
def Or(a: Bool, b: Bool): def Or(a: Bool, b: Bool):

@ -26,6 +26,9 @@ class Expression:
""" Simplifies this expression """ """ Simplifies this expression """
self.raw = z3.simplify(self.raw) self.raw = z3.simplify(self.raw)
def __repr__(self):
return self.raw.__repr__()
def simplify(expression: Expression): def simplify(expression: Expression):
""" Simplifies the expression """ """ Simplifies the expression """

@ -1,5 +1,6 @@
import z3 import z3
from mythril.laser.smt.bool import Bool from mythril.laser.smt.bool import Bool
from mythril.laser.smt.expression import Expression
class Solver: class Solver:
@ -16,11 +17,17 @@ class Solver:
def add(self, constraints: list) -> None: def add(self, constraints: list) -> None:
""" Adds the constraints to this solver """ """ Adds the constraints to this solver """
if not isinstance(constraints, list):
self.raw.add(constraints.raw)
return
constraints = [c.raw for c in constraints] constraints = [c.raw for c in constraints]
self.raw.add(constraints) self.raw.add(constraints)
def append(self, constraints: list) -> None: def append(self, constraints: list) -> None:
""" Adds the constraints to this solver """ """ Adds the constraints to this solver """
if not isinstance(constraints, list):
self.raw.append(constraints.raw)
return
constraints = [c.raw for c in constraints] constraints = [c.raw for c in constraints]
self.raw.add(constraints) self.raw.add(constraints)
@ -35,3 +42,15 @@ class Solver:
def pop(self, num) -> None: def pop(self, num) -> None:
self.raw.pop(num) self.raw.pop(num)
class Optimize(Solver):
def __init__(self):
super().__init__()
self.raw = z3.Optimize()
def minimize(self, element: Expression):
self.raw.minimize(element.raw)
def maximize(self, element: Expression):
self.raw.maximize(element.raw)

Loading…
Cancel
Save