```-rwxr-xr-x 5471 high-ctidh-20210523/distmults.py#!/usr/bin/env python3

import costs

# input: list of coeffs of g = g_0+g_1 x+g_2 x^2+...
# input: f
# output: list of coeffs of (f+(1-f)x)g
def polytimeslinear(g,f):
n = len(g)
result = [0]*(n+1)
for s in range(n):
result[s] += f*g[s]
result[s+1] += (1-f)*g[s]
return result

def average(primes,batchsize,batchbound):
B = len(batchsize)
assert B == len(batchbound)
assert sum(batchsize) == len(primes)
batchstart = [sum(batchsize[:j]) for j in range(B)]
batchstop = [sum(batchsize[:j+1]) for j in range(B)]

R = 0
prsuccess = [[1] for b in range(B)]

x = {}
x['AB'] = 0
x['elligator'] = 0
x['clear2'] = 0
for b in range(B):
x['eachdac',b] = 0
x['maxdac',b] = 0
x['isog',0,b] = 0
x['isog',1,b] = 0
x['isog',2,b] = 0

while True:
# now considering status at beginning of round R
# one AB per round; rounds numbered from 0

# for 0 <= s <= R:
#   prsuccess[b][s] is chance that batch b
#   had exactly s successes in first R rounds

prdone = [sum(prsuccess[b][batchbound[b]:]) for b in range(B)]
# prdone[b] is chance that batch b is done in <=R rounds

prdoneall = 1
for b in range(B): prdoneall *= prdone[b]

prAB = 1-prdoneall
# we need round R with probability prAB

x['AB'] += prAB

for b in range(B):
# case 1, chance 1-prAB: not AB; forces done[b]
# case 2, chance 1-prdone[b]: not done[b]; forces not AB
# case 3, chance prAB-(1-prdone[b]): AB and done[b]

# in case 3, this AB will do 2x eachdac clearing outside primes
x['eachdac',b] += 2*(prAB-(1-prdone[b]))

# in case 2, this AB will do 2x clearing non-selected primes
x['maxdac',b] += 2*(1-prdone[b])*(batchsize[b]-1)

# remaining costs depend on targetlen and our position in this AB
# so figure out distribution of positions

gfsmaller = [1]
for a in range(b):
gfsmaller = polytimeslinear(gfsmaller,prdone[a])
# gfsmaller is generating function for
# number of smaller primes in this AB

gflarger = [1]
for a in range(b+1,B):
gflarger = polytimeslinear(gflarger,prdone[a])
# gfsmaller is generating function for
# number of larger primes in this AB

for numsmaller in range(len(gfsmaller)):
for numlarger in range(len(gflarger)):
prsituation = (1-prdone[b])*gfsmaller[numsmaller]*gflarger[numlarger]

targetlen = numsmaller+1+numlarger
if targetlen <= 3:
t = numsmaller # 0 1 2
else:
# 6 4 3 2 1 0 5 7
if numlarger == 0:
t = targetlen-1
elif numlarger == 1:
t = 0
elif numlarger == 2:
t = targetlen-2
else:
t = numlarger-2

if t == 0:
x['elligator'] += 2*prsituation
x['clear2'] += 4*prsituation
# XXX: can also do these directly from prAB

# our contribution to kernel points
# from earlier targets in AB:
x['maxdac',b] += t*prsituation

# isogenies:
if t == targetlen-1:
push = 0
elif t == 0:
push = 1
else:
push = 2
if t == targetlen-2 and targetlen > 2:
push = 1
x['isog',push,b] += prsituation*(1-1/primes[batchstart[b]])

# final mults:
if t == targetlen-2 and targetlen > 2:
x['maxdac',b] += prsituation # P0
elif t < targetlen-1:
x['maxdac',b] += 2*prsituation # P0, P1

if prdoneall > 0.999999999:
return x

R += 1
for b in range(B):
f = 1.0/primes[sum(batchsize[:b])]
# f is failure probability of batch b
prsuccess[b] = polytimeslinear(prsuccess[b],f)

def test():
import sys
sys.setrecursionlimit(10000)
for primes,batchsize,batchbound in (
( (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,587),
(3, 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 3, 5, 6, 3),
(8, 11, 12, 12, 12, 12, 12, 12, 12, 11, 10, 4, 6, 9, 3)
),
( (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,587),
(3, 4, 5, 5, 6, 6, 6, 8, 7, 9, 10, 5),
(13, 19, 20, 20, 20, 20, 20, 20, 17, 17, 17, 5)
),
( (3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137,139,149,151,157,163,167,173,179,181,191,193,197,199,211,223,227,229,233,239,241,251,257,263,269,271,277,281,283,293,307,311,313,317,331,337,347,349,353,359,367,373,379,383,389,397,401,409,419,421,431,433,439,443,449,457,461,463,467,479,487,491,499,503,509,521,523,541,547,557,563,569,571,577,587,593,599,601,607,613,617,619,631,641,643,647,653,659,661,673,677,683,691,701,709,719,727,733,983),
(5, 5, 6, 6, 6, 7, 7, 8, 10, 9, 7, 8, 8, 8, 9, 9, 5, 7),
(5, 7, 7, 7, 7, 7, 7, 7, 7, 7, 6, 6, 6, 6, 6, 6, 2, 2)
)
):
x = average(primes,batchsize,batchbound)
print(costs.strstats(x,'','%.6f',primes,batchsize))

if __name__ == '__main__':
test()
```