-rwxr-xr-x 4505 high-ctidh-20210523/costs.py
#!/usr/bin/env python3
from memoized import memoized
import chain
import costisog
M = 1
S = 1
x2 = S+S
x2DBL = M+M+M+M # but save 1M if affine
xDBL = x2+x2DBL
xADD = M+M+S+S+M+M
clear2 = xDBL
def daccost(n):
return xDBL+xADD+n*xADD
@memoized
def bigprime(primes):
p = 4
for l in primes: p *= l
p -= 1
return p
@memoized
def inv(primes):
p = bigprime(primes)
invchain = chain.chain2(p-2)
invchaincost = chain.cost2(invchain)
return invchaincost[0]*M+invchaincost[1]*S
@memoized
def div(primes):
return inv(primes)+M
@memoized
def sqrt(primes):
p = bigprime(primes)
sqrtchain = chain.chain2((p+1)//4)
sqrtchaincost = chain.cost2(sqrtchain)
return sqrtchaincost[0]*M+(sqrtchaincost[1]+1)*S
@memoized
def elligator(primes):
return S+M+M+M+S+M+M+sqrt(primes)
def dac_search(target,r0,r1,r2,chain,chainlen,best,bestlen):
if chainlen >= bestlen: return best,bestlen
if r2 > target: return best,bestlen
if r2<<(bestlen-1-chainlen) < target: return best,bestlen
if r2 == target: return chain,chainlen
chain *= 2
chainlen += 1
best,bestlen = dac_search(target,r0,r2,r0+r2,chain+1,chainlen,best,bestlen)
best,bestlen = dac_search(target,r1,r2,r1+r2,chain,chainlen,best,bestlen)
return best,bestlen
def dac(target):
best = None
bestlen = 0
while best == None:
bestlen += 1
best,bestlen = dac_search(target,1,2,3,0,0,best,bestlen)
return best,bestlen
@memoized
def daclen(primes):
return [dac(primes[j])[1] for j in range(len(primes))]
@memoized
def batchstart(batchsize):
B = len(batchsize)
return [sum(batchsize[:j]) for j in range(B)]
@memoized
def batchstop(batchsize):
B = len(batchsize)
return [sum(batchsize[:j+1]) for j in range(B)]
@memoized
def maxdaclen(primes,batchsize):
B = len(batchsize)
return [max(daclen(primes)[j]
for j in range(batchstart(batchsize)[b],
batchstop(batchsize)[b]))
for b in range(B)]
@memoized
def maxdac(primes,batchsize,b):
B = len(batchsize)
M = maxdaclen(primes,batchsize)
return daccost(M[b])
@memoized
def eachdac(primes,batchsize,b):
B = len(batchsize)
D = daclen(primes)
return sum(daccost(D[j])
for j in range(batchstart(batchsize)[b],
batchstop(batchsize)[b]))
@memoized
def bsgs(primes,batchsize,b):
return costisog.optimize(primes[batchstart(batchsize)[b]],1)[1]
@memoized
def isog(push,primes,batchsize,b):
bs,gs = bsgs(primes,batchsize,b)
return costisog.isog(primes[batchstop(batchsize)[b]-1],push,(bs,gs))
def mults(x,primes,batchsize):
B = len(batchsize)
mults = 0
mults += div(primes)
mults += x['elligator']*elligator(primes)
mults += x['clear2']*clear2
for b in range(B):
mults += x['maxdac',b]*maxdac(primes,batchsize,b)
mults += x['eachdac',b]*eachdac(primes,batchsize,b)
mults += x['isog',0,b]*isog(0,primes,batchsize,b)
mults += x['isog',1,b]*isog(1,primes,batchsize,b)
mults += x['isog',2,b]*isog(2,primes,batchsize,b)
return mults
def strstats(x,prefix,format,primes,batchsize):
B = len(batchsize)
result = prefix
result += 'mults %s ' % (format%mults(x,primes,batchsize))
result += 'AB %s ' % (format%x['AB'])
result += 'elligator %s ' % (format%x['elligator'])
result += 'clear2 %s ' % (format%x['clear2'])
result += 'isog0 %s ' % ' '.join(format%x['isog',0,b] for b in range(B))
result += 'isog1 %s ' % ' '.join(format%x['isog',1,b] for b in range(B))
result += 'isog2 %s ' % ' '.join(format%x['isog',2,b] for b in range(B))
result += 'maxdac %s ' % ' '.join(format%x['maxdac',b] for b in range(B))
result += 'eachdac %s ' % ' '.join(format%x['eachdac',b] for b in range(B))
return result
def test():
primes = (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,587)
batchsize = (3, 4, 5, 5, 6, 6, 6, 8, 7, 9, 10, 5)
batchbound = (13, 19, 20, 20, 20, 20, 20, 20, 17, 17, 17, 5)
B = len(batchsize)
C = [[isog(push,primes,batchsize,b) for b in range(B)] for push in range(3)]
assert C[0] == [34, 82, 170, 250, 368, 412, 532, 653, 720, 875, 1034, 1860, ]
assert C[1] == [48, 120, 252, 372, 546, 643, 830, 1031, 1138, 1385, 1644, 2898, ]
assert C[2] == [62, 158, 334, 494, 724, 874, 1128, 1409, 1556, 1895, 2254, 3936, ]
if __name__ == '__main__':
test()