Create helper functions for operations in bitvec.py

pull/901/head
Nathan 6 years ago
parent b9214bcb41
commit 0a992bb6d9
  1. 151
      mythril/laser/smt/bitvec.py
  2. 6
      tests/laser/smt/bitvecfunc_test.py

@ -1,10 +1,10 @@
"""This module provides classes for an SMT abstraction of bit vectors."""
from typing import Union, overload, List, cast, Any, Optional
from typing import Union, overload, List, cast, Any, Optional, Callable
import z3
from mythril.laser.smt.bool import Bool, And
from mythril.laser.smt.bool import Bool, And, Or
from mythril.laser.smt.expression import Expression
Annotations = List[Any]
@ -212,6 +212,48 @@ class BitVec(Expression[z3.BitVecRef]):
return Bool(cast(z3.BoolRef, self.raw != other.raw), annotations=union)
def _comparison_helper(
a: BitVec, b: BitVec, operation: Callable, default_value: bool, inputs_equal: bool
) -> Bool:
annotations = a.annotations + b.annotations
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(operation(a.raw, b.raw), annotations=annotations)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(default_value), annotations=annotations)
return And(
Bool(operation(a.raw, b.raw), annotations=annotations),
a.input_ == b.input_ if inputs_equal else a.input_ != b.input_,
)
return Bool(operation(a.raw, b.raw), annotations)
def _arithmetic_helper(a: BitVec, b: BitVec, operation: Callable) -> BitVec:
raw = operation(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=a.func_name, input_=a.input_, annotations=union
)
elif isinstance(b, BitVecFunc):
return BitVecFunc(
raw=raw, func_name=b.func_name, input_=b.input_, annotations=union
)
return BitVec(raw, annotations=union)
def If(a: Union[Bool, bool], b: Union[BitVec, int], c: Union[BitVec, int]) -> BitVec:
"""Create an if-then-else expression.
@ -239,24 +281,7 @@ def UGT(a: BitVec, b: BitVec) -> Bool:
:param b:
:return:
"""
annotations = a.annotations + b.annotations
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(z3.UGT(a.raw, b.raw), annotations=annotations)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(False), annotations=annotations)
return And(
Bool(z3.UGT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_
)
return Bool(z3.UGT(a.raw, b.raw), annotations)
return _comparison_helper(a, b, z3.UGT, default_value=False, inputs_equal=False)
def UGE(a: BitVec, b: BitVec) -> Bool:
@ -266,24 +291,7 @@ def UGE(a: BitVec, b: BitVec) -> Bool:
:param b:
:return:
"""
annotations = a.annotations + b.annotations
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(z3.UGE(a.raw, b.raw), annotations=annotations)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(False), annotations=annotations)
return And(
Bool(z3.UGE(a.raw, b.raw), annotations=annotations), a.input_ != b.input_
)
return Bool(z3.UGE(a.raw, b.raw), annotations)
return Or(UGT(a, b), a == b)
def ULT(a: BitVec, b: BitVec) -> Bool:
@ -293,24 +301,17 @@ def ULT(a: BitVec, b: BitVec) -> Bool:
:param b:
:return:
"""
annotations = a.annotations + b.annotations
if isinstance(a, BitVecFunc):
if not a.symbolic and not b.symbolic:
return Bool(z3.ULT(a.raw, b.raw), annotations=annotations)
return _comparison_helper(a, b, z3.ULT, default_value=False, inputs_equal=False)
if (
not isinstance(b, BitVecFunc)
or not a.func_name
or not a.input_
or not a.func_name == b.func_name
):
return Bool(z3.BoolVal(False), annotations=annotations)
return And(
Bool(z3.ULT(a.raw, b.raw), annotations=annotations), a.input_ != b.input_
)
def ULE(a: BitVec, b: BitVec) -> Bool:
"""Create an unsigned less than expression.
return Bool(z3.ULT(a.raw, b.raw), annotations)
:param a:
:param b:
:return:
"""
return Or(ULT(a, b), a == b)
@overload
@ -343,7 +344,9 @@ def Concat(*args: Union[BitVec, List[BitVec]]) -> BitVec:
if bitvecfunc:
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(raw=nraw, func_name=None, input_=None, annotations=annotations)
return BitVecFunc(
raw=nraw, func_name=None, input_=None, annotations=annotations
)
return BitVec(nraw, annotations)
@ -359,7 +362,9 @@ def Extract(high: int, low: int, bv: BitVec) -> BitVec:
raw = z3.Extract(high, low, bv.raw)
if isinstance(bv, BitVecFunc):
# Is there a better value to set func_name and input to in this case?
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=bv.annotations)
return BitVecFunc(
raw=raw, func_name=None, input_=None, annotations=bv.annotations
)
return BitVec(raw, annotations=bv.annotations)
@ -371,17 +376,7 @@ def URem(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
raw = z3.URem(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union)
elif isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union)
return BitVec(raw, annotations=union)
return _arithmetic_helper(a, b, z3.URem)
def SRem(a: BitVec, b: BitVec) -> BitVec:
@ -391,17 +386,7 @@ def SRem(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
raw = z3.SRem(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union)
elif isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union)
return BitVec(raw, annotations=union)
return _arithmetic_helper(a, b, z3.SRem)
def UDiv(a: BitVec, b: BitVec) -> BitVec:
@ -411,17 +396,7 @@ def UDiv(a: BitVec, b: BitVec) -> BitVec:
:param b:
:return:
"""
raw = z3.UDiv(a.raw, b.raw)
union = a.annotations + b.annotations
if isinstance(a, BitVecFunc) and isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=None, input_=None, annotations=union)
elif isinstance(a, BitVecFunc):
return BitVecFunc(raw=raw, func_name=a.func_name, input_=a.input_, annotations=union)
elif isinstance(b, BitVecFunc):
return BitVecFunc(raw=raw, func_name=b.func_name, input_=b.input_, annotations=union)
return BitVec(raw, annotations=union)
return _arithmetic_helper(a, b, z3.UDiv)
def Sum(*args: BitVec) -> BitVec:

@ -1,4 +1,4 @@
from mythril.laser.smt import Solver, symbol_factory
from mythril.laser.smt import Solver, symbol_factory, bitvec
import z3
import pytest
@ -42,6 +42,10 @@ def test_bitvecfunc_arithmetic(operation, expected):
(operator.le, z3.sat),
(operator.gt, z3.unsat),
(operator.ge, z3.sat),
(bitvec.UGT, z3.unsat),
(bitvec.UGE, z3.sat),
(bitvec.ULT, z3.unsat),
(bitvec.ULE, z3.sat),
],
)
def test_bitvecfunc_bitvecfunc_comparison(operation, expected):

Loading…
Cancel
Save