[she] add deserialize()

pull/2/head
MITSUNARI Shigeo 6 years ago
parent 8d54609260
commit 696330be5d
  1. 80
      ffi/python/she.py

@ -20,7 +20,7 @@ GT_CIPHER_SIZE = GT_SIZE * 4
MCLBN_COMPILED_TIME_VAR = (MCLBN_FR_UNIT_SIZE * 10) + MCLBN_FP_UNIT_SIZE MCLBN_COMPILED_TIME_VAR = (MCLBN_FR_UNIT_SIZE * 10) + MCLBN_FP_UNIT_SIZE
Buffer = c_ubyte * 1536 Buffer = c_ubyte * 2304
lib = None lib = None
def init(curveType=BN254): def init(curveType=BN254):
@ -38,6 +38,8 @@ def init(curveType=BN254):
ret = lib.sheInit(curveType, MCLBN_COMPILED_TIME_VAR) ret = lib.sheInit(curveType, MCLBN_COMPILED_TIME_VAR)
if ret != 0: if ret != 0:
raise RuntimeError("sheInit", ret) raise RuntimeError("sheInit", ret)
lib.mclBn_verifyOrderG1(0)
lib.mclBn_verifyOrderG2(0)
# custom setup for a function which returns pointer # custom setup for a function which returns pointer
lib.shePrecomputedPublicKeyCreate.restype = c_void_p lib.shePrecomputedPublicKeyCreate.restype = c_void_p
@ -49,44 +51,47 @@ def setRangeForDLP(hashSize):
def setTryNum(tryNum): def setTryNum(tryNum):
lib.sheSetTryNum(tryNum) lib.sheSetTryNum(tryNum)
def hexStr(v): def _hexStr(v):
s = "" s = ""
for x in v: for x in v:
s += format(x, '02x') s += format(x, '02x')
return s return s
class CipherTextG1(Structure): def _serialize(self, f):
_fields_ = [("v", c_ulonglong * G1_CIPHER_SIZE)]
def serialize(self):
buf = Buffer() buf = Buffer()
ret = lib.sheCipherTextG1Serialize(byref(buf), len(buf), byref(self.v)) ret = f(byref(buf), len(buf), byref(self.v))
if ret == 0: if ret == 0:
raise RuntimeError("serialize") raise RuntimeError("serialize")
return buf[0:ret] return buf[0:ret]
def _deserialize(cstr, f, buf):
x = cstr()
ca = (c_ubyte * len(buf))(*buf)
ret = f(byref(x.v), byref(ca), len(buf))
if ret == 0:
raise RuntimeError("deserialize")
return x
class CipherTextG1(Structure):
_fields_ = [("v", c_ulonglong * G1_CIPHER_SIZE)]
def serialize(self):
return _serialize(self, lib.sheCipherTextG1Serialize)
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return _hexStr(self.serialize())
class CipherTextG2(Structure): class CipherTextG2(Structure):
_fields_ = [("v", c_ulonglong * G2_CIPHER_SIZE)] _fields_ = [("v", c_ulonglong * G2_CIPHER_SIZE)]
def serialize(self): def serialize(self):
buf = Buffer() return _serialize(self, lib.sheCipherTextG2Serialize)
ret = lib.sheCipherTextG2Serialize(byref(buf), len(buf), byref(self.v))
if ret == 0:
raise RuntimeError("serialize")
return buf[0:ret]
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return _hexStr(self.serialize())
class CipherTextGT(Structure): class CipherTextGT(Structure):
_fields_ = [("v", c_ulonglong * GT_CIPHER_SIZE)] _fields_ = [("v", c_ulonglong * GT_CIPHER_SIZE)]
def serialize(self): def serialize(self):
buf = Buffer() return _serialize(self, lib.sheCipherTextGTSerialize)
ret = lib.sheCipherTextGTSerialize(byref(buf), len(buf), byref(self.v))
if ret == 0:
raise RuntimeError("serialize")
return buf[0:ret]
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return _hexStr(self.serialize())
def _enc(CT, enc, encIntVec, neg, p, m): def _enc(CT, enc, encIntVec, neg, p, m):
c = CT() c = CT()
@ -136,13 +141,9 @@ class PrecomputedPublicKey(Structure):
class PublicKey(Structure): class PublicKey(Structure):
_fields_ = [("v", c_ulonglong * PUB_SIZE)] _fields_ = [("v", c_ulonglong * PUB_SIZE)]
def serialize(self): def serialize(self):
buf = Buffer() return _serialize(self, lib.shePublicKeySerialize)
ret = lib.shePublicKeySerialize(byref(buf), len(buf), byref(self.v))
if ret == 0:
raise RuntimeError("serialize")
return buf[0:ret]
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return _hexStr(self.serialize())
def encG1(self, m): def encG1(self, m):
return _enc(CipherTextG1, lib.sheEncG1, lib.sheEncIntVecG1, lib.sheNegG1, byref(self.v), m) return _enc(CipherTextG1, lib.sheEncG1, lib.sheEncIntVecG1, lib.sheNegG1, byref(self.v), m)
def encG2(self, m): def encG2(self, m):
@ -164,13 +165,9 @@ class SecretKey(Structure):
if ret != 0: if ret != 0:
raise RuntimeError("setByCSPRNG", ret) raise RuntimeError("setByCSPRNG", ret)
def serialize(self): def serialize(self):
buf = Buffer() return _serialize(self, lib.sheSecretKeySerialize)
ret = lib.sheSecretKeySerialize(byref(buf), len(buf), byref(self.v))
if ret == 0:
raise RuntimeError("serialize")
return buf[0:ret]
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return _hexStr(self.serialize())
def getPulicKey(self): def getPulicKey(self):
pub = PublicKey() pub = PublicKey()
lib.sheGetPublicKey(byref(pub.v), byref(self.v)) lib.sheGetPublicKey(byref(pub.v), byref(self.v))
@ -255,6 +252,21 @@ def mul(cx, cy):
raise RuntimeError("mul") raise RuntimeError("mul")
return out return out
def deserializeToSecretKey(buf):
return _deserialize(SecretKey, lib.sheSecretKeyDeserialize, buf)
def deserializeToPublicKey(buf):
return _deserialize(PublicKey, lib.shePublicKeyDeserialize, buf)
def deserializeToCipherTextG1(buf):
return _deserialize(CipherTextG1, lib.sheCipherTextG1Deserialize, buf)
def deserializeToCipherTextG2(buf):
return _deserialize(CipherTextG2, lib.sheCipherTextG2Deserialize, buf)
def deserializeToCipherTextGT(buf):
return _deserialize(CipherTextGT, lib.sheCipherTextGTDeserialize, buf)
if __name__ == '__main__': if __name__ == '__main__':
init(BLS12_381) init(BLS12_381)
sec = SecretKey() sec = SecretKey()
@ -262,6 +274,8 @@ if __name__ == '__main__':
print("sec=", sec.serializeToHexStr()) print("sec=", sec.serializeToHexStr())
pub = sec.getPulicKey() pub = sec.getPulicKey()
print("pub=", pub.serializeToHexStr()) print("pub=", pub.serializeToHexStr())
if sec.serialize() != deserializeToSecretKey(sec.serialize()).serialize(): print("err-ser1")
if pub.serialize() != deserializeToPublicKey(pub.serialize()).serialize(): print("err-ser2")
m11 = 1 m11 = 1
m12 = 5 m12 = 5
@ -287,6 +301,9 @@ if __name__ == '__main__':
if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1") if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1")
if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2") if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2")
if c11.serialize() != deserializeToCipherTextG1(c11.serialize()).serialize(): print("err-ser3")
if c21.serialize() != deserializeToCipherTextG2(c21.serialize()).serialize(): print("err-ser3")
# large integer # large integer
m1 = 0x140712384712047127412964192876419276341 m1 = 0x140712384712047127412964192876419276341
m2 = -m1 + 123 m2 = -m1 + 123
@ -307,6 +324,7 @@ if __name__ == '__main__':
if sec.dec(add(c1, c2)) != 123: print("err-large31") if sec.dec(add(c1, c2)) != 123: print("err-large31")
c1 = mul(pub.encGT(1), m1) c1 = mul(pub.encGT(1), m1)
if sec.dec(add(c1, c2)) != 123: print("err-large32") if sec.dec(add(c1, c2)) != 123: print("err-large32")
if c1.serialize() != deserializeToCipherTextGT(c1.serialize()).serialize(): print("err-ser4")
mt = -56 mt = -56
ct = pub.encGT(mt) ct = pub.encGT(mt)

Loading…
Cancel
Save