a16z / jolt

The simplest and most extensible zkVM. Fast and fully open source from a16z crypto and friends. ⚡
https://jolt.a16zcrypto.com
MIT License
619 stars 123 forks source link

Accelerate R1CSShape::multiply_vec_uniform #297

Closed sragss closed 1 week ago

sragss commented 4 months ago

R1CSShape::multiply_vec_uniform (here) generates the Az / Bz / Cz vectors for Spartan. Z is the full witness. A, B, C are the R1CS matrices.

Performance seems to vary greatly across machines / architectures but this function cost scales slightly super-linearly with cycle count.

In the case of our instance of Spartan A, B, C are highly uniform: they have a "sparse diagonal" construction. Because the same R1CS circuit is run per step of the CPU, we have a concept of a "single step shape" (represented by a,b,c,d below).

$$ [ A = \begin{bmatrix} a & b & 0 & 0 & 0 & 0 \ c & d & 0 & 0 & 0 & 0 \ 0 & 0 & a & b & 0 & 0 \ 0 & 0 & c & d & 0 & 0 \ 0 & 0 & 0 & 0 & a & b \ 0 & 0 & 0 & 0 & c & d \ \end{bmatrix} ] $$

sragss commented 4 months ago

Current thinking is that we can use some preprocessing on the uniform A, B, C matrices to determine what they'll end up reading and writing.

A given sparse entry's column in {A, B, C} determines which padded_trace_length segment of the witness z will be read. All entries from a given row will write to the same section of the output vecs {A, B, C}z.

Currently we parallelize across rows because it is easy:

My thinking is that we can instead group by column, which would allow preprocessing the witness segment z and better cache hits on the witness segment z.

Here's the column accesses by the different sparse small M matrices {A, B, C}:

0: ["A", "B", "B", "B", "B", "B", "C"]
1: ["A", "B"]
2: ["A"]
3: ["C"]
4: ["A", "A"]
7: ["B", "B", "C"]
8: ["B"]
9: ["B", "B", "C"]
10: ["B", "B", "C"]
12: ["B"]
13: ["B"]
14: ["B"]
15: ["B"]
16: ["B", "B"]
17: ["A", "B"]
18: ["A", "B"]
19: ["A", "B"]
20: ["A", "B"]
21: ["A", "A"]
22: ["A", "A"]
23: ["A", "A"]
24: ["A", "A"]
25: ["A", "B", "C"]
26: ["A", "B", "C"]
27: ["A", "B", "C"]
28: ["A", "B", "B", "B", "B", "C"]
29: ["A", "A"]
30: ["A", "A"]
31: ["A", "A"]
32: ["A", "A"]
33: ["B", "B", "B", "B"]
34: ["A", "A", "A", "B"]
35: ["A", "A", "A", "B"]
36: ["A", "A", "A", "A", "A", "A", "A", "A", "B"]
37: ["A", "A", "A", "A", "A", "B"]
38: ["A", "A", "A", "B", "B"]
39: ["A", "A", "A", "B"]
40: ["A", "A", "B", "B"]
41: ["A", "A", "A", "B"]
42: ["A", "A", "B", "B", "B", "B", "B", "B", "B"]
43: ["A", "A", "A", "B"]
44: ["A", "A", "A", "B"]
45: ["A", "A", "B"]
46: ["A", "A", "B"]
47: ["A", "A", "B"]
48: ["A", "A", "B"]
49: ["A", "A", "B"]
50: ["A", "A", "B"]
51: ["A", "A", "B"]
52: ["A", "A", "B"]
53: ["A", "A", "B"]
54: ["A", "A", "B"]
55: ["A", "A", "B"]
56: ["A", "A", "B"]
57: ["A", "A", "B"]
58: ["A", "A", "B"]
59: ["A", "A", "A", "A", "A", "A", "B"]
60: ["A", "A", "A", "A", "A", "A", "B"]
61: ["A", "A", "A", "A", "A", "A", "B"]
62: ["B", "B", "C"]
63: ["A", "B", "B", "C"]
64: ["A", "B", "B", "C"]
65: ["B", "B", "C"]
66: ["B", "B", "B", "B", "C"]
67: ["A", "C"]
68: ["A", "C"]
69: ["A", "C"]
70: ["A", "C"]
71: ["A", "C"]
72: ["A", "C"]
73: ["A", "C"]
74: ["B", "C", "C"]
75: ["A", "C"]
128: ["A", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "B", "C"]

If these do not intersect with the rows too much we could get good speedups switching the direction of parallelism.

Further, 50-90% of the values in the full_witness_vector that are accessed are zero. We don't pay too bad for this because mul_0_1_optimized, but still we'd read and write less if we handle this more intelligently.

result.iter_mut().enumerate().for_each(|(i, out)| {
    *out += mul_0_1_optimized(&val, &full_witness_vector[witness_offset + i]);
});
flyq commented 3 months ago

fyi: https://github.com/microsoft/Nova/pull/232