noir-lang / noir

Noir is a domain specific language for zero knowledge proofs
https://noir-lang.org
Apache License 2.0
902 stars 203 forks source link

Code gen nested arrays with a flat memory structure for ACIR runtimes #6231

Open vezenovm opened 1 month ago

vezenovm commented 1 month ago

Problem

Our nested arrays have a recursive structure. When working with dynamic indices, this can lead to a lot of extra array operations that would otherwise be unnecessary if we had a flat memory representing the nested array.

Happy Case

Goal: Treat nested arrays as a single flat memory. We want to trade off extra array get/set operations for extra calculations to compute the index into the nested memory.

For example:

  v10 = array_get [[Field 1, [Field 2], [Field 3]]], index v3
  v12 = array_get v10, index u32 0
  v14 = array_get v10, index u32 1
  v16 = array_get v10, index u32 2

Here we do not need the first array get. We should be able to determine the following for the second array get and the other follow-ups.

v14 = array_get [[Field 1, [Field 2], [Field 3]]], index (v3 * 3) + 1

In this case the extra computation needed to compute the index might not benefit us as much as getting rid of the first array get. But when it comes to writing into a nested array the benefits can be much greater.

Take this code:

struct Bar {
    inner: [Field; 3],
}
struct Foo {
    a: Field,
    b: [Field; 3],
    bar: Bar,
}
struct FooParent {
    array: [Field; 3],
    foos: [Foo; 1],
}
fn main(mut x: [Foo; 1], y: pub Field) {
    let foo_parent_one = FooParent { array: [0, 1, 2], foos: x };
    let foo_parent_two = FooParent { array: [3, 4, 5], foos: x };
    let mut foo_parents = [foo_parent_one, foo_parent_two];

    if y == 3 {
        foo_parents[y - 2].foos[y - 2].b[y - 1] = 5000;
    } else {
        foo_parents[y - 2].foos[y - 2].b[y - 1] = 1000;
    }
    assert(foo_parents[1].foos[1].b[2] == 5000);
}

This is the SSA for the if block:

  b2():
    inc_rc [[Field 0, Field 1, Field 2], v0, [Field 3, Field 4, Field 5], v0]
    v34 = sub v1, Field 2
    v35 = cast v34 as u32
    v36 = mul v35, u32 2
    v37 = add v36, u32 1
    v38 = array_get [[Field 0, Field 1, Field 2], v0, [Field 3, Field 4, Field 5], v0], index v37
    inc_rc v38
    v39 = mul v35, u32 3
    v40 = add v39, u32 1
    v41 = array_get v38, index v40
    inc_rc v41
    v42 = add v39, u32 2
    v43 = array_get v38, index v42
    v44 = sub v1, Field 1
    v45 = cast v44 as u32
    v47 = array_set v41, index v45, value Field 5000
    v48 = array_set mut v38, index v40, value v47
    v49 = add v40, u32 1
    v50 = array_set v48, index v49, value v43
    v51 = array_set mut [[Field 0, Field 1, Field 2], v0, [Field 3, Field 4, Field 5], v0], index v37, value v50
    store v51 at v11
    jmp b3()

There is a lot of fetching and setting that is unnecessary. Looking at v43 we can see that it is actually set back into the equivalent index in the nested array. We should be able to have a single array_set inside of b2 with some extra arithmetic operations on the dynamic index.

Workaround

Yes

Workaround Description

Flatten the memory manually. This is a pretty infeasible workaround to put on developers though and we should do this flattening in the compiler.

Here is the example above manually flattened:

fn main(mut x: [Foo; 1], y: pub Field) {
    let mut foo_parents = [0; 20];
    foo_parents = [
        0, 1, 2, x[0].a, x[0].b[0], x[0].b[1], x[0].b[2], x[0].bar.inner[0], x[0].bar.inner[1], x[0].bar.inner[2],
        3, 4, 5, x[0].a, x[0].b[0], x[0].b[1], x[0].b[2], x[0].bar.inner[0], x[0].bar.inner[1], x[0].bar.inner[2]
    ];
    let index_foo_parent = ((y - 2) * 10);
    // Add 3 to skip `array` in FooParent
    let index_foos = index_foo_parent + 3;
    // Add 1 to skip `a` in Foo. 
    let index_foo_b = index_foos + 3;
    if y == 3 {
        foo_parents[index_foo_b] = 5000;
    } else {
        foo_parents[index_foo_b] = 1000;
    }
    assert(foo_parents[16] == 5000);
}

This main is 229 Brillig opcodes vs. 383 Brillig opcodes in the original code snippet.

Here is the SSA of the flat main:

After Array Set Optimizations:
brillig fn main f0 {
  b0(v0: [Field, [Field; 3], [Field; 3]; 1], v1: Field):
    v2 = allocate
    v4 = array_get v0, index u32 0
    v6 = array_get v0, index u32 1
    v8 = array_get v0, index u32 2
    v9 = array_get v6, index u32 0
    v10 = array_get v6, index u32 1
    v11 = array_get v6, index u32 2
    v12 = array_get v8, index u32 0
    v13 = array_get v8, index u32 1
    v14 = array_get v8, index u32 2
    inc_rc [Field 0, Field 1, Field 2, v4, v9, v10, v11, v12, v13, v14, Field 3, Field 4, Field 5, v4, v9, v10, v11, v12, v13, v14]
    store [Field 0, Field 1, Field 2, v4, v9, v10, v11, v12, v13, v14, Field 3, Field 4, Field 5, v4, v9, v10, v11, v12, v13, v14] at v2
    v22 = sub v1, Field 2
    v24 = mul v22, Field 10
    v25 = add v24, Field 3
    v26 = add v25, Field 3
    v27 = eq v1, Field 3
    jmpif v27 then: b2, else: b1
  b2():
    inc_rc [Field 0, Field 1, Field 2, v4, v9, v10, v11, v12, v13, v14, Field 3, Field 4, Field 5, v4, v9, v10, v11, v12, v13, v14]
    v31 = cast v26 as u32
    v33 = array_set mut [Field 0, Field 1, Field 2, v4, v9, v10, v11, v12, v13, v14, Field 3, Field 4, Field 5, v4, v9, v10, v11, v12, v13, v14], index v31, value Field 5000
    store v33 at v2
    jmp b3()

There are a lot more array get operations for fetching from the dynamic main inputs, however, b2 has been greatly simplified and we removed lots of unnecessary array sets (which are a lot more expensive than array gets). Technically if v0 was also treated as a flat list of memory we could also just have 7 array gets rather than 9 array gets as we would not need to do the unnecessary fetching of v6 and v8.

Additional Context

Additional updates such as treating the block param as an array constant rather than a single value ID would also allow us to get rid of all the array gets that precede the store into v2. This change is slightly different than flattening the memory itself and can be done in follow-up work.

Project Impact

Nice-to-have

Blocker Context

This does not truly block anything, but would be a very good optimization to have.

Would you like to submit a PR for this Issue?

None

Support Needs

No response

jfecher commented 1 month ago

We could possibly perform this lowering in ssa-gen as well whenever we see a nested array type, translating it into a larger unnested array. I think the main difficulty would be when users do access and pass around the sub-arrays. We'd probably have to iterate through each index at that point to make a new array. Then if they later set the sub-array we'd have to iterate and set each index as well. I don't expect that to be a common case though.

vezenovm commented 1 month ago

We could possibly perform this lowering in ssa-gen as well whenever we see a nested array type, translating it into a larger unnested array.

Yeah I was thinking to do this during ssa gen in the manner you laid out.

I think the main difficulty would be when users do access and pass around the sub-arrays. We'd probably have to iterate through each index at that point to make a new array. Then if they later set the sub-array we'd have to iterate and set each index as well. I don't expect that to be a common case though.

I agree I think this case will most likely just require making a new array for the sub array.