Open sragss opened 9 months ago
Here's a poly A, B, C density chart for Sha256:
Poly Az ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1339241
0x0000000000000000000000000000000000000000000000000000000000000001: 390838
0x0000000000000000000000000000000000000000000000000000000000000002: 266101
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 99671
Total with only a single appearance: 1301
Poly Bz ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1480606
0x0000000000000000000000000000000000000000000000000000000000000001: 616546
Total with only a single appearance: 0
Poly uCz_E ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1863816
0x0000000000000000000000000000000000000000000000000000000000000002: 132442
0x0000000000000000000000000000000000000000000000000000000000000001: 49820
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 49773
Total with only a single appearance: 1301
Time elapsed is: 7.962553084s
An optimized 0/1 comb function makes it ~33% faster:
let comb_func_outer = |poly_A_comp: &G::Scalar,
poly_B_comp: &G::Scalar,
poly_C_comp: &G::Scalar,
poly_D_comp: &G::Scalar|
-> G::Scalar {
// Goal: compute *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) fast.
// poly_A we know to be uniformly random
// poly_B: A matrix, poly_C: B matrix, poly_D: C matrix
if poly_B_comp.eq(&G::Scalar::ZERO) || poly_C_comp.eq(&G::Scalar::ZERO) {
*poly_A_comp * poly_D_comp.neg()
} else {
let inner = if poly_B_comp.eq(&G::Scalar::ONE) {
*poly_C_comp - *poly_D_comp
} else if poly_C_comp.eq(&G::Scalar::ONE) {
*poly_B_comp - *poly_D_comp
} else {
*poly_B_comp * *poly_C_comp - *poly_D_comp
};
*poly_A_comp * inner
}
};
Sumcheck::prove_cubic_with_additive_term
seems suboptimal. Currently it takes ~6% of Spartan e2e time.https://github.com/a16z/Spartan2/blob/uniform_r1cs_shape/src/spartan/sumcheck.rs#L251
Some ideas on optimization follow.
0 / 1 Checking
comb_func_outer
gets passed intoSumcheck::prove_cubic_with_additive_term
fromSnark::prove
. This combination function isf(a,b,c,d) = a * (b * c - d)
. There are clear optimizations to be had here in the case that any of the terms are 0 / 1. Specifically if a is zero, we should short circuit. The rest are less relevant but can theoretically save up to 66% of field multiplications.compute_eval_points_cubic
This function is parallelized over the length of the 4 MLEs passed in, but is missing some optimizations. This is the binding function:
We compute
f(r)
forr = 0, 2, 3
(the 1-th eval can be derived).To expand this a bit we have:
We can precompute
m = high - low
.This is more efficient by a few field additions.
Next, notice that if
high / low
have a high probability of being0 / 1
we have some interesting properties:high == low
=>m=0
=>f(2) = f(3) = high
m=0
=>comb_func(f_a(2), f_b(2), f_c(2), f_d(2)) == comb_func(f_a(3), f_b(3), f_c(3), f_d(3))
There are some other combos that are likely less relevant and rarer. May be worth exploring.Poly Binding
At the end of each round of
Sumcheck::prove_cubic_with_additive_term
the 4 polynomials are bound (bound_poly_var_top
). These can all be executed in parallel rather than serially. Thebound_poly_var_top
function itself is parallelized, but worth determining experimentally if a changed parallelization shape is more efficient from a memory contention perspective (I suspect it will be).Inline Poly Binding
The two sections of the sumcheck loop are to evaluate the joint polynomial
p(b,...) = f_a(b,..)*[f_b(b,..)*f_c(b,..) - f_d(b,..)]
over the boolean hypercube. then to bind each of the multilinear polynomialsf_a / f_b / f_c / f_d
to a pointr
derived from the prior evaluation. I usually call these the eval loop then the binding loop. Interestingly they perform much of the same work. Above (incompute_eval_points_cubic
) I describe the eval loop algorithm. The binding loop does the same but forf(r)
instead off({0,2,3})
. This means it may be plausible to keepm
around to computelow' = low + r * m
. I believe this saves exactly one field addition per step at the cost of significant RAM, but plausible there are some memory performance improvements when tested experimentally.