You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
369 lines
11 KiB
369 lines
11 KiB
import os
|
|
import platform
|
|
from ctypes import *
|
|
|
|
BN254 = 0
|
|
BLS12_381 = 5
|
|
MCLBN_FR_UNIT_SIZE = 4
|
|
MCLBN_FP_UNIT_SIZE = 6
|
|
|
|
FR_SIZE = MCLBN_FR_UNIT_SIZE
|
|
G1_SIZE = MCLBN_FP_UNIT_SIZE * 3
|
|
G2_SIZE = MCLBN_FP_UNIT_SIZE * 6
|
|
GT_SIZE = MCLBN_FP_UNIT_SIZE * 12
|
|
|
|
SEC_SIZE = FR_SIZE * 2
|
|
PUB_SIZE = G1_SIZE + G2_SIZE
|
|
G1_CIPHER_SIZE = G1_SIZE * 2
|
|
G2_CIPHER_SIZE = G2_SIZE * 2
|
|
GT_CIPHER_SIZE = GT_SIZE * 4
|
|
|
|
MCLBN_COMPILED_TIME_VAR = (MCLBN_FR_UNIT_SIZE * 10) + MCLBN_FP_UNIT_SIZE
|
|
|
|
Buffer = c_ubyte * 2304
|
|
lib = None
|
|
|
|
def init(curveType=BN254):
|
|
global lib
|
|
name = platform.system()
|
|
if name == 'Linux':
|
|
libName = 'libmclshe384_256.so'
|
|
elif name == 'Darwin':
|
|
libName = 'libmclshe384_256.dylib'
|
|
elif name == 'Windows':
|
|
libName = 'mclshe384_256.dll'
|
|
else:
|
|
raise RuntimeError("not support yet", name)
|
|
lib = cdll.LoadLibrary(libName)
|
|
ret = lib.sheInit(curveType, MCLBN_COMPILED_TIME_VAR)
|
|
if ret != 0:
|
|
raise RuntimeError("sheInit", ret)
|
|
lib.mclBn_verifyOrderG1(0)
|
|
lib.mclBn_verifyOrderG2(0)
|
|
# custom setup for a function which returns pointer
|
|
lib.shePrecomputedPublicKeyCreate.restype = c_void_p
|
|
|
|
def setRangeForDLP(hashSize):
|
|
ret = lib.sheSetRangeForDLP(hashSize)
|
|
if ret != 0:
|
|
raise RuntimeError("setRangeForDLP", ret)
|
|
|
|
def setTryNum(tryNum):
|
|
lib.sheSetTryNum(tryNum)
|
|
|
|
def _hexStr(v):
|
|
s = ""
|
|
for x in v:
|
|
s += format(x, '02x')
|
|
return s
|
|
|
|
def _serialize(self, f):
|
|
buf = Buffer()
|
|
ret = f(byref(buf), len(buf), byref(self.v))
|
|
if ret == 0:
|
|
raise RuntimeError("serialize")
|
|
return buf[0:ret]
|
|
|
|
def _deserialize(cstr, f, buf):
|
|
x = cstr()
|
|
ca = (c_ubyte * len(buf))(*buf)
|
|
ret = f(byref(x.v), byref(ca), len(buf))
|
|
if ret == 0:
|
|
raise RuntimeError("deserialize")
|
|
return x
|
|
|
|
class CipherTextG1(Structure):
|
|
_fields_ = [("v", c_ulonglong * G1_CIPHER_SIZE)]
|
|
def serialize(self):
|
|
return _serialize(self, lib.sheCipherTextG1Serialize)
|
|
def serializeToHexStr(self):
|
|
return _hexStr(self.serialize())
|
|
|
|
class CipherTextG2(Structure):
|
|
_fields_ = [("v", c_ulonglong * G2_CIPHER_SIZE)]
|
|
def serialize(self):
|
|
return _serialize(self, lib.sheCipherTextG2Serialize)
|
|
def serializeToHexStr(self):
|
|
return _hexStr(self.serialize())
|
|
|
|
class CipherTextGT(Structure):
|
|
_fields_ = [("v", c_ulonglong * GT_CIPHER_SIZE)]
|
|
def serialize(self):
|
|
return _serialize(self, lib.sheCipherTextGTSerialize)
|
|
def serializeToHexStr(self):
|
|
return _hexStr(self.serialize())
|
|
|
|
def _enc(CT, enc, encIntVec, neg, p, m):
|
|
c = CT()
|
|
if -0x80000000 <= m <= 0x7fffffff:
|
|
ret = enc(byref(c.v), p, m)
|
|
if ret != 0:
|
|
raise RuntimeError("enc", m)
|
|
return c
|
|
if m < 0:
|
|
minus = True
|
|
m = -m
|
|
else:
|
|
minus = False
|
|
if m >= 1 << (MCLBN_FR_UNIT_SIZE * 64):
|
|
raise RuntimeError("enc:too large m", m)
|
|
a = []
|
|
while m > 0:
|
|
a.append(m & 0xffffffff)
|
|
m >>= 32
|
|
ca = (c_uint * len(a))(*a)
|
|
ret = encIntVec(byref(c.v), p, byref(ca), sizeof(ca))
|
|
if ret != 0:
|
|
raise RuntimeError("enc:IntVec", m)
|
|
if minus:
|
|
ret = neg(byref(c.v), byref(c.v))
|
|
if ret != 0:
|
|
raise RuntimeError("enc:neg", m)
|
|
return c
|
|
|
|
class PrecomputedPublicKey(Structure):
|
|
def __init__(self):
|
|
self.p = 0
|
|
def create(self):
|
|
if not self.p:
|
|
self.p = c_void_p(lib.shePrecomputedPublicKeyCreate())
|
|
if self.p == 0:
|
|
raise RuntimeError("PrecomputedPublicKey::create")
|
|
def destroy(self):
|
|
lib.shePrecomputedPublicKeyDestroy(self.p)
|
|
def encG1(self, m):
|
|
return _enc(CipherTextG1, lib.shePrecomputedPublicKeyEncG1, lib.shePrecomputedPublicKeyEncIntVecG1, lib.sheNegG1, self.p, m)
|
|
def encG2(self, m):
|
|
return _enc(CipherTextG2, lib.shePrecomputedPublicKeyEncG2, lib.shePrecomputedPublicKeyEncIntVecG2, lib.sheNegG2, self.p, m)
|
|
def encGT(self, m):
|
|
return _enc(CipherTextGT, lib.shePrecomputedPublicKeyEncGT, lib.shePrecomputedPublicKeyEncIntVecGT, lib.sheNegGT, self.p, m)
|
|
|
|
class PublicKey(Structure):
|
|
_fields_ = [("v", c_ulonglong * PUB_SIZE)]
|
|
def serialize(self):
|
|
return _serialize(self, lib.shePublicKeySerialize)
|
|
def serializeToHexStr(self):
|
|
return _hexStr(self.serialize())
|
|
def encG1(self, m):
|
|
return _enc(CipherTextG1, lib.sheEncG1, lib.sheEncIntVecG1, lib.sheNegG1, byref(self.v), m)
|
|
def encG2(self, m):
|
|
return _enc(CipherTextG2, lib.sheEncG2, lib.sheEncIntVecG2, lib.sheNegG2, byref(self.v), m)
|
|
def encGT(self, m):
|
|
return _enc(CipherTextGT, lib.sheEncGT, lib.sheEncIntVecGT, lib.sheNegGT, byref(self.v), m)
|
|
def createPrecomputedPublicKey(self):
|
|
ppub = PrecomputedPublicKey()
|
|
ppub.create()
|
|
ret = lib.shePrecomputedPublicKeyInit(ppub.p, byref(self.v))
|
|
if ret != 0:
|
|
raise RuntimeError("createPrecomputedPublicKey")
|
|
return ppub
|
|
|
|
class SecretKey(Structure):
|
|
_fields_ = [("v", c_ulonglong * SEC_SIZE)]
|
|
def setByCSPRNG(self):
|
|
ret = lib.sheSecretKeySetByCSPRNG(byref(self.v))
|
|
if ret != 0:
|
|
raise RuntimeError("setByCSPRNG", ret)
|
|
def serialize(self):
|
|
return _serialize(self, lib.sheSecretKeySerialize)
|
|
def serializeToHexStr(self):
|
|
return _hexStr(self.serialize())
|
|
def getPulicKey(self):
|
|
pub = PublicKey()
|
|
lib.sheGetPublicKey(byref(pub.v), byref(self.v))
|
|
return pub
|
|
def dec(self, c):
|
|
m = c_longlong()
|
|
if isinstance(c, CipherTextG1):
|
|
ret = lib.sheDecG1(byref(m), byref(self.v), byref(c.v))
|
|
elif isinstance(c, CipherTextG2):
|
|
ret = lib.sheDecG2(byref(m), byref(self.v), byref(c.v))
|
|
elif isinstance(c, CipherTextGT):
|
|
ret = lib.sheDecGT(byref(m), byref(self.v), byref(c.v))
|
|
if ret != 0:
|
|
raise RuntimeError("dec")
|
|
return m.value
|
|
def isZero(self, c):
|
|
if isinstance(c, CipherTextG1):
|
|
return lib.sheIsZeroG1(byref(self.v), byref(c.v)) == 1
|
|
elif isinstance(c, CipherTextG2):
|
|
return lib.sheIsZeroG2(byref(self.v), byref(c.v)) == 1
|
|
elif isinstance(c, CipherTextGT):
|
|
return lib.sheIsZeroGT(byref(self.v), byref(c.v)) == 1
|
|
raise RuntimeError("dec")
|
|
|
|
def neg(c):
|
|
ret = -1
|
|
if isinstance(c, CipherTextG1):
|
|
out = CipherTextG1()
|
|
ret = lib.sheNegG1(byref(out.v), byref(c.v))
|
|
elif isinstance(c, CipherTextG2):
|
|
out = CipherTextG2()
|
|
ret = lib.sheNegG2(byref(out.v), byref(c.v))
|
|
elif isinstance(c, CipherTextGT):
|
|
out = CipherTextGT()
|
|
ret = lib.sheNegGT(byref(out.v), byref(c.v))
|
|
if ret != 0:
|
|
raise RuntimeError("neg")
|
|
return out
|
|
|
|
def add(cx, cy):
|
|
ret = -1
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG1):
|
|
out = CipherTextG1()
|
|
ret = lib.sheAddG1(byref(out.v), byref(cx.v), byref(cy.v))
|
|
elif isinstance(cx, CipherTextG2) and isinstance(cy, CipherTextG2):
|
|
out = CipherTextG2()
|
|
ret = lib.sheAddG2(byref(out.v), byref(cx.v), byref(cy.v))
|
|
elif isinstance(cx, CipherTextGT) and isinstance(cy, CipherTextGT):
|
|
out = CipherTextGT()
|
|
ret = lib.sheAddGT(byref(out.v), byref(cx.v), byref(cy.v))
|
|
if ret != 0:
|
|
raise RuntimeError("add")
|
|
return out
|
|
|
|
def sub(cx, cy):
|
|
ret = -1
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG1):
|
|
out = CipherTextG1()
|
|
ret = lib.sheSubG1(byref(out.v), byref(cx.v), byref(cy.v))
|
|
elif isinstance(cx, CipherTextG2) and isinstance(cy, CipherTextG2):
|
|
out = CipherTextG2()
|
|
ret = lib.sheSubG2(byref(out.v), byref(cx.v), byref(cy.v))
|
|
elif isinstance(cx, CipherTextGT) and isinstance(cy, CipherTextGT):
|
|
out = CipherTextGT()
|
|
ret = lib.sheSubGT(byref(out.v), byref(cx.v), byref(cy.v))
|
|
if ret != 0:
|
|
raise RuntimeError("sub")
|
|
return out
|
|
|
|
def mul(cx, cy):
|
|
ret = -1
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG2):
|
|
out = CipherTextGT()
|
|
ret = lib.sheMul(byref(out.v), byref(cx.v), byref(cy.v))
|
|
elif isinstance(cx, CipherTextG1) and (isinstance(cy, int) or isinstance(cy, long)):
|
|
return _enc(CipherTextG1, lib.sheMulG1, lib.sheMulIntVecG1, lib.sheNegG1, byref(cx.v), cy)
|
|
elif isinstance(cx, CipherTextG2) and (isinstance(cy, int) or isinstance(cy, long)):
|
|
return _enc(CipherTextG2, lib.sheMulG2, lib.sheMulIntVecG2, lib.sheNegG2, byref(cx.v), cy)
|
|
elif isinstance(cx, CipherTextGT) and (isinstance(cy, int) or isinstance(cy, long)):
|
|
return _enc(CipherTextGT, lib.sheMulGT, lib.sheMulIntVecGT, lib.sheNegGT, byref(cx.v), cy)
|
|
if ret != 0:
|
|
raise RuntimeError("mul")
|
|
return out
|
|
|
|
def deserializeToSecretKey(buf):
|
|
return _deserialize(SecretKey, lib.sheSecretKeyDeserialize, buf)
|
|
|
|
def deserializeToPublicKey(buf):
|
|
return _deserialize(PublicKey, lib.shePublicKeyDeserialize, buf)
|
|
|
|
def deserializeToCipherTextG1(buf):
|
|
return _deserialize(CipherTextG1, lib.sheCipherTextG1Deserialize, buf)
|
|
|
|
def deserializeToCipherTextG2(buf):
|
|
return _deserialize(CipherTextG2, lib.sheCipherTextG2Deserialize, buf)
|
|
|
|
def deserializeToCipherTextGT(buf):
|
|
return _deserialize(CipherTextGT, lib.sheCipherTextGTDeserialize, buf)
|
|
|
|
if __name__ == '__main__':
|
|
init(BLS12_381)
|
|
sec = SecretKey()
|
|
sec.setByCSPRNG()
|
|
print("sec=", sec.serializeToHexStr())
|
|
pub = sec.getPulicKey()
|
|
print("pub=", pub.serializeToHexStr())
|
|
if sec.serialize() != deserializeToSecretKey(sec.serialize()).serialize(): print("err-ser1")
|
|
if pub.serialize() != deserializeToPublicKey(pub.serialize()).serialize(): print("err-ser2")
|
|
|
|
m11 = 1
|
|
m12 = 5
|
|
m21 = 3
|
|
m22 = -4
|
|
c11 = pub.encG1(m11)
|
|
c12 = pub.encG1(m12)
|
|
# dec(enc) for G1
|
|
if sec.dec(c11) != m11: print("err1")
|
|
|
|
# add/sub for G1
|
|
if sec.dec(add(c11, c12)) != m11 + m12: print("err2")
|
|
if sec.dec(sub(c11, c12)) != m11 - m12: print("err3")
|
|
|
|
# add/sub for G2
|
|
c21 = pub.encG2(m21)
|
|
c22 = pub.encG2(m22)
|
|
if sec.dec(c21) != m21: print("err4")
|
|
if sec.dec(add(c21, c22)) != m21 + m22: print("err5")
|
|
if sec.dec(sub(c21, c22)) != m21 - m22: print("err6")
|
|
|
|
# mul const for G1/G2
|
|
if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1")
|
|
if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2")
|
|
|
|
if c11.serialize() != deserializeToCipherTextG1(c11.serialize()).serialize(): print("err-ser3")
|
|
if c21.serialize() != deserializeToCipherTextG2(c21.serialize()).serialize(): print("err-ser3")
|
|
|
|
# large integer
|
|
m1 = 0x140712384712047127412964192876419276341
|
|
m2 = -m1 + 123
|
|
c1 = pub.encG1(m1)
|
|
c2 = pub.encG1(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large11")
|
|
c1 = mul(pub.encG1(1), m1)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large12")
|
|
|
|
c1 = pub.encG2(m1)
|
|
c2 = pub.encG2(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large21")
|
|
c1 = mul(pub.encG2(1), m1)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large22")
|
|
|
|
c1 = pub.encGT(m1)
|
|
c2 = pub.encGT(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large31")
|
|
c1 = mul(pub.encGT(1), m1)
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large32")
|
|
if c1.serialize() != deserializeToCipherTextGT(c1.serialize()).serialize(): print("err-ser4")
|
|
|
|
mt = -56
|
|
ct = pub.encGT(mt)
|
|
if sec.dec(ct) != mt: print("err7")
|
|
|
|
# mul G1 and G2
|
|
if sec.dec(mul(c11, c21)) != m11 * m21: print("err8")
|
|
|
|
if not sec.isZero(pub.encG1(0)): print("err-zero11")
|
|
if sec.isZero(pub.encG1(3)): print("err-zero12")
|
|
if not sec.isZero(pub.encG2(0)): print("err-zero21")
|
|
if sec.isZero(pub.encG2(3)): print("err-zero22")
|
|
if not sec.isZero(pub.encGT(0)): print("err-zero31")
|
|
if sec.isZero(pub.encGT(3)): print("err-zero32")
|
|
|
|
# use precomputedPublicKey for performance
|
|
ppub = pub.createPrecomputedPublicKey()
|
|
c1 = ppub.encG1(m11)
|
|
if sec.dec(c1) != m11: print("err9")
|
|
|
|
# large integer for precomputedPublicKey
|
|
m1 = 0x140712384712047127412964192876419276341
|
|
m2 = -m1 + 123
|
|
c1 = ppub.encG1(m1)
|
|
c2 = ppub.encG1(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err10")
|
|
c1 = ppub.encG2(m1)
|
|
c2 = ppub.encG2(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err11")
|
|
c1 = ppub.encGT(m1)
|
|
c2 = ppub.encGT(m2)
|
|
if sec.dec(add(c1, c2)) != 123: print("err12")
|
|
|
|
import sys
|
|
if sys.version_info.major >= 3:
|
|
import timeit
|
|
N = 100000
|
|
print(str(timeit.timeit("pub.encG1(12)", number=N, globals=globals()) / float(N) * 1e3) + "msec")
|
|
print(str(timeit.timeit("ppub.encG1(12)", number=N, globals=globals()) / float(N) * 1e3) + "msec")
|
|
|
|
ppub.destroy() # necessary to avoid memory leak
|
|
|
|
|