|
|
|
@ -0,0 +1,369 @@ |
|
|
|
|
import os |
|
|
|
|
import platform |
|
|
|
|
from ctypes import * |
|
|
|
|
|
|
|
|
|
BN254 = 0 |
|
|
|
|
BLS12_381 = 5 |
|
|
|
|
MCLBN_FR_UNIT_SIZE = 4 |
|
|
|
|
MCLBN_FP_UNIT_SIZE = 6 |
|
|
|
|
|
|
|
|
|
FR_SIZE = MCLBN_FR_UNIT_SIZE |
|
|
|
|
G1_SIZE = MCLBN_FP_UNIT_SIZE * 3 |
|
|
|
|
G2_SIZE = MCLBN_FP_UNIT_SIZE * 6 |
|
|
|
|
GT_SIZE = MCLBN_FP_UNIT_SIZE * 12 |
|
|
|
|
|
|
|
|
|
SEC_SIZE = FR_SIZE * 2 |
|
|
|
|
PUB_SIZE = G1_SIZE + G2_SIZE |
|
|
|
|
G1_CIPHER_SIZE = G1_SIZE * 2 |
|
|
|
|
G2_CIPHER_SIZE = G2_SIZE * 2 |
|
|
|
|
GT_CIPHER_SIZE = GT_SIZE * 4 |
|
|
|
|
|
|
|
|
|
MCLBN_COMPILED_TIME_VAR = (MCLBN_FR_UNIT_SIZE * 10) + MCLBN_FP_UNIT_SIZE |
|
|
|
|
|
|
|
|
|
Buffer = c_ubyte * 2304 |
|
|
|
|
lib = None |
|
|
|
|
|
|
|
|
|
def init(curveType=BN254): |
|
|
|
|
global lib |
|
|
|
|
name = platform.system() |
|
|
|
|
if name == 'Linux': |
|
|
|
|
libName = 'libmclshe384_256.so' |
|
|
|
|
elif name == 'Darwin': |
|
|
|
|
libName = 'libmclshe384_256.dylib' |
|
|
|
|
elif name == 'Windows': |
|
|
|
|
libName = 'mclshe384_256.dll' |
|
|
|
|
else: |
|
|
|
|
raise RuntimeError("not support yet", name) |
|
|
|
|
lib = cdll.LoadLibrary(libName) |
|
|
|
|
ret = lib.sheInit(curveType, MCLBN_COMPILED_TIME_VAR) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("sheInit", ret) |
|
|
|
|
lib.mclBn_verifyOrderG1(0) |
|
|
|
|
lib.mclBn_verifyOrderG2(0) |
|
|
|
|
# custom setup for a function which returns pointer |
|
|
|
|
lib.shePrecomputedPublicKeyCreate.restype = c_void_p |
|
|
|
|
|
|
|
|
|
def setRangeForDLP(hashSize): |
|
|
|
|
ret = lib.sheSetRangeForDLP(hashSize) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("setRangeForDLP", ret) |
|
|
|
|
|
|
|
|
|
def setTryNum(tryNum): |
|
|
|
|
lib.sheSetTryNum(tryNum) |
|
|
|
|
|
|
|
|
|
def _hexStr(v): |
|
|
|
|
s = "" |
|
|
|
|
for x in v: |
|
|
|
|
s += format(x, '02x') |
|
|
|
|
return s |
|
|
|
|
|
|
|
|
|
def _serialize(self, f): |
|
|
|
|
buf = Buffer() |
|
|
|
|
ret = f(byref(buf), len(buf), byref(self.v)) |
|
|
|
|
if ret == 0: |
|
|
|
|
raise RuntimeError("serialize") |
|
|
|
|
return buf[0:ret] |
|
|
|
|
|
|
|
|
|
def _deserialize(cstr, f, buf): |
|
|
|
|
x = cstr() |
|
|
|
|
ca = (c_ubyte * len(buf))(*buf) |
|
|
|
|
ret = f(byref(x.v), byref(ca), len(buf)) |
|
|
|
|
if ret == 0: |
|
|
|
|
raise RuntimeError("deserialize") |
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
class CipherTextG1(Structure): |
|
|
|
|
_fields_ = [("v", c_ulonglong * G1_CIPHER_SIZE)] |
|
|
|
|
def serialize(self): |
|
|
|
|
return _serialize(self, lib.sheCipherTextG1Serialize) |
|
|
|
|
def serializeToHexStr(self): |
|
|
|
|
return _hexStr(self.serialize()) |
|
|
|
|
|
|
|
|
|
class CipherTextG2(Structure): |
|
|
|
|
_fields_ = [("v", c_ulonglong * G2_CIPHER_SIZE)] |
|
|
|
|
def serialize(self): |
|
|
|
|
return _serialize(self, lib.sheCipherTextG2Serialize) |
|
|
|
|
def serializeToHexStr(self): |
|
|
|
|
return _hexStr(self.serialize()) |
|
|
|
|
|
|
|
|
|
class CipherTextGT(Structure): |
|
|
|
|
_fields_ = [("v", c_ulonglong * GT_CIPHER_SIZE)] |
|
|
|
|
def serialize(self): |
|
|
|
|
return _serialize(self, lib.sheCipherTextGTSerialize) |
|
|
|
|
def serializeToHexStr(self): |
|
|
|
|
return _hexStr(self.serialize()) |
|
|
|
|
|
|
|
|
|
def _enc(CT, enc, encIntVec, neg, p, m): |
|
|
|
|
c = CT() |
|
|
|
|
if -0x80000000 <= m <= 0x7fffffff: |
|
|
|
|
ret = enc(byref(c.v), p, m) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("enc", m) |
|
|
|
|
return c |
|
|
|
|
if m < 0: |
|
|
|
|
minus = True |
|
|
|
|
m = -m |
|
|
|
|
else: |
|
|
|
|
minus = False |
|
|
|
|
if m >= 1 << (MCLBN_FR_UNIT_SIZE * 64): |
|
|
|
|
raise RuntimeError("enc:too large m", m) |
|
|
|
|
a = [] |
|
|
|
|
while m > 0: |
|
|
|
|
a.append(m & 0xffffffff) |
|
|
|
|
m >>= 32 |
|
|
|
|
ca = (c_uint * len(a))(*a) |
|
|
|
|
ret = encIntVec(byref(c.v), p, byref(ca), sizeof(ca)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("enc:IntVec", m) |
|
|
|
|
if minus: |
|
|
|
|
ret = neg(byref(c.v), byref(c.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("enc:neg", m) |
|
|
|
|
return c |
|
|
|
|
|
|
|
|
|
class PrecomputedPublicKey(Structure): |
|
|
|
|
def __init__(self): |
|
|
|
|
self.p = 0 |
|
|
|
|
def create(self): |
|
|
|
|
if not self.p: |
|
|
|
|
self.p = c_void_p(lib.shePrecomputedPublicKeyCreate()) |
|
|
|
|
if self.p == 0: |
|
|
|
|
raise RuntimeError("PrecomputedPublicKey::create") |
|
|
|
|
def destroy(self): |
|
|
|
|
lib.shePrecomputedPublicKeyDestroy(self.p) |
|
|
|
|
def encG1(self, m): |
|
|
|
|
return _enc(CipherTextG1, lib.shePrecomputedPublicKeyEncG1, lib.shePrecomputedPublicKeyEncIntVecG1, lib.sheNegG1, self.p, m) |
|
|
|
|
def encG2(self, m): |
|
|
|
|
return _enc(CipherTextG2, lib.shePrecomputedPublicKeyEncG2, lib.shePrecomputedPublicKeyEncIntVecG2, lib.sheNegG2, self.p, m) |
|
|
|
|
def encGT(self, m): |
|
|
|
|
return _enc(CipherTextGT, lib.shePrecomputedPublicKeyEncGT, lib.shePrecomputedPublicKeyEncIntVecGT, lib.sheNegGT, self.p, m) |
|
|
|
|
|
|
|
|
|
class PublicKey(Structure): |
|
|
|
|
_fields_ = [("v", c_ulonglong * PUB_SIZE)] |
|
|
|
|
def serialize(self): |
|
|
|
|
return _serialize(self, lib.shePublicKeySerialize) |
|
|
|
|
def serializeToHexStr(self): |
|
|
|
|
return _hexStr(self.serialize()) |
|
|
|
|
def encG1(self, m): |
|
|
|
|
return _enc(CipherTextG1, lib.sheEncG1, lib.sheEncIntVecG1, lib.sheNegG1, byref(self.v), m) |
|
|
|
|
def encG2(self, m): |
|
|
|
|
return _enc(CipherTextG2, lib.sheEncG2, lib.sheEncIntVecG2, lib.sheNegG2, byref(self.v), m) |
|
|
|
|
def encGT(self, m): |
|
|
|
|
return _enc(CipherTextGT, lib.sheEncGT, lib.sheEncIntVecGT, lib.sheNegGT, byref(self.v), m) |
|
|
|
|
def createPrecomputedPublicKey(self): |
|
|
|
|
ppub = PrecomputedPublicKey() |
|
|
|
|
ppub.create() |
|
|
|
|
ret = lib.shePrecomputedPublicKeyInit(ppub.p, byref(self.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("createPrecomputedPublicKey") |
|
|
|
|
return ppub |
|
|
|
|
|
|
|
|
|
class SecretKey(Structure): |
|
|
|
|
_fields_ = [("v", c_ulonglong * SEC_SIZE)] |
|
|
|
|
def setByCSPRNG(self): |
|
|
|
|
ret = lib.sheSecretKeySetByCSPRNG(byref(self.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("setByCSPRNG", ret) |
|
|
|
|
def serialize(self): |
|
|
|
|
return _serialize(self, lib.sheSecretKeySerialize) |
|
|
|
|
def serializeToHexStr(self): |
|
|
|
|
return _hexStr(self.serialize()) |
|
|
|
|
def getPulicKey(self): |
|
|
|
|
pub = PublicKey() |
|
|
|
|
lib.sheGetPublicKey(byref(pub.v), byref(self.v)) |
|
|
|
|
return pub |
|
|
|
|
def dec(self, c): |
|
|
|
|
m = c_longlong() |
|
|
|
|
if isinstance(c, CipherTextG1): |
|
|
|
|
ret = lib.sheDecG1(byref(m), byref(self.v), byref(c.v)) |
|
|
|
|
elif isinstance(c, CipherTextG2): |
|
|
|
|
ret = lib.sheDecG2(byref(m), byref(self.v), byref(c.v)) |
|
|
|
|
elif isinstance(c, CipherTextGT): |
|
|
|
|
ret = lib.sheDecGT(byref(m), byref(self.v), byref(c.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("dec") |
|
|
|
|
return m.value |
|
|
|
|
def isZero(self, c): |
|
|
|
|
if isinstance(c, CipherTextG1): |
|
|
|
|
return lib.sheIsZeroG1(byref(self.v), byref(c.v)) == 1 |
|
|
|
|
elif isinstance(c, CipherTextG2): |
|
|
|
|
return lib.sheIsZeroG2(byref(self.v), byref(c.v)) == 1 |
|
|
|
|
elif isinstance(c, CipherTextGT): |
|
|
|
|
return lib.sheIsZeroGT(byref(self.v), byref(c.v)) == 1 |
|
|
|
|
raise RuntimeError("dec") |
|
|
|
|
|
|
|
|
|
def neg(c): |
|
|
|
|
ret = -1 |
|
|
|
|
if isinstance(c, CipherTextG1): |
|
|
|
|
out = CipherTextG1() |
|
|
|
|
ret = lib.sheNegG1(byref(out.v), byref(c.v)) |
|
|
|
|
elif isinstance(c, CipherTextG2): |
|
|
|
|
out = CipherTextG2() |
|
|
|
|
ret = lib.sheNegG2(byref(out.v), byref(c.v)) |
|
|
|
|
elif isinstance(c, CipherTextGT): |
|
|
|
|
out = CipherTextGT() |
|
|
|
|
ret = lib.sheNegGT(byref(out.v), byref(c.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("neg") |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def add(cx, cy): |
|
|
|
|
ret = -1 |
|
|
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG1): |
|
|
|
|
out = CipherTextG1() |
|
|
|
|
ret = lib.sheAddG1(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
elif isinstance(cx, CipherTextG2) and isinstance(cy, CipherTextG2): |
|
|
|
|
out = CipherTextG2() |
|
|
|
|
ret = lib.sheAddG2(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
elif isinstance(cx, CipherTextGT) and isinstance(cy, CipherTextGT): |
|
|
|
|
out = CipherTextGT() |
|
|
|
|
ret = lib.sheAddGT(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("add") |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def sub(cx, cy): |
|
|
|
|
ret = -1 |
|
|
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG1): |
|
|
|
|
out = CipherTextG1() |
|
|
|
|
ret = lib.sheSubG1(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
elif isinstance(cx, CipherTextG2) and isinstance(cy, CipherTextG2): |
|
|
|
|
out = CipherTextG2() |
|
|
|
|
ret = lib.sheSubG2(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
elif isinstance(cx, CipherTextGT) and isinstance(cy, CipherTextGT): |
|
|
|
|
out = CipherTextGT() |
|
|
|
|
ret = lib.sheSubGT(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("sub") |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def mul(cx, cy): |
|
|
|
|
ret = -1 |
|
|
|
|
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG2): |
|
|
|
|
out = CipherTextGT() |
|
|
|
|
ret = lib.sheMul(byref(out.v), byref(cx.v), byref(cy.v)) |
|
|
|
|
elif isinstance(cx, CipherTextG1) and (isinstance(cy, int) or isinstance(cy, long)): |
|
|
|
|
return _enc(CipherTextG1, lib.sheMulG1, lib.sheMulIntVecG1, lib.sheNegG1, byref(cx.v), cy) |
|
|
|
|
elif isinstance(cx, CipherTextG2) and (isinstance(cy, int) or isinstance(cy, long)): |
|
|
|
|
return _enc(CipherTextG2, lib.sheMulG2, lib.sheMulIntVecG2, lib.sheNegG2, byref(cx.v), cy) |
|
|
|
|
elif isinstance(cx, CipherTextGT) and (isinstance(cy, int) or isinstance(cy, long)): |
|
|
|
|
return _enc(CipherTextGT, lib.sheMulGT, lib.sheMulIntVecGT, lib.sheNegGT, byref(cx.v), cy) |
|
|
|
|
if ret != 0: |
|
|
|
|
raise RuntimeError("mul") |
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
def deserializeToSecretKey(buf): |
|
|
|
|
return _deserialize(SecretKey, lib.sheSecretKeyDeserialize, buf) |
|
|
|
|
|
|
|
|
|
def deserializeToPublicKey(buf): |
|
|
|
|
return _deserialize(PublicKey, lib.shePublicKeyDeserialize, buf) |
|
|
|
|
|
|
|
|
|
def deserializeToCipherTextG1(buf): |
|
|
|
|
return _deserialize(CipherTextG1, lib.sheCipherTextG1Deserialize, buf) |
|
|
|
|
|
|
|
|
|
def deserializeToCipherTextG2(buf): |
|
|
|
|
return _deserialize(CipherTextG2, lib.sheCipherTextG2Deserialize, buf) |
|
|
|
|
|
|
|
|
|
def deserializeToCipherTextGT(buf): |
|
|
|
|
return _deserialize(CipherTextGT, lib.sheCipherTextGTDeserialize, buf) |
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
init(BLS12_381) |
|
|
|
|
sec = SecretKey() |
|
|
|
|
sec.setByCSPRNG() |
|
|
|
|
print("sec=", sec.serializeToHexStr()) |
|
|
|
|
pub = sec.getPulicKey() |
|
|
|
|
print("pub=", pub.serializeToHexStr()) |
|
|
|
|
if sec.serialize() != deserializeToSecretKey(sec.serialize()).serialize(): print("err-ser1") |
|
|
|
|
if pub.serialize() != deserializeToPublicKey(pub.serialize()).serialize(): print("err-ser2") |
|
|
|
|
|
|
|
|
|
m11 = 1 |
|
|
|
|
m12 = 5 |
|
|
|
|
m21 = 3 |
|
|
|
|
m22 = -4 |
|
|
|
|
c11 = pub.encG1(m11) |
|
|
|
|
c12 = pub.encG1(m12) |
|
|
|
|
# dec(enc) for G1 |
|
|
|
|
if sec.dec(c11) != m11: print("err1") |
|
|
|
|
|
|
|
|
|
# add/sub for G1 |
|
|
|
|
if sec.dec(add(c11, c12)) != m11 + m12: print("err2") |
|
|
|
|
if sec.dec(sub(c11, c12)) != m11 - m12: print("err3") |
|
|
|
|
|
|
|
|
|
# add/sub for G2 |
|
|
|
|
c21 = pub.encG2(m21) |
|
|
|
|
c22 = pub.encG2(m22) |
|
|
|
|
if sec.dec(c21) != m21: print("err4") |
|
|
|
|
if sec.dec(add(c21, c22)) != m21 + m22: print("err5") |
|
|
|
|
if sec.dec(sub(c21, c22)) != m21 - m22: print("err6") |
|
|
|
|
|
|
|
|
|
# mul const for G1/G2 |
|
|
|
|
if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1") |
|
|
|
|
if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2") |
|
|
|
|
|
|
|
|
|
if c11.serialize() != deserializeToCipherTextG1(c11.serialize()).serialize(): print("err-ser3") |
|
|
|
|
if c21.serialize() != deserializeToCipherTextG2(c21.serialize()).serialize(): print("err-ser3") |
|
|
|
|
|
|
|
|
|
# large integer |
|
|
|
|
m1 = 0x140712384712047127412964192876419276341 |
|
|
|
|
m2 = -m1 + 123 |
|
|
|
|
c1 = pub.encG1(m1) |
|
|
|
|
c2 = pub.encG1(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large11") |
|
|
|
|
c1 = mul(pub.encG1(1), m1) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large12") |
|
|
|
|
|
|
|
|
|
c1 = pub.encG2(m1) |
|
|
|
|
c2 = pub.encG2(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large21") |
|
|
|
|
c1 = mul(pub.encG2(1), m1) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large22") |
|
|
|
|
|
|
|
|
|
c1 = pub.encGT(m1) |
|
|
|
|
c2 = pub.encGT(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large31") |
|
|
|
|
c1 = mul(pub.encGT(1), m1) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err-large32") |
|
|
|
|
if c1.serialize() != deserializeToCipherTextGT(c1.serialize()).serialize(): print("err-ser4") |
|
|
|
|
|
|
|
|
|
mt = -56 |
|
|
|
|
ct = pub.encGT(mt) |
|
|
|
|
if sec.dec(ct) != mt: print("err7") |
|
|
|
|
|
|
|
|
|
# mul G1 and G2 |
|
|
|
|
if sec.dec(mul(c11, c21)) != m11 * m21: print("err8") |
|
|
|
|
|
|
|
|
|
if not sec.isZero(pub.encG1(0)): print("err-zero11") |
|
|
|
|
if sec.isZero(pub.encG1(3)): print("err-zero12") |
|
|
|
|
if not sec.isZero(pub.encG2(0)): print("err-zero21") |
|
|
|
|
if sec.isZero(pub.encG2(3)): print("err-zero22") |
|
|
|
|
if not sec.isZero(pub.encGT(0)): print("err-zero31") |
|
|
|
|
if sec.isZero(pub.encGT(3)): print("err-zero32") |
|
|
|
|
|
|
|
|
|
# use precomputedPublicKey for performance |
|
|
|
|
ppub = pub.createPrecomputedPublicKey() |
|
|
|
|
c1 = ppub.encG1(m11) |
|
|
|
|
if sec.dec(c1) != m11: print("err9") |
|
|
|
|
|
|
|
|
|
# large integer for precomputedPublicKey |
|
|
|
|
m1 = 0x140712384712047127412964192876419276341 |
|
|
|
|
m2 = -m1 + 123 |
|
|
|
|
c1 = ppub.encG1(m1) |
|
|
|
|
c2 = ppub.encG1(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err10") |
|
|
|
|
c1 = ppub.encG2(m1) |
|
|
|
|
c2 = ppub.encG2(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err11") |
|
|
|
|
c1 = ppub.encGT(m1) |
|
|
|
|
c2 = ppub.encGT(m2) |
|
|
|
|
if sec.dec(add(c1, c2)) != 123: print("err12") |
|
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
if sys.version_info.major >= 3: |
|
|
|
|
import timeit |
|
|
|
|
N = 100000 |
|
|
|
|
print(str(timeit.timeit("pub.encG1(12)", number=N, globals=globals()) / float(N) * 1e3) + "msec") |
|
|
|
|
print(str(timeit.timeit("ppub.encG1(12)", number=N, globals=globals()) / float(N) * 1e3) + "msec") |
|
|
|
|
|
|
|
|
|
ppub.destroy() # necessary to avoid memory leak |
|
|
|
|
|