```-rwxr-xr-x 5846 high-ctidh-20210523/greedy#!/usr/bin/env python3

# sample usage: ./greedy.py 512 256 3 0 2
# CSIDH-512 prime
# >=2^256 keys
# B=3
# force the 0 largest primes to be skipped
# try to use 2 cores

from multiprocessing import Pool

from memoized import memoized
import distmults
import costs
import sys

def printstatus(prefix,cost,N0,m0,numprimes1):
N = N0 if numprimes1 == 0 else N0+(numprimes1,)
m = m0 if numprimes1 == 0 else m0+(0,)
print('%s %.2f %s %s' % (prefix,cost,str(N).replace(' ',''),str(m).replace(' ','')))

def costfunction(primes0,primes1,N0,m0):
primes = primes0+primes1
N = N0+(len(primes1),) if len(primes1) > 0 else N0
m = m0+(0,) if len(primes1) > 0 else m0
x = distmults.average(primes,N,m)
return costs.mults(x,primes,N)

@memoized
def batchkeys(x,y):
poly = [1]
for i in range(x):
newpoly = poly+[0]
for j in range(len(poly)):
newpoly[j+1] += poly[j]
poly = newpoly
for i in range(y):
newpoly = poly+[0]
for j in range(len(poly)):
newpoly[j+1] += 2*poly[j]
poly = newpoly
return poly[x]

@memoized
def keys(N,m):
result = 1
for s,b in zip(N,m):
result *= batchkeys(s,b)
return result

# neighboring_intvec; search upwards in non-b directions
def searchup(minkeyspace,primes0,primes1,N0,m0,cost,b,best):
if cost >= best[0]:
return best
if keys(N0,m0) >= minkeyspace:
return cost,m0

B0 = len(N0)

for c in range(B0):
if c == b: continue
upm = list(m0)
upm[c] += 1
upm = tuple(upm)
upcost = costfunction(primes0,primes1,N0,upm)
best = searchup(minkeyspace,primes0,primes1,N0,upm,upcost,b,best)
return best

def optimizem(minkeyspace,primes0,primes1,N0,m0=None):
B0 = len(N0)

if m0 == None:
N0 = tuple(N0)
assert sum(N0) == len(primes0)

z = 1
while True:
m0 = tuple([z]*B0)
if keys(N0,m0) >= minkeyspace: break
z += 1
else:
while keys(N0,m0) < minkeyspace:
m0 = list(m0)
m0[0] += 1
m0 = tuple(m0)

cost = costfunction(primes0,primes1,N0,m0)

while True:
printstatus('searching',cost,N0,m0,len(primes1))
sys.stdout.flush()

best = cost,m0
for b in range(B0):
if m0[b] == 0: continue
newm = list(m0)
newm[b] -= 1
newm = tuple(newm)
newcost = costfunction(primes0,primes1,N0,newm)
best = searchup(minkeyspace,primes0,primes1,N0,newm,newcost,b,best)
if best == (cost,m0): break
cost,m0 = best

return cost,m0

def optimizeNm(minkeyspace,primes0,primes1,B,parallelism=1):
B0 = B-1 if len(primes1)>0 else B
N0 = tuple(len(primes0)//B0+(j<len(primes0)%B0) for j in range(B0))
cost,m0 = optimizem(minkeyspace,primes0,primes1,N0)

while True:
best = cost,N0,m0
variants = []
for b in range(B0):
if N0[b] <= 1: continue
for c in range(B0):
if c == b: continue
newsize = list(N0)
newsize[b] -= 1
newsize[c] += 1
newsize = tuple(newsize)
variants += [(minkeyspace,primes0,primes1,newsize,m0)]
with Pool(parallelism) as p:
results = p.starmap(optimizem,variants,chunksize=1)
for (newcost,newm),(_,_,_,newsize,_) in zip(results,variants):
if newcost < best[0]:
best = newcost,newsize,newm
if best == (cost,N0,m0): break
cost,N0,m0 = best

return cost,N0,m0

def doit():
sys.setrecursionlimit(10000)

p = {}
p['512'] = (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,587)
p['1024'] = (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,661,673,677,683,691,701,709,719,727,733,983)
p['2048'] = (3,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,661,673,677,683,691,701,709,719,727,733,739,743,751,757,761,769,773,787,797,809,811,821,823,827,829,839,853,857,859,863,877,881,883,887,907,911,919,929,937,941,947,953,967,971,977,983,991,997,1009,1013,1019,1021,1031,1033,1039,1049,1051,1061,1063,1069,1087,1091,1093,1097,1103,1109,1117,1123,1129,1151,1153,1163,1171,1181,1187,1193,1201,1213,1217,1223,1229,1231,1237,1249,1259,1277,1279,1283,1289,1291,1297,1301,1303,1307,1319,1321,1327,1361,1367,1373,1381,1399,1409,1423,1427,1429,1433,1439,1447,1451,1453,1459,3413)

primes = p['512']
if len(sys.argv) > 1:
primes = p[sys.argv[1]]

minkeyspace = 2**256
if len(sys.argv) > 2:
minkeyspace = 2**float(sys.argv[2])

B = 3
if len(sys.argv) > 3:
B = int(sys.argv[3])
assert B >= 1
assert B <= len(primes)

numprimes1 = 0
if len(sys.argv) > 4:
numprimes1 = int(sys.argv[4])
assert 0 <= numprimes1
if numprimes1 > 0: assert B >= 2

primes0 = primes[:len(primes)-numprimes1]
primes1 = primes[len(primes)-numprimes1:]

parallelism = 1
if len(sys.argv) > 5:
parallelism = int(sys.argv[5])

cost,N0,m0 = optimizeNm(minkeyspace,primes0,primes1,B,parallelism)
printstatus('output',cost,N0,m0,len(primes1))

if __name__ == '__main__':
doit()
```