noir-lang / noir

Noir is a domain specific language for zero knowledge proofs
https://noir-lang.org
Apache License 2.0
893 stars 199 forks source link

SSA optimization is ineffective when an unconstrained function has loops #4535

Closed sirasistant closed 2 months ago

sirasistant commented 8 months ago

Problem

This code:

struct EnumEmulation {
    a: Option<Field>,
    b: Option<Field>,
    c: Option<Field>,
}

unconstrained fn main() -> pub Field {
    let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

    for _ in 0..1 {
        assert_eq(emulated_enum.a.unwrap(), 1);
    }

    emulated_enum.a = Option::some(2);
    emulated_enum.a.unwrap()
}

is functionally the same as

struct EnumEmulation {
    a: Option<Field>,
    b: Option<Field>,
    c: Option<Field>,
}

unconstrained fn main() -> pub Field {
    let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

    assert_eq(emulated_enum.a.unwrap(), 1);

    emulated_enum.a = Option::some(2);
    emulated_enum.a.unwrap()
}

However, their optimized SSAs are completely different

After Dead Instruction Elimination:
brillig fn main f0 {
  b0():
    v111 = allocate
    store u1 1 at v111
    v112 = allocate
    store Field 1 at v112
    v113 = allocate
    store u1 0 at v113
    v114 = allocate
    store Field 0 at v114
    v115 = allocate
    store u1 0 at v115
    v116 = allocate
    store Field 0 at v116
    jmp b1(u64 0)
  b1(v13: u64):
    v117 = eq v13, u64 0
    jmpif v117 then: b2, else: b3
  b2():
    v124 = load v111
    v125 = load v112
    constrain v124 == u1 1
    constrain v125 == Field 1
    v131 = add v13, u64 1
    jmp b1(v131)
  b3():
    v120 = load v113
    v121 = load v114
    v122 = load v115
    v123 = load v116
    store u1 1 at v111
    store Field 2 at v112
    store v120 at v113
    store v121 at v114
    store v122 at v115
    store v123 at v116
    return Field 2
}

vs

After Dead Instruction Elimination:
brillig fn main f0 {
  b0():
    return Field 2
}

I think mem2reg is failing to promote to reg values that are used across multiple blocks

Happy Case

Ideally, the two programs should generate the same optimized SSA

Project Impact

Nice to have

Impact Context

This generates a big blowup in bytecode size in aztec's public functions. See https://github.com/AztecProtocol/aztec-packages/pull/5153

Workaround

Yes

Workaround Description

We can reduce the amount of unused references by using traits instead of emulating enums via options, this is the case for the context enum, that we could transform into a context trait.

Additional Context

No response

Would you like to submit a PR for this Issue?

None

Support Needs

No response

jfecher commented 8 months ago

I think mem2reg is failing to promote to reg values that are used across multiple blocks

mem2reg is one of the only passes that does work well between multiple blocks (note the constant return Field 2 at the end). One missing case here is that we're not unrolling the loop since we avoid doing so completely in brillig since they can be runtime bound.

The presence of the loop here is preventing the rest of the code from being removed.

jfecher commented 8 months ago

@sirasistant how common are small loops like this? I'm also not certain how much unrolling them will help since they will use non-constant values in the loop which means instructions from each iteration will still remain in the code. This example benefits greatly from loop unrolling mostly because there are no inputs to main so the loop can be removed entirely.

If a loop is normally 5 iterations which contain 4 instructions each, unrolling even that small loop to get 20 instructions is more questionable.

sirasistant commented 8 months ago

Oh I completely made up the size of the loop, it was just to show a functionally equal example so I used a loop that runs only one time. We found this issue when making a change that made almost all public functions to contain at least one loop. As you said the problem is that the presence of the loop avoids simplifications that do apply when the loop is not there, and I thought that mem2reg is the pass that simplified it for deletion it in one case and not in the other. I'm going to dig a bit deeper to see where the difference is at.

jfecher commented 8 months ago

@sirasistant if you're just looking at why the loop isn't being unrolled that'd be the loop unrolling pass which is currently completely skipped for unconstrained code.

sirasistant commented 8 months ago

I know, it's fine for it not to be unrolled (even in the case the bounds are known at compile time, unrolled loops can lead to bigger bytecode), I meant what optimization steps are not acting due to the presence of the loop itself

sirasistant commented 8 months ago

I'm looking at the SSA and it seems that mem2reg is not identifying that the loads are known: With loop:

After Inlining:
brillig fn main f0 {
  b0():
    v7 = allocate
    store u1 1 at v7
    v8 = allocate
    store Field 1 at v8
    v9 = allocate
    store u1 0 at v9
    v10 = allocate
    store Field 0 at v10
    v11 = allocate
    store u1 0 at v11
    v12 = allocate
    store Field 0 at v12
    jmp b1(u64 0)
  b1(v13: u64):
    v15 = eq v13, u64 0
    jmpif v15 then: b2, else: b3
  b2():
    v31 = load v7
    v32 = load v8
    v33 = load v9
    v34 = load v10
    v35 = load v11
    v36 = load v12
    constrain v31 == u1 1
    v38 = eq v32, Field 1
    constrain v32 == Field 1
    v40 = add v13, u64 1
    jmp b1(v40)
  b3():
    v16 = load v7
    v17 = load v8
    v18 = load v9
    v19 = load v10
    v20 = load v11
    v21 = load v12
    store u1 1 at v7
    store Field 2 at v8
    store v18 at v9
    store v19 at v10
    store v20 at v11
    store v21 at v12
    v24 = load v7
    v25 = load v8
    v26 = load v9
    v27 = load v10
    v28 = load v11
    v29 = load v12
    constrain v24 == u1 1
    return v25
}

After Mem2Reg:
brillig fn main f0 {
  b0():
    v41 = allocate
    store u1 1 at v41
    v42 = allocate
    store Field 1 at v42
    v43 = allocate
    store u1 0 at v43
    v44 = allocate
    store Field 0 at v44
    v45 = allocate
    store u1 0 at v45
    v46 = allocate
    store Field 0 at v46
    jmp b1(u64 0)
  b1(v13: u64):
    v47 = eq v13, u64 0
    jmpif v47 then: b2, else: b3
  b2():
    v60 = load v41
    v61 = load v42
    v62 = load v43
    v63 = load v44
    v64 = load v45
    v65 = load v46
    constrain v60 == u1 1
    v66 = eq v61, Field 1
    constrain v61 == Field 1
    v67 = add v13, u64 1
    jmp b1(v67)
  b3():
    v48 = load v41
    v49 = load v42
    v50 = load v43
    v51 = load v44
    v52 = load v45
    v53 = load v46
    store u1 1 at v41
    store Field 2 at v42
    store v50 at v43
    store v51 at v44
    store v52 at v45
    store v53 at v46
    return Field 2
}

without loop


After Inlining:
brillig fn main f0 {
  b0():
    v7 = allocate
    store u1 1 at v7
    v8 = allocate
    store Field 1 at v8
    v9 = allocate
    store u1 0 at v9
    v10 = allocate
    store Field 0 at v10
    v11 = allocate
    store u1 0 at v11
    v12 = allocate
    store Field 0 at v12
    v13 = load v7
    v14 = load v8
    v15 = load v9
    v16 = load v10
    v17 = load v11
    v18 = load v12
    constrain v13 == u1 1
    v20 = eq v14, Field 1
    constrain v14 == Field 1
    v21 = load v7
    v22 = load v8
    v23 = load v9
    v24 = load v10
    v25 = load v11
    v26 = load v12
    store u1 1 at v7
    store Field 2 at v8
    store v23 at v9
    store v24 at v10
    store v25 at v11
    store v26 at v12
    v29 = load v7
    v30 = load v8
    v31 = load v9
    v32 = load v10
    v33 = load v11
    v34 = load v12
    constrain v29 == u1 1
    return v30
}

After Mem2Reg:
brillig fn main f0 {
  b0():
    v36 = allocate
    v37 = allocate
    v38 = allocate
    v39 = allocate
    v40 = allocate
    v41 = allocate
    return Field 2
}

But the presence of the loop doesn't change that the loaded values are known in both cases

sirasistant commented 8 months ago

I guess the ideal SSA mem2reg could output in this case given that It's not its job to remove loops with empty bodies, is something like this:


After Mem2Reg:
brillig fn main f0 {
  b0():
    v36 = allocate
    v37 = allocate
    v38 = allocate
    v39 = allocate
    v40 = allocate
    v41 = allocate
    jmp b1(u64 0)
  b1(v13: u64):
    v47 = eq v13, u64 0
    jmpif v47 then: b2, else: b3
  b2():
    v67 = add v13, u64 1
    jmp b1(v67)
  b3():
    return Field 2
}
jfecher commented 7 months ago

I've looked into this a bit further and it is indeed an issue with the mem2reg pass. The issue is with the core algorithm used though. Loops are just inherently tricky. The algorithm sees the blocks as follows:

b0():   ; predecessors = []
b1():   ; predecessors = [b0, b2]
b2():   ; predecessors = [b1]
b3():   ; predecessors = [b1]

The algorithm works as follows:

  1. Analyze block b0, we find known stored values for v7 through v12.
  2. Analyze block b1. The starting value of each reference is the union of the starting value of each reference from each predecessor block (!). For b0 these are known, but we haven't visited b2 yet because of the loop, so we know nothing about it and have to conservatively assume each value is unknown.
  3. Analyze block b2. The starting value of each reference is the union of each previous block, which for b1 we determined was unknown since we haven't visited b2 yet. So b2 also starts out unknown. Since these two blocks loop the choice would have been the same if we had chosen to analyze b2 first.
  4. Analyze block b3. By this point we don't know the value of any reference. The only stores we figure out are the ones in the same block since the values become known after that point.

Take aways: fixing this requires changing the algorithm itself to handle loops better. One possible option is to have an entirely separate pass to go through each block and find where each store instruction is beforehand so we know if (e.g.) b2 may mutate a reference ahead of time while we're visiting b1 during the main pass. The main downside here is that we'd essentially be doubling the runtime of this pass.

I'll also note that identifying all pointer aliases correctly has been proven to be undecidable (Landi, 1992) so while we can improve this pass, we'll never be able to perfectly remove every reference.

sirasistant commented 7 months ago

This might be related, poseidon is a bit unusable right now in brillig:

This code:

unconstrained fn main() -> pub Field {
    dep::std::hash::poseidon2::Poseidon2::hash([0; 2], 2)
}

Generates this huge SSA after optimization

After Dead Instruction Elimination:
brillig fn main f0 {
  b0():
    v1170 = allocate
    v1171 = allocate
    v1172 = allocate
    v1173 = allocate
    store [Field 0, Field 0, Field 0] at v1170
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v1171
    store u32 0 at v1172
    store u1 0 at v1173
    v1174 = allocate
    store [Field 0, Field 0, Field 0] at v1174
    v1175 = allocate
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v1175
    v1176 = allocate
    store u32 0 at v1176
    v1177 = allocate
    store u1 0 at v1177
    jmp b2(u64 0)
  b2(v32: u64):
    v1178 = lt v32, u64 2
    jmpif v1178 then: b3, else: b4
  b3():
    v1305 = lt v32, u64 2
    jmpif v1305 then: b29, else: b30
  b29():
    v1307 = array_get [Field 0, Field 0], index v32
    v1308 = load v1174
    inc_rc v1308
    store v1308 at v1174
    v1309 = load v1175
    inc_rc v1309
    store v1309 at v1175
    v1311 = load v1177
    v1312 = not v1311
    v1313 = load v1176
    v1315 = eq v1313, u32 3
    v1316 = mul v1312, v1315
    jmpif v1316 then: b31, else: b32
  b31():
    inc_rc v1308
    store v1308 at v1174
    inc_rc v1309
    store v1309 at v1175
    jmp b39(u32 0)
  b39(v275: u32):
    v1340 = lt v275, u32 3
    jmpif v1340 then: b40, else: b41
  b40():
    v1390 = load v1176
    v1392 = lt v275, v1390
    v1393 = not v1392
    jmpif v1393 then: b48, else: b49
  b48():
    v1395 = load v1174
    v1396 = load v1175
    v1397 = load v1176
    v1398 = load v1177
    v1399 = cast v275 as u64
    v1400 = array_set v1395, index v1399, value Field 0
    store v1400 at v1174
    store v1396 at v1175
    store v1397 at v1176
    store v1398 at v1177
    jmp b49()
  b49():
    v1394 = add v275, u32 1
    jmp b39(v1394)
  b41():
    jmp b42(u32 0)
  b42(v277: u32):
    v1341 = lt v277, u32 3
    jmpif v1341 then: b43, else: b44
  b43():
    v1369 = load v1174
    v1370 = load v1175
    v1371 = load v1176
    v1372 = load v1177
    v1374 = load v1175
    v1377 = cast v277 as u64
    v1378 = array_get v1374, index v1377
    v1379 = load v1174
    v1383 = array_get v1379, index v1377
    v1384 = add v1378, v1383
    v1385 = array_set v1370, index v1377, value v1384
    store v1369 at v1174
    store v1385 at v1175
    store v1371 at v1176
    store v1372 at v1177
    v1387 = add v277, u32 1
    jmp b42(v1387)
  b44():
    v1342 = load v1174
    v1344 = load v1176
    v1345 = load v1177
    v1347 = load v1175
    v1350 = call poseidon2_permutation(v1347, u32 4)
    inc_rc v1350
    store v1342 at v1174
    store v1350 at v1175
    store v1344 at v1176
    store v1345 at v1177
    v1351 = allocate
    store [Field 0, Field 0, Field 0] at v1351
    jmp b45(u32 0)
  b45(v290: u32):
    v1352 = lt v290, u32 3
    jmpif v1352 then: b46, else: b47
  b46():
    v1359 = load v1351
    v1361 = load v1175
    v1364 = cast v290 as u64
    v1365 = array_get v1361, index v1364
    v1366 = array_set v1359, index v1364, value v1365
    store v1366 at v1351
    v1368 = add v290, u32 1
    jmp b45(v1368)
  b47():
    v1354 = load v1174
    dec_rc v1354
    v1355 = load v1175
    dec_rc v1355
    v1357 = load v1177
    v1358 = array_set v1354, index u64 0, value v1307
    store v1358 at v1174
    store v1355 at v1175
    store u32 1 at v1176
    store v1357 at v1177
    jmp b38()
  b38():
    v1327 = load v1174
    dec_rc v1327
    store v1327 at v1174
    v1328 = load v1175
    dec_rc v1328
    store v1328 at v1175
    jmp b30()
  b30():
    v1306 = add v32, u64 1
    jmp b2(v1306)
  b32():
    v1318 = load v1177
    v1319 = not v1318
    v1320 = load v1176
    v1322 = eq v1320, u32 3
    v1323 = not v1322
    v1324 = mul v1319, v1323
    jmpif v1324 then: b33, else: b34
  b33():
    v1332 = load v1176
    v1333 = load v1177
    v1334 = load v1176
    v1336 = cast v1334 as u64
    v1337 = array_set v1308, index v1336, value v1307
    v1339 = add v1332, u32 1
    range_check v1339 to 32 bits
    store v1337 at v1174
    store v1309 at v1175
    store v1339 at v1176
    store v1333 at v1177
    jmp b37()
  b37():
    jmp b38()
  b34():
    v1326 = load v1177
    jmpif v1326 then: b35, else: b36
  b35():
    v1331 = array_set v1308, index u64 0, value v1307
    store v1331 at v1174
    store v1309 at v1175
    store u32 1 at v1176
    store u1 0 at v1177
    jmp b36()
  b36():
    jmp b37()
  b4():
    v1179 = load v1174
    inc_rc v1179
    store v1179 at v1174
    v1180 = load v1175
    inc_rc v1180
    store v1180 at v1175
    v1182 = load v1177
    v1183 = load v1176
    v1185 = eq v1183, u32 0
    v1186 = mul v1182, v1185
    jmpif v1186 then: b6, else: b7
  b6():
    store v1179 at v1174
    store v1180 at v1175
    store u32 0 at v1176
    store u1 0 at v1177
    jmp b7()
  b7():
    v1188 = load v1177
    v1189 = not v1188
    jmpif v1189 then: b8, else: b9
  b8():
    inc_rc v1179
    store v1179 at v1174
    inc_rc v1180
    store v1180 at v1175
    jmp b15(u32 0)
  b15(v108: u32):
    v1228 = lt v108, u32 3
    jmpif v1228 then: b16, else: b17
  b16():
    v1291 = load v1176
    v1293 = lt v108, v1291
    v1294 = not v1293
    jmpif v1294 then: b24, else: b25
  b24():
    v1296 = load v1174
    v1297 = load v1175
    v1298 = load v1176
    v1299 = load v1177
    v1300 = cast v108 as u64
    v1301 = array_set v1296, index v1300, value Field 0
    store v1301 at v1174
    store v1297 at v1175
    store v1298 at v1176
    store v1299 at v1177
    jmp b25()
  b25():
    v1295 = add v108, u32 1
    jmp b15(v1295)
  b17():
    jmp b18(u32 0)
  b18(v110: u32):
    v1229 = lt v110, u32 3
    jmpif v1229 then: b19, else: b20
  b19():
    v1270 = load v1174
    v1271 = load v1175
    v1272 = load v1176
    v1273 = load v1177
    v1275 = load v1175
    v1278 = cast v110 as u64
    v1279 = array_get v1275, index v1278
    v1280 = load v1174
    v1284 = array_get v1280, index v1278
    v1285 = add v1279, v1284
    v1286 = array_set v1271, index v1278, value v1285
    store v1270 at v1174
    store v1286 at v1175
    store v1272 at v1176
    store v1273 at v1177
    v1288 = add v110, u32 1
    jmp b18(v1288)
  b20():
    v1230 = load v1174
    v1232 = load v1176
    v1233 = load v1177
    v1235 = load v1175
    v1238 = call poseidon2_permutation(v1235, u32 4)
    inc_rc v1238
    store v1230 at v1174
    store v1238 at v1175
    store v1232 at v1176
    store v1233 at v1177
    v1239 = allocate
    store [Field 0, Field 0, Field 0] at v1239
    jmp b21(u32 0)
  b21(v125: u32):
    v1240 = lt v125, u32 3
    jmpif v1240 then: b22, else: b23
  b22():
    v1260 = load v1239
    v1262 = load v1175
    v1265 = cast v125 as u64
    v1266 = array_get v1262, index v1265
    v1267 = array_set v1260, index v1265, value v1266
    store v1267 at v1239
    v1269 = add v125, u32 1
    jmp b21(v1269)
  b23():
    v1241 = load v1239
    v1242 = load v1174
    dec_rc v1242
    v1243 = load v1175
    dec_rc v1243
    inc_rc v1241
    v1244 = load v1176
    store v1242 at v1174
    store v1243 at v1175
    store v1244 at v1176
    store u1 1 at v1177
    jmp b26(u32 0)
  b26(v180: u32):
    v1246 = lt v180, u32 3
    jmpif v1246 then: b27, else: b28
  b27():
    v1251 = load v1174
    v1252 = load v1175
    v1253 = load v1176
    v1254 = load v1177
    v1255 = cast v180 as u64
    v1256 = array_get v1241, index v1255
    v1257 = array_set v1251, index v1255, value v1256
    store v1257 at v1174
    store v1252 at v1175
    store v1253 at v1176
    store v1254 at v1177
    v1259 = add v180, u32 1
    jmp b26(v1259)
  b28():
    v1247 = load v1174
    v1248 = load v1175
    v1250 = load v1177
    store v1247 at v1174
    store v1248 at v1175
    store u32 3 at v1176
    store v1250 at v1177
    jmp b9()
  b9():
    v1190 = load v1174
    v1194 = array_get v1190, index u64 0
    jmp b10(u32 1)
  b10(v58: u32):
    v1195 = lt v58, u32 3
    jmpif v1195 then: b11, else: b12
  b11():
    v1210 = load v1176
    v1212 = lt v58, v1210
    jmpif v1212 then: b13, else: b14
  b13():
    v1214 = load v1174
    v1215 = load v1175
    v1216 = load v1176
    v1217 = load v1177
    v1218 = sub v58, u32 1
    range_check v1218 to 32 bits
    v1219 = load v1174
    v1223 = cast v58 as u64
    v1224 = array_get v1219, index v1223
    v1225 = cast v1218 as u64
    v1226 = array_set v1214, index v1225, value v1224
    store v1226 at v1174
    store v1215 at v1175
    store v1216 at v1176
    store v1217 at v1177
    jmp b14()
  b14():
    v1213 = add v58, u32 1
    jmp b10(v1213)
  b12():
    v1196 = load v1174
    v1197 = load v1175
    v1199 = load v1177
    v1202 = load v1176
    v1204 = sub v1202, u32 1
    range_check v1204 to 32 bits
    v1205 = cast v1204 as u64
    v1206 = array_set v1196, index v1205, value Field 0
    store v1204 at v1176
    store v1199 at v1177
    dec_rc v1206
    store v1206 at v1174
    dec_rc v1197
    store v1197 at v1175
    return v1194
}

Whereas on acir

After Dead Instruction Elimination:
acir fn main f0 {
  b0():
    enable_side_effects u1 0
    enable_side_effects u1 0
    enable_side_effects u1 1
    enable_side_effects u1 0
    enable_side_effects u1 0
    enable_side_effects u1 1
    enable_side_effects u1 1
    enable_side_effects u1 1
    v4039 = call poseidon2_permutation([Field 0, Field 0, Field 0, Field 2⁶⁵], u32 4)
    inc_rc v4039
    v4041 = array_get v4039, index u64 0
    dec_rc v4039
    enable_side_effects u1 1
    enable_side_effects u1 1
    return v4041
}
jfecher commented 7 months ago

@sirasistant I haven't looked into that SSA in-depth but what makes you think it is underoptimized? Comparing brillig to acir generation isn't the best comparison since Acir will always have all loops unrolled and if blocks flattened. With no blocks remaining many optimizations like mem2reg become trivial. So it's generally expected that Acir will be much simpler and smaller especially if loop bounds are small or if it uses all constant values.

The only thing that jumps out to me from your example are some loops with small bounds (2 & 3). We could look into unrolling loops in brillig if they have small, known bounds.

sirasistant commented 7 months ago

It's just that the SSA seemed quite large after optimization, I only saw a raduction in instructions of ~500 after inlining to ~300 after all the optimization pipeline, with all values in the program being constant. We are going to switch from pedersen to poseidon in aztec, let's see if the faster execution of poseidon2 compensates the more complex SSA generated (pedersen hash is a single opcode)

sirasistant commented 7 months ago

Interesting, I tried forcing trying to unroll loops in brillig to test how it'd look like with unrolled loops and it seems like it didn't help much reducing instruction count:


After Dead Instruction Elimination:
brillig fn main f0 {
  b0():
    v2020 = allocate
    v2021 = allocate
    v2022 = allocate
    v2023 = allocate
    store [Field 0, Field 0, Field 0] at v2020
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2021
    store u32 0 at v2022
    store u1 0 at v2023
    v2024 = allocate
    v2025 = allocate
    v2026 = allocate
    store u32 0 at v2026
    v2027 = allocate
    store u1 0 at v2027
    store [Field 0, Field 0, Field 0] at v2024
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2025
    jmpif u1 0 then: b55, else: b56
  b55():
    store [Field 0, Field 0, Field 0] at v2024
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2025
    jmpif u1 1 then: b120, else: b121
  b120():
    store [Field 0, Field 0, Field 0] at v2024
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2025
    store u32 0 at v2026
    store u1 0 at v2027
    jmp b121()
  b121():
    jmpif u1 1 then: b125, else: b126
  b125():
    v2219 = load v2024
    v2220 = load v2025
    v2221 = array_set v2219, index u64 1, value Field 0
    store v2221 at v2024
    store v2220 at v2025
    store u32 0 at v2026
    store u1 0 at v2027
    jmp b126()
  b126():
    jmpif u1 1 then: b130, else: b131
  b130():
    v2216 = load v2024
    v2217 = load v2025
    v2218 = array_set v2216, index u64 2, value Field 0
    store v2218 at v2024
    store v2217 at v2025
    store u32 0 at v2026
    store u1 0 at v2027
    jmp b131()
  b131():
    v2192 = load v2024
    v2193 = load v2025
    v2195 = load v2025
    v2196 = array_get v2195, index u64 0
    v2197 = load v2024
    v2199 = array_get v2197, index u64 0
    v2200 = add v2196, v2199
    v2201 = array_set v2193, index u64 0, value v2200
    v2202 = array_get v2201, index u64 1
    v2203 = array_get v2192, index u64 1
    v2204 = add v2202, v2203
    v2205 = array_set v2201, index u64 1, value v2204
    v2206 = array_get v2205, index u64 2
    v2207 = array_get v2192, index u64 2
    v2208 = add v2206, v2207
    v2209 = array_set v2205, index u64 2, value v2208
    v2210 = call poseidon2_permutation(v2209, u32 4)
    inc_rc v2210
    v2211 = allocate
    v2212 = array_get v2210, index u64 0
    v2213 = array_get v2210, index u64 1
    v2214 = array_get v2210, index u64 2
    store [v2212, v2213, v2214] at v2211
    dec_rc v2192
    dec_rc v2210
    v2215 = array_set v2192, index u64 0, value Field 0
    store v2215 at v2024
    store v2210 at v2025
    store u32 1 at v2026
    store u1 0 at v2027
    jmp b62()
  b62():
    v2028 = load v2024
    dec_rc v2028
    v2029 = load v2025
    dec_rc v2029
    inc_rc v2028
    store v2028 at v2024
    inc_rc v2029
    store v2029 at v2025
    v2031 = load v2026
    v2032 = eq v2031, u32 3
    jmpif v2032 then: b79, else: b80
  b79():
    inc_rc v2028
    store v2028 at v2024
    inc_rc v2029
    store v2029 at v2025
    v2145 = load v2026
    v2146 = lt u32 0, v2145
    v2147 = not v2146
    jmpif v2147 then: b103, else: b104
  b103():
    v2186 = load v2026
    v2187 = array_set v2028, index u64 0, value Field 0
    store v2187 at v2024
    store v2029 at v2025
    store v2186 at v2026
    store u1 0 at v2027
    jmp b104()
  b104():
    v2149 = load v2026
    v2150 = lt u32 1, v2149
    v2151 = not v2150
    jmpif v2151 then: b108, else: b109
  b108():
    v2183 = load v2024
    v2184 = load v2026
    v2185 = array_set v2183, index u64 1, value Field 0
    store v2185 at v2024
    store v2029 at v2025
    store v2184 at v2026
    store u1 0 at v2027
    jmp b109()
  b109():
    v2153 = load v2026
    v2154 = lt u32 2, v2153
    v2155 = not v2154
    jmpif v2155 then: b113, else: b114
  b113():
    v2180 = load v2024
    v2181 = load v2026
    v2182 = array_set v2180, index u64 2, value Field 0
    store v2182 at v2024
    store v2029 at v2025
    store v2181 at v2026
    store u1 0 at v2027
    jmp b114()
  b114():
    v2156 = load v2024
    v2160 = array_get v2029, index u64 0
    v2161 = load v2024
    v2163 = array_get v2161, index u64 0
    v2164 = add v2160, v2163
    v2165 = array_set v2029, index u64 0, value v2164
    v2166 = array_get v2165, index u64 1
    v2167 = array_get v2156, index u64 1
    v2168 = add v2166, v2167
    v2169 = array_set v2165, index u64 1, value v2168
    v2170 = array_get v2169, index u64 2
    v2171 = array_get v2156, index u64 2
    v2172 = add v2170, v2171
    v2173 = array_set v2169, index u64 2, value v2172
    v2174 = call poseidon2_permutation(v2173, u32 4)
    inc_rc v2174
    v2175 = allocate
    v2176 = array_get v2174, index u64 0
    v2177 = array_get v2174, index u64 1
    v2178 = array_get v2174, index u64 2
    store [v2176, v2177, v2178] at v2175
    dec_rc v2156
    dec_rc v2174
    v2179 = array_set v2156, index u64 0, value Field 0
    store v2179 at v2024
    store v2174 at v2025
    store u32 1 at v2026
    store u1 0 at v2027
    jmp b86()
  b86():
    v2038 = load v2024
    dec_rc v2038
    v2039 = load v2025
    dec_rc v2039
    inc_rc v2038
    store v2038 at v2024
    inc_rc v2039
    store v2039 at v2025
    jmpif u1 0 then: b6, else: b7
  b6():
    store v2038 at v2024
    store v2039 at v2025
    store u32 0 at v2026
    store u1 0 at v2027
    jmp b7()
  b7():
    jmpif u1 1 then: b8, else: b9
  b8():
    inc_rc v2038
    store v2038 at v2024
    inc_rc v2039
    store v2039 at v2025
    v2091 = load v2026
    v2092 = lt u32 0, v2091
    v2093 = not v2092
    jmpif v2093 then: b137, else: b138
  b137():
    v2134 = load v2026
    v2135 = array_set v2038, index u64 0, value Field 0
    store v2135 at v2024
    store v2039 at v2025
    store v2134 at v2026
    store u1 0 at v2027
    jmp b138()
  b138():
    v2095 = load v2026
    v2096 = lt u32 1, v2095
    v2097 = not v2096
    jmpif v2097 then: b142, else: b143
  b142():
    v2131 = load v2024
    v2132 = load v2026
    v2133 = array_set v2131, index u64 1, value Field 0
    store v2133 at v2024
    store v2039 at v2025
    store v2132 at v2026
    store u1 0 at v2027
    jmp b143()
  b143():
    v2099 = load v2026
    v2100 = lt u32 2, v2099
    v2101 = not v2100
    jmpif v2101 then: b147, else: b148
  b147():
    v2128 = load v2024
    v2129 = load v2026
    v2130 = array_set v2128, index u64 2, value Field 0
    store v2130 at v2024
    store v2039 at v2025
    store v2129 at v2026
    store u1 0 at v2027
    jmp b148()
  b148():
    v2102 = load v2024
    v2106 = array_get v2039, index u64 0
    v2107 = load v2024
    v2109 = array_get v2107, index u64 0
    v2110 = add v2106, v2109
    v2111 = array_set v2039, index u64 0, value v2110
    v2112 = array_get v2111, index u64 1
    v2113 = array_get v2102, index u64 1
    v2114 = add v2112, v2113
    v2115 = array_set v2111, index u64 1, value v2114
    v2116 = array_get v2115, index u64 2
    v2117 = array_get v2102, index u64 2
    v2118 = add v2116, v2117
    v2119 = array_set v2115, index u64 2, value v2118
    v2120 = call poseidon2_permutation(v2119, u32 4)
    inc_rc v2120
    v2121 = allocate
    v2122 = array_get v2120, index u64 0
    v2123 = array_get v2120, index u64 1
    v2124 = array_get v2120, index u64 2
    store [v2122, v2123, v2124] at v2121
    dec_rc v2102
    dec_rc v2120
    v2125 = array_set v2102, index u64 0, value v2122
    v2126 = array_set v2125, index u64 1, value v2123
    v2127 = array_set v2126, index u64 2, value v2124
    store v2127 at v2024
    store v2120 at v2025
    store u32 3 at v2026
    store u1 1 at v2027
    jmp b9()
  b9():
    v2044 = load v2024
    v2048 = array_get v2044, index u64 0
    v2051 = load v2026
    v2053 = lt u32 1, v2051
    jmpif v2053 then: b154, else: b155
  b154():
    v2081 = load v2024
    v2082 = load v2025
    v2083 = load v2026
    v2084 = load v2027
    v2085 = load v2024
    v2089 = array_get v2085, index u64 1
    v2090 = array_set v2081, index u64 0, value v2089
    store v2090 at v2024
    store v2082 at v2025
    store v2083 at v2026
    store v2084 at v2027
    jmp b155()
  b155():
    v2056 = load v2026
    v2058 = lt u32 2, v2056
    jmpif v2058 then: b159, else: b160
  b159():
    v2071 = load v2024
    v2072 = load v2025
    v2073 = load v2026
    v2074 = load v2027
    v2075 = load v2024
    v2079 = array_get v2075, index u64 2
    v2080 = array_set v2071, index u64 1, value v2079
    store v2080 at v2024
    store v2072 at v2025
    store v2073 at v2026
    store v2074 at v2027
    jmp b160()
  b160():
    v2059 = load v2024
    v2060 = load v2025
    v2062 = load v2027
    v2065 = load v2026
    v2067 = sub v2065, u32 1
    range_check v2067 to 32 bits
    v2068 = cast v2067 as u64
    v2069 = array_set v2059, index v2068, value Field 0
    store v2067 at v2026
    store v2062 at v2027
    dec_rc v2069
    store v2069 at v2024
    dec_rc v2060
    store v2060 at v2025
    return v2048
  b80():
    v2034 = load v2026
    v2035 = eq v2034, u32 3
    v2036 = not v2035
    jmpif v2036 then: b81, else: b82
  b81():
    v2139 = load v2026
    v2140 = load v2026
    v2141 = cast v2140 as u64
    v2142 = array_set v2028, index v2141, value Field 0
    v2144 = add v2139, u32 1
    range_check v2144 to 32 bits
    store v2142 at v2024
    store v2029 at v2025
    store v2144 at v2026
    store u1 0 at v2027
    jmp b85()
  b85():
    jmp b86()
  b82():
    jmpif u1 0 then: b83, else: b84
  b83():
    v2138 = array_set v2028, index u64 0, value Field 0
    store v2138 at v2024
    store v2029 at v2025
    store u32 1 at v2026
    store u1 0 at v2027
    jmp b84()
  b84():
    jmp b85()
  b56():
    jmpif u1 1 then: b57, else: b58
  b57():
    store [Field 0, Field 0, Field 0] at v2024
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2025
    store u32 1 at v2026
    store u1 0 at v2027
    jmp b61()
  b61():
    jmp b62()
  b58():
    jmpif u1 0 then: b59, else: b60
  b59():
    store [Field 0, Field 0, Field 0] at v2024
    store [Field 0, Field 0, Field 0, Field 2⁶⁵] at v2025
    store u32 1 at v2026
    store u1 0 at v2027
    jmp b60()
  b60():
    jmp b61()
}

It seems to take advantage of unrolled loops in brillig we could be missing an optimization step for non-flattened control flow that would

jfecher commented 7 months ago

It seems to take advantage of unrolled loops in brillig we could be missing an optimization step for non-flattened control flow that would

  • switch jmpif with a constant value to jump
  • prune blocks that are never targeted
  • merge blocks a & b where b predecessor is only a and a successor is only b

The last point should be done already by the simplify_cfg pass which does currently apply to brillig code as well.

  • prune blocks that are never targeted

The second point should be automatic by the function printer, since blocks are actually never removed in general. So it must still be reachable.

  • switch jmpif with a constant value to jump

This is one area where brillig opts can improve I think. We used to do this automatically in more passes but it was removed for bugs (IIRC) and we just rely on a check during flattening currently. We can try adding this to another pass (simplify_cfg? A different pass?)

TomAFrench commented 6 months ago
struct EnumEmulation {
    a: Option<Field>,
    b: Option<Field>,
    c: Option<Field>,
}

unconstrained fn main() -> pub Field {
    let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

    for _ in 0..1 {
        assert_eq(emulated_enum.a.unwrap(), 1);
    }

    emulated_enum.a = Option::some(2);
    emulated_enum.a.unwrap()
}

On current master, this program now compiles down to

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    v50 = allocate
    store u1 1 at v50
    v51 = allocate
    store Field 1 at v51
    v52 = allocate
    store u1 0 at v52
    v53 = allocate
    store Field 0 at v53
    v54 = allocate
    store u1 0 at v54
    v55 = allocate
    store Field 0 at v55
    jmp b1(u64 0)
  b1(v10: u64):
    v56 = eq v10, u64 0
    jmpif v56 then: b2, else: b3
  b2():
    v63 = load v50
    v64 = load v51
    constrain v63 == u1 1
    constrain v64 == Field 1
    v70 = add v10, u64 1
    jmp b1(v70)
  b3():
    v59 = load v52
    v60 = load v53
    v61 = load v54
    v62 = load v55
    store u1 1 at v50
    store Field 2 at v51
    store v59 at v52
    store v60 at v53
    store v61 at v54
    store v62 at v55
    return Field 2
}
TomAFrench commented 6 months ago

Ah misread and though this was improved (but still suboptimal), looks like it's unchanged :sweat_smile:

sirasistant commented 6 months ago

Yeah I think for this to improve mem2reg needs to do more than one pass of the code to gather information about loops ):

TomAFrench commented 5 months ago

An extreme example is regression_4709 which resolves acc to a constant whereas in brillig we need to calculate this at runtime.

https://github.com/noir-lang/noir/pull/5128

sirasistant commented 4 months ago

Another extreme example, 15% of the runtime in this test is spent evaluating std::compat::is_bn254(), which should be known at compile time (but contains a loop)

TomAFrench commented 4 months ago

We're currently not inlining any functions in brillig no matter how small afaik. This along with the fact that we don't perform any optimizations during brillig-gen means that these checks make their way into runtime code whereas they should be optimized out. As well as std::compat::is_bn254() we're going to be preserving all of the bytecode for the other branch which we don't need (e.g. in lt for fields).

@sirasistant It might be worth looking into a pass which automatically inlines brilllig functions which contain a number of instructions below a certain limit (maybe 5-10 or so?) to get some of these benefits of constant folding without increasing the bytecode size) There's also some overhead associated with performing a brillig function call so that could be used to make sure inlining always reduces the bytecode size.

sirasistant commented 4 months ago

looking into a pass which automatically inlines brilllig functions

Do you mean unroll? we are always inlining in brillig currently unless it's part of a recursive loop

sirasistant commented 4 months ago

For this function:

unconstrained fn main() {
    assert(std::compat::is_bn254());
}

This is the final SSA after all optimizations:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    v57 = allocate
    store u1 1 at v57
    jmp b1(u32 0)
  b1(v2: u32):
    v58 = lt v2, u32 2⁵
    jmpif v58 then: b2, else: b3
  b2():
    v60 = load v57
    v61 = lt v2, u32 2⁵
    constrain v61 == u1 1 '"Index out of bounds"'
    v62 = array_get [u8 2⁴×3, u8 100, u8 78, u8 114, u8 225, u8 49, u8 2⁴×10, u8 41, u8 184, u8 2⁴×5, u8 69, u8 182, u8 129, u8 129, u8 88, u8 93, u8 40, u8 51, u8 232, u8 72, u8 121, u8 185, u8 2⁴×7, u8 145, u8 67, u8 225, u8 245, u8 147, u8 2⁴×15, u8 0, u8 0, u8 1], index v2
    v63 = array_get [u8 2⁴×3, u8 100, u8 78, u8 114, u8 225, u8 49, u8 2⁴×10, u8 41, u8 184, u8 2⁴×5, u8 69, u8 182, u8 129, u8 129, u8 88, u8 93, u8 40, u8 51, u8 232, u8 72, u8 121, u8 185, u8 2⁴×7, u8 145, u8 67, u8 225, u8 245, u8 147, u8 2⁴×15, u8 0, u8 0, u8 1], index v2
    v64 = eq v62, v63
    v65 = mul v60, v64
    store v65 at v57
    v66 = add v2, u32 1
    jmp b1(v66)
  b3():
    v59 = load v57
    constrain v59 == u1 1
    return 
}
TomAFrench commented 4 months ago

Ah, my bad. You're right, I misread the inlining pass when checking this.

That said, we should be able to unroll loops in any brillig function without arguments or foreign calls safely.

jfecher commented 4 months ago

That said, we should be able to unroll loops in any brillig function without arguments or foreign calls safely.

We definitely can try. I think the main concern before was just increasing code size. Unsure what a reasonable limit is but I'm sure we can think of some arbitrary measure.

TomAFrench commented 4 months ago

This shouldn't increase code size as in the case where there are no runtime arguments/foreign call results/etc SSA optimizations should be able to simplify the function down into a constant return value.

TomAFrench commented 4 months ago

To benefit from this however I think we'll need to run these functions which can be simplified down to a constant value through the entire SSA pipeline first before we can inline them into other functions.

jfecher commented 2 months ago

Was looking at this a bit today in standup and talking with @vezenovm and we noticed the brillig can be improved a bit if we manually make a read only copy before the loop:

struct EnumEmulation {
    a: Option<Field>,
    b: Option<Field>,
    c: Option<Field>,
}

unconstrained fn main() -> pub Field {
    let mut emulated_enum = EnumEmulation { a: Option::some(1), b: Option::none(), c: Option::none() };

    let my_copy = emulated_enum;
    for _ in 0..1 {
        assert_eq(my_copy.a.unwrap(), 1);
    }

    emulated_enum.a = Option::some(2);
    emulated_enum.a.unwrap()
}

With this the new SSA would be:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    v36 = allocate
    store u1 1 at v36
    v37 = allocate
    store Field 1 at v37
    v38 = allocate
    store u1 0 at v38
    v39 = allocate
    store Field 0 at v39
    v40 = allocate
    store u1 0 at v40
    v41 = allocate
    store Field 0 at v41
    jmp b1(u32 0)
  b1(v10: u32):
    v42 = eq v10, u32 0
    jmpif v42 then: b2, else: b3
  b2():
    v49 = add v10, u32 1
    jmp b1(v49)
  b3():
    v45 = load v38
    v46 = load v39
    v47 = load v40
    v48 = load v41
    store u1 1 at v36
    store Field 2 at v37
    store v45 at v38
    store v46 at v39
    store v47 at v40
    store v48 at v41
    return Field 2
}

Which is 4 fewer instructions within the loop compared to Tom's example. I could see this being important especially for larger loops. It should also be a fairly easy optimization for users to apply even if it's difficult for the compiler to identify this automatically.

Another thing we noticed is these instructions at the end:

v45 = load v38
...
store v45 at v38

Where we're loading a value only to immediately store it again. In mem2reg when we load from a reference whose value is unknown, that value will stay unknown afterward, but this doesn't need to be the case! We now know that the reference loads to v45 so its value can now be known! This would help us remove any subsequent loads to v38 (although there are none here)

Another check we could add though is for store v45 at v38 - ie. storing a value in a reference when we know that the reference already holds that value. Since we now know that v38 holds v45 internally, we can remove this store entirely. Then, we can remove v45 = load v38 from dead instruction elimination along with potentially v38 = allocate and store u1 0 at v38 as well. The only hiccup to removing all of them would be that the last load would be removed during DIE but the last store is only removed during mem2reg which happens beforehand. That store staying would prevent the allocate from being removed as well.

With all these optimizations the (theoretical) SSA would be:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    v36 = allocate
    store u1 1 at v36
    v37 = allocate
    store Field 1 at v37
    jmp b1(u32 0)
  b1(v10: u32):
    v42 = eq v10, u32 0
    jmpif v42 then: b2, else: b3
  b2():
    v49 = add v10, u32 1
    jmp b1(v49)
  b3():
    store u1 1 at v36
    store Field 2 at v37
    return Field 2
}

If we can identify that v36 and v37 are local to the current function, never loaded from, and not returned, then we can remove them as well:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    jmp b1(u32 0)
  b1(v10: u32):
    v42 = eq v10, u32 0
    jmpif v42 then: b2, else: b3
  b2():
    v49 = add v10, u32 1
    jmp b1(v49)
  b3():
    return Field 2
}
sirasistant commented 2 months ago

That sounds great! If we reach this point:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    v36 = allocate
    store u1 1 at v36
    v37 = allocate
    store Field 1 at v37
    jmp b1(u32 0)
  b1(v10: u32):
    v42 = eq v10, u32 0
    jmpif v42 then: b2, else: b3
  b2():
    v49 = add v10, u32 1
    jmp b1(v49)
  b3():
    store u1 1 at v36
    store Field 2 at v37
    return Field 2
}

we can probably update the unroller to delete empty loops like this one

vezenovm commented 2 months ago

I get to this point in #5865 if we do a copy outside the loop and read the copy in the loop:

After Array Set Optimizations:
brillig fn main f0 {
  b0():
    jmp b1(u32 0)
  b1(v10: u32):
    v42 = eq v10, u32 0
    jmpif v42 then: b2, else: b3
  b2():
    v49 = add v10, u32 1
    jmp b1(v49)
  b3():
    return Field 2
}

Just catching an edge case for arrays w/ inc_rc, but the optimization looks promising.