a16z / Spartan2

High-speed zkSNARKs
MIT License
2 stars 3 forks source link

Optimization: Sumcheck::prove_cubic_with_additive_term #4

Open sragss opened 9 months ago

sragss commented 9 months ago

Sumcheck::prove_cubic_with_additive_term seems suboptimal. Currently it takes ~6% of Spartan e2e time.

https://github.com/a16z/Spartan2/blob/uniform_r1cs_shape/src/spartan/sumcheck.rs#L251

Some ideas on optimization follow.

0 / 1 Checking

comb_func_outer gets passed into Sumcheck::prove_cubic_with_additive_term from Snark::prove. This combination function is f(a,b,c,d) = a * (b * c - d). There are clear optimizations to be had here in the case that any of the terms are 0 / 1. Specifically if a is zero, we should short circuit. The rest are less relevant but can theoretically save up to 66% of field multiplications.

compute_eval_points_cubic

This function is parallelized over the length of the 4 MLEs passed in, but is missing some optimizations. This is the binding function:

for i in 0..mle_evals.len() / 2:
    low = mle_evals[i];
    high = mle_evals[2*i];
    f(r) = low + r * (high - low);

We compute f(r) for r = 0, 2, 3 (the 1-th eval can be derived).

To expand this a bit we have:

f(0) = low + 0 * (high - low) = low
f(2) = low + 2 * (high - low) = high + high - low
f(3) = low + 3 * (high - low) = f(2) + high - low

We can precompute m = high - low.

m =  high - low
f(0) = low
f(2) = high + m
f(3) = f(2) + m

This is more efficient by a few field additions.

Next, notice that if high / low have a high probability of being 0 / 1 we have some interesting properties:

Poly Binding

At the end of each round of Sumcheck::prove_cubic_with_additive_term the 4 polynomials are bound (bound_poly_var_top). These can all be executed in parallel rather than serially. The bound_poly_var_top function itself is parallelized, but worth determining experimentally if a changed parallelization shape is more efficient from a memory contention perspective (I suspect it will be).

Inline Poly Binding

The two sections of the sumcheck loop are to evaluate the joint polynomial p(b,...) = f_a(b,..)*[f_b(b,..)*f_c(b,..) - f_d(b,..)] over the boolean hypercube. then to bind each of the multilinear polynomials f_a / f_b / f_c / f_d to a point r derived from the prior evaluation. I usually call these the eval loop then the binding loop. Interestingly they perform much of the same work. Above (in compute_eval_points_cubic) I describe the eval loop algorithm. The binding loop does the same but for f(r) instead of f({0,2,3}). This means it may be plausible to keep m around to compute low' = low + r * m. I believe this saves exactly one field addition per step at the cost of significant RAM, but plausible there are some memory performance improvements when tested experimentally.

sragss commented 9 months ago

Here's a poly A, B, C density chart for Sha256:

Poly Az ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1339241
0x0000000000000000000000000000000000000000000000000000000000000001: 390838
0x0000000000000000000000000000000000000000000000000000000000000002: 266101
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 99671
Total with only a single appearance: 1301

Poly Bz ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1480606
0x0000000000000000000000000000000000000000000000000000000000000001: 616546
Total with only a single appearance: 0

Poly uCz_E ====================== 2097152
0x0000000000000000000000000000000000000000000000000000000000000000: 1863816
0x0000000000000000000000000000000000000000000000000000000000000002: 132442
0x0000000000000000000000000000000000000000000000000000000000000001: 49820
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000: 49773
Total with only a single appearance: 1301
Time elapsed is: 7.962553084s
sragss commented 9 months ago

An optimized 0/1 comb function makes it ~33% faster:


        let comb_func_outer = |poly_A_comp: &G::Scalar,
                               poly_B_comp: &G::Scalar,
                               poly_C_comp: &G::Scalar,
                               poly_D_comp: &G::Scalar|
         -> G::Scalar {
          // Goal: compute *poly_A_comp * (*poly_B_comp * *poly_C_comp - *poly_D_comp) fast.
          // poly_A we know to be uniformly random
          // poly_B: A matrix, poly_C: B matrix, poly_D: C matrix
          if poly_B_comp.eq(&G::Scalar::ZERO) || poly_C_comp.eq(&G::Scalar::ZERO) {
            *poly_A_comp * poly_D_comp.neg()
          } else {
            let inner = if poly_B_comp.eq(&G::Scalar::ONE) {
              *poly_C_comp - *poly_D_comp
            } else if poly_C_comp.eq(&G::Scalar::ONE)  {
              *poly_B_comp - *poly_D_comp
            } else {
              *poly_B_comp * *poly_C_comp - *poly_D_comp
            };
            *poly_A_comp * inner
          }
        };