[she] she.py supports large int

pull/2/head
MITSUNARI Shigeo 6 years ago
parent d05146f5b1
commit 9bc05bf641
  1. 115
      ffi/python/she.py

@ -34,7 +34,7 @@ def init(curveType=MCL_BN254):
else: else:
raise RuntimeError("not support yet", name) raise RuntimeError("not support yet", name)
lib = cdll.LoadLibrary(libName) lib = cdll.LoadLibrary(libName)
ret = lib.sheInit(MCL_BN254, MCLBN_COMPILED_TIME_VAR) ret = lib.sheInit(curveType, MCLBN_COMPILED_TIME_VAR)
if ret != 0: if ret != 0:
raise RuntimeError("sheInit", ret) raise RuntimeError("sheInit", ret)
# custom setup for a function which returns pointer # custom setup for a function which returns pointer
@ -89,6 +89,34 @@ class CipherTextGT(Structure):
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return hexStr(self.serialize())
def _enc(CT, enc, encIntVec, neg, p, m):
c = CT()
if -0x80000000 <= m <= 0x7fffffff:
ret = enc(byref(c.v), p, m)
if ret != 0:
raise RuntimeError("enc", m)
return c
if m < 0:
minus = True
m = -m
else:
minus = False
if m >= 1 << (MCLBN_FR_UNIT_SIZE * 64):
raise RuntimeError("enc:too large m", m)
a = []
while m > 0:
a.append(m & 0xffffffff)
m >>= 32
ca = (c_uint * len(a))(*a)
ret = encIntVec(byref(c.v), p, byref(ca), sizeof(ca))
if ret != 0:
raise RuntimeError("enc:IntVec", m)
if minus:
ret = neg(byref(c.v), byref(c.v))
if ret != 0:
raise RuntimeError("enc:neg", m)
return c
class PrecomputedPublicKey(Structure): class PrecomputedPublicKey(Structure):
def __init__(self): def __init__(self):
self.p = 0 self.p = 0
@ -100,23 +128,11 @@ class PrecomputedPublicKey(Structure):
def destroy(self): def destroy(self):
lib.shePrecomputedPublicKeyDestroy(self.p) lib.shePrecomputedPublicKeyDestroy(self.p)
def encG1(self, m): def encG1(self, m):
c = CipherTextG1() return _enc(CipherTextG1, lib.shePrecomputedPublicKeyEncG1, lib.shePrecomputedPublicKeyEncIntVecG1, lib.sheNegG1, self.p, m)
ret = lib.shePrecomputedPublicKeyEncG1(byref(c.v), self.p, m)
if ret != 0:
raise RuntimeError("encG1", m)
return c
def encG2(self, m): def encG2(self, m):
c = CipherTextG2() return _enc(CipherTextG2, lib.shePrecomputedPublicKeyEncG2, lib.shePrecomputedPublicKeyEncIntVecG2, lib.sheNegG2, self.p, m)
ret = lib.shePrecomputedPublicKeyEncG2(byref(c.v), self.p, m)
if ret != 0:
raise RuntimeError("encG2", m)
return c
def encGT(self, m): def encGT(self, m):
c = CipherTextGT() return _enc(CipherTextGT, lib.shePrecomputedPublicKeyEncGT, lib.shePrecomputedPublicKeyEncIntVecGT, lib.sheNegGT, self.p, m)
ret = lib.shePrecomputedPublicKeyEncGT(byref(c.v), self.p, m)
if ret != 0:
raise RuntimeError("encGT", m)
return c
class PublicKey(Structure): class PublicKey(Structure):
_fields_ = [("v", c_ulonglong * PUB_SIZE)] _fields_ = [("v", c_ulonglong * PUB_SIZE)]
@ -129,23 +145,11 @@ class PublicKey(Structure):
def serializeToHexStr(self): def serializeToHexStr(self):
return hexStr(self.serialize()) return hexStr(self.serialize())
def encG1(self, m): def encG1(self, m):
c = CipherTextG1() return _enc(CipherTextG1, lib.sheEncG1, lib.sheEncIntVecG1, lib.sheNegG1, byref(self.v), m)
ret = lib.sheEncG1(byref(c.v), byref(self.v), m)
if ret != 0:
raise RuntimeError("encG1", m)
return c
def encG2(self, m): def encG2(self, m):
c = CipherTextG2() return _enc(CipherTextG2, lib.sheEncG2, lib.sheEncIntVecG2, lib.sheNegG2, byref(self.v), m)
ret = lib.sheEncG2(byref(c.v), byref(self.v), m)
if ret != 0:
raise RuntimeError("encG2", m)
return c
def encGT(self, m): def encGT(self, m):
c = CipherTextGT() return _enc(CipherTextGT, lib.sheEncGT, lib.sheEncIntVecGT, lib.sheNegGT, byref(self.v), m)
ret = lib.sheEncGT(byref(c.v), byref(self.v), m)
if ret != 0:
raise RuntimeError("encGT", m)
return c
def createPrecomputedPublicKey(self): def createPrecomputedPublicKey(self):
ppub = PrecomputedPublicKey() ppub = PrecomputedPublicKey()
ppub.create() ppub.create()
@ -234,15 +238,12 @@ def mul(cx, cy):
if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG2): if isinstance(cx, CipherTextG1) and isinstance(cy, CipherTextG2):
out = CipherTextGT() out = CipherTextGT()
ret = lib.sheMul(byref(out.v), byref(cx.v), byref(cy.v)) ret = lib.sheMul(byref(out.v), byref(cx.v), byref(cy.v))
elif isinstance(cx, CipherTextG1) and isinstance(cy, int): elif isinstance(cx, CipherTextG1) and (isinstance(cy, int) or isinstance(cy, long)):
out = CipherTextG1() return _enc(CipherTextG1, lib.sheMulG1, lib.sheMulIntVecG1, lib.sheNegG1, byref(cx.v), cy)
ret = lib.sheMulG1(byref(out.v), byref(cx.v), cy) elif isinstance(cx, CipherTextG2) and (isinstance(cy, int) or isinstance(cy, long)):
elif isinstance(cx, CipherTextG2) and isinstance(cy, int): return _enc(CipherTextG2, lib.sheMulG2, lib.sheMulIntVecG2, lib.sheNegG2, byref(cx.v), cy)
out = CipherTextG2() elif isinstance(cx, CipherTextGT) and (isinstance(cy, int) or isinstance(cy, long)):
ret = lib.sheMulG2(byref(out.v), byref(cx.v), cy) return _enc(CipherTextGT, lib.sheMulGT, lib.sheMulIntVecGT, lib.sheNegGT, byref(cx.v), cy)
elif isinstance(cx, CipherTextGT) and isinstance(cy, int):
out = CipherTextGT()
ret = lib.sheMulGT(byref(out.v), byref(cx.v), cy)
if ret != 0: if ret != 0:
raise RuntimeError("mul") raise RuntimeError("mul")
return out return out
@ -279,6 +280,27 @@ if __name__ == '__main__':
if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1") if sec.dec(mul(c11, 3)) != m11 * 3: print("err_mul1")
if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2") if sec.dec(mul(c21, 7)) != m21 * 7: print("err_mul2")
# large integer
m1 = 0x140712384712047127412964192876419276341
m2 = -m1 + 123
c1 = pub.encG1(m1)
c2 = pub.encG1(m2)
if sec.dec(add(c1, c2)) != 123: print("err-large11")
c1 = mul(pub.encG1(1), m1)
if sec.dec(add(c1, c2)) != 123: print("err-large12")
c1 = pub.encG2(m1)
c2 = pub.encG2(m2)
if sec.dec(add(c1, c2)) != 123: print("err-large21")
c1 = mul(pub.encG2(1), m1)
if sec.dec(add(c1, c2)) != 123: print("err-large22")
c1 = pub.encGT(m1)
c2 = pub.encGT(m2)
if sec.dec(add(c1, c2)) != 123: print("err-large31")
c1 = mul(pub.encGT(1), m1)
if sec.dec(add(c1, c2)) != 123: print("err-large32")
mt = -56 mt = -56
ct = pub.encGT(mt) ct = pub.encGT(mt)
if sec.dec(ct) != mt: print("err7") if sec.dec(ct) != mt: print("err7")
@ -291,6 +313,19 @@ if __name__ == '__main__':
c1 = ppub.encG1(m11) c1 = ppub.encG1(m11)
if sec.dec(c1) != m11: print("err9") if sec.dec(c1) != m11: print("err9")
# large integer for precomputedPublicKey
m1 = 0x140712384712047127412964192876419276341
m2 = -m1 + 123
c1 = ppub.encG1(m1)
c2 = ppub.encG1(m2)
if sec.dec(add(c1, c2)) != 123: print("err10")
c1 = ppub.encG2(m1)
c2 = ppub.encG2(m2)
if sec.dec(add(c1, c2)) != 123: print("err11")
c1 = ppub.encGT(m1)
c2 = ppub.encGT(m2)
if sec.dec(add(c1, c2)) != 123: print("err12")
import sys import sys
if sys.version_info.major >= 3: if sys.version_info.major >= 3:
import timeit import timeit

Loading…
Cancel
Save