from mythril . laser . smt import Solver , symbol_factory , bitvec
import z3
import pytest
import operator
@pytest . mark . parametrize (
" operation,expected " ,
[
( operator . add , z3 . unsat ) ,
( operator . sub , z3 . unsat ) ,
( operator . and_ , z3 . sat ) ,
( operator . or_ , z3 . sat ) ,
( operator . xor , z3 . unsat ) ,
] ,
)
def test_bitvecfunc_arithmetic ( operation , expected ) :
# Arrange
s = Solver ( )
input_ = symbol_factory . BitVecVal ( 1 , 8 )
bvf = symbol_factory . BitVecFuncSym ( " bvf " , " sha3 " , 256 , input_ = input_ )
x = symbol_factory . BitVecSym ( " x " , 256 )
y = symbol_factory . BitVecSym ( " y " , 256 )
# Act
s . add ( x != y )
s . add ( operation ( bvf , x ) == operation ( y , bvf ) )
# Assert
assert s . check ( ) == expected
@pytest . mark . parametrize (
" operation,expected " ,
[
( operator . eq , z3 . sat ) ,
( operator . ne , z3 . unsat ) ,
( operator . lt , z3 . unsat ) ,
( operator . le , z3 . sat ) ,
( operator . gt , z3 . unsat ) ,
( operator . ge , z3 . sat ) ,
( bitvec . UGT , z3 . unsat ) ,
( bitvec . UGE , z3 . sat ) ,
( bitvec . ULT , z3 . unsat ) ,
( bitvec . ULE , z3 . sat ) ,
] ,
)
def test_bitvecfunc_bitvecfunc_comparison ( operation , expected ) :
# Arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecSym ( " input2 " , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncSym ( " bvf2 " , " sha3 " , 256 , input_ = input2 )
# Act
s . add ( operation ( bvf1 , bvf2 ) )
s . add ( input1 == input2 )
# Assert
assert s . check ( ) == expected
def test_bitvecfunc_bitvecfuncval_comparison ( ) :
# Arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecVal ( 1337 , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncVal ( 12345678910 , " sha3 " , 256 , input_ = input2 )
# Act
s . add ( bvf1 == bvf2 )
# Assert
assert s . check ( ) == z3 . sat
assert s . model ( ) . eval ( input2 . raw ) == 1337
def test_bitvecfunc_nested_comparison ( ) :
# arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecSym ( " input2 " , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncSym ( " bvf2 " , " sha3 " , 256 , input_ = bvf1 )
bvf3 = symbol_factory . BitVecFuncSym ( " bvf3 " , " sha3 " , 256 , input_ = input2 )
bvf4 = symbol_factory . BitVecFuncSym ( " bvf4 " , " sha3 " , 256 , input_ = bvf3 )
# Act
s . add ( input1 == input2 )
s . add ( bvf2 == bvf4 )
# Assert
assert s . check ( ) == z3 . sat
def test_bitvecfunc_unequal_nested_comparison ( ) :
# arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecSym ( " input2 " , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncSym ( " bvf2 " , " sha3 " , 256 , input_ = bvf1 )
bvf3 = symbol_factory . BitVecFuncSym ( " bvf3 " , " sha3 " , 256 , input_ = input2 )
bvf4 = symbol_factory . BitVecFuncSym ( " bvf4 " , " sha3 " , 256 , input_ = bvf3 )
# Act
s . add ( input1 != input2 )
s . add ( bvf2 == bvf4 )
# Assert
assert s . check ( ) == z3 . unsat
def test_bitvecfunc_ext_nested_comparison ( ) :
# arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecSym ( " input2 " , 256 )
input3 = symbol_factory . BitVecSym ( " input3 " , 256 )
input4 = symbol_factory . BitVecSym ( " input4 " , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncSym ( " bvf2 " , " sha3 " , 256 , input_ = bvf1 + input3 )
bvf3 = symbol_factory . BitVecFuncSym ( " bvf3 " , " sha3 " , 256 , input_ = input2 )
bvf4 = symbol_factory . BitVecFuncSym ( " bvf4 " , " sha3 " , 256 , input_ = bvf3 + input4 )
# Act
s . add ( input1 == input2 )
s . add ( input3 == input4 )
s . add ( bvf2 == bvf4 )
# Assert
assert s . check ( ) == z3 . sat
def test_bitvecfunc_ext_unequal_nested_comparison ( ) :
# Arrange
s = Solver ( )
input1 = symbol_factory . BitVecSym ( " input1 " , 256 )
input2 = symbol_factory . BitVecSym ( " input2 " , 256 )
input3 = symbol_factory . BitVecSym ( " input3 " , 256 )
input4 = symbol_factory . BitVecSym ( " input4 " , 256 )
bvf1 = symbol_factory . BitVecFuncSym ( " bvf1 " , " sha3 " , 256 , input_ = input1 )
bvf2 = symbol_factory . BitVecFuncSym ( " bvf2 " , " sha3 " , 256 , input_ = bvf1 + input3 )
bvf3 = symbol_factory . BitVecFuncSym ( " bvf3 " , " sha3 " , 256 , input_ = input2 )
bvf4 = symbol_factory . BitVecFuncSym ( " bvf4 " , " sha3 " , 256 , input_ = bvf3 + input4 )
# Act
s . add ( input1 == input2 )
s . add ( input3 != input4 )
s . add ( bvf2 == bvf4 )
# Assert
assert s . check ( ) == z3 . unsat