a16z / jolt

The simplest and most extensible zkVM. Fast and fully open source from a16z crypto and friends. ⚡
https://jolt.a16zcrypto.com
MIT License
587 stars 106 forks source link

Simplified second-Spartan-sum-check (via simpler uniform R1CS product-vector MLE) #347

Open arasuarun opened 2 months ago

arasuarun commented 2 months ago

Making a note of a new optimization for Spartan applied to Jolt's R1CS recently suggested by Justin. It would involve simple changes to the A, B, C mle evaluations and the inner (second) sumcheck for the prover, and mle evaluation for the verifier.

Context: the three constraint matrices A, B, C are all in block-diagonal form as there are ~60 constraints and ~80 variables of z per RISC-V cycle. A boolean vector representing a row in this big matrix can be split into two components as $r = r{small} \parallel r{big}$ where $r{small}$ denotes the constraint (from the small matrix) and $r{big}$ denotes the step counter. A column can be split $y = y{small} \parallel y{big}$ where $y{small}$ denotes the variable (from the small matrix) and $y{big}$ is again the step counter. Currently, the uniformity of the R1CS is exploited as follows. Let $NC$ be the number of columns in the big matrix.

$(A z)[r] = \sum\limits{j{small} \parallel j{big} \in \{ 0,1 \} ^{\log(NC)}} A{small}(r{small}, j{small}) eq(r{big}, j{big}) * z(j)$,

and similarly for $(Bz)[r]$ and $(Cz)[r]$.

To prepare to apply sum-check to compute (a random linear combination of) the above three sums, the prover first evaluates the mles of A, B and C at all points of the form $(r, y)$ where $y$ ranges over $\{0, 1\} ^{\log(NC)}$, as: $A(r, y) = \sum\limits{j} A{small}(r{small}, y{small}) * eq(r{big}, y{big})$. Once these values are all computed, the prover is ready to call the second sumcheck through the function prove_spartan_quadratic.

It turns out there is an even simpler way to express $(A * z)$. There's no reason to sum over $j_{big}$ at all. We can simple write:

$(A z)][r] = \sum\limits{j{small}} A{small}(r{small}, j_{small}) z(j{small} \parallel r{big})$.

This means that the second sumcheck can just iterate over the variables in y_small which is a constant independent of the program length. And to run the second sum-check, the prover just needs to evaluate the mle of A, B, C at each point $y{small} \parallel r{big}$ as $A(r, y{small} \parallel r{big}) = \sum\limits{j{small}} A{small}(r{small}, y_{small})$.