o1-labs / o1js

TypeScript framework for zk-SNARKs and zkApps
https://docs.minaprotocol.com/en/zkapps/how-to-write-a-zkapp
Apache License 2.0
479 stars 107 forks source link

`Provable.Array(UInt32, 64)` fails with OOM: `RuntimeError: unreachable` #1391

Open emlautarom1 opened 5 months ago

emlautarom1 commented 5 months ago

We're trying to build a UInt2048 library (minimal enough for RSA), and we're getting the following long error:

    RuntimeError: unreachable

      at __rg_oom (wasm:/wasm/0130b18a:1:4434388)
      at __rust_alloc_error_handler (wasm:/wasm/0130b18a:1:4437226)
      at alloc::alloc::handle_alloc_error::rt_error::hf68089e509c8318c (wasm:/wasm/0130b18a:1:4438023)
      at alloc::alloc::handle_alloc_error::hd7440911893bc97b (wasm:/wasm/0130b18a:1:4438012)
      at alloc::raw_vec::RawVec<T,A>::reserve::do_reserve_and_handle::h94aa91d7deef96d5 (wasm:/wasm/0130b18a:1:4306437)
      at <o1_utils::serialization::SerdeAs as serde_with::ser::SerializeAs<T>>::serialize_as::h05f8705bb1e9b20a (wasm:/wasm/0130b18a:1:3928697)
      at kimchi::circuits::constraints::_::<impl serde::ser::Serialize for kimchi::circuits::constraints::ColumnEvaluations<F>>::serialize::h37f3183275b476c9 (wasm:/wasm/0130b18a:1:3178428)
      at kimchi::prover_index::_::<impl serde::ser::Serialize for kimchi::prover_index::ProverIndex<G,OpeningProof>>::serialize::hcb9c41293c9facf0 (wasm:/wasm/0130b18a:1:3523077)
      at caml_pasta_fp_plonk_index_encode (wasm:/wasm/0130b18a:1:3690790)
      at Object.<anonymous>.module.exports.caml_pasta_fp_plonk_index_encode (o1js/dist/node/bindings/compiled/_node_bindings/plonk_wasm.cjs:1475:14)
      at encodeProverKey (o1js/src/lib/proof-system/prover-keys.ts:106:26)
      at write_ (o1js/src/lib/proof_system.ts:644:48)
      at ../../../../../workspace_root/src/lib/snarkyjs/src/bindings/ocaml/lib/pickles_bindings.ml:404:40
      at filter_map$1 (ocaml/base/list.ml:812:14)
      at write$0 (src/lib/key_cache/key_cache.ml:150:5)
      at pk (src/lib/pickles/cache.ml:113:17)
      at force_lazy_block (ocaml/ocaml/camlinternalLazy.ml:31:18)
      at vk (ocaml/ocaml/camlinternalLazy.ml:27:5)
      at force_lazy_block (ocaml/ocaml/camlinternalLazy.ml:31:18)
      at _izk_ (ocaml/ocaml/camlinternalLazy.ml:27:5)
      at map$34 (src/lib/pickles_types/vector.ml:140:49)
      at map$34 (src/lib/pickles_types/vector.ml:140:56)
      at step_vks (src/lib/pickles/compile.ml:633:21)
      at force_lazy_block (ocaml/ocaml/camlinternalLazy.ml:31:18)
      at ../../../../gregor/.opam/4.14.0/lib/ocaml/camlinternalLazy.ml:27:5
      at with_label (src/lib/snarky/src/base/snark0.ml:1253:15)
      at with_label (src/lib/snarky/src/base/snark0.ml:1253:15)
      at main (src/lib/pickles/compile.ml:660:40)
      at ../../../../../workspace_root/src/lib/pickles/compile.ml:343:27
      at as_stateful (src/lib/snarky/src/base/snark0.ml:755:15)
      at g (src/lib/snarky/src/base/snark0.ml:753:9)
      at constraint_count (src/lib/snarky/src/base/checked_runner.ml:320:13)
      at log (src/lib/snarky_log/snarky_log.ml:28:7)
      at compile_with_wrap_main_overrid (src/lib/pickles/compile.ml:339:9)
      at compile_promise (src/lib/pickles/pickles.ml:315:5)
      at Function.Class.<computed> [as compile] (o1js/src/bindings/js/proxy.js:20:52)
      at node_modules/o1js/src/lib/proof_system.ts:659:28
      at withThreadPool (o1js/src/bindings/js/node/node-backend.js:55:20)
      at prettifyStacktracePromise (o1js/src/lib/errors.ts:137:12)
      at compileProgram (o1js/src/lib/proof_system.ts:653:5)
      at Function.compile (o1js/src/lib/zkapp.ts:685:9)
      at src/BigInt.test.ts:102:5

I was able to trace it to the following third-party sources:

I don't know exactly how serialization is being handled on these bridges, but it seems that at some point it's trying to allocate an array bigger than what it can, triggering a panic.

If we use Provable.Array(UInt32, 48) instead (note that this does not work for our use case), then no panic gets triggered.

A possible workaround (we believe) is to use two Provable.Array(UInt32, 32) instead, but this causes a lot of inconveniences when it comes to indexing operations: array[i] = x and y = array[i] become non-trivial (we could not figure out how to implement these operations).

emlautarom1 commented 5 months ago

The full code that is causing the issue:

const Words64 = Provable.Array(UInt32, 64);

export class UInt2048 extends Struct({ words: Words64 }) {

  static zero() {
    return new UInt2048({ words: Array(64).fill(UInt32.zero) });
  }

  mul(other: UInt2048): UInt2048 {
    let result: UInt2048 = UInt2048.zero();

    for (let j = 0; j < 64; j++) {
      let carry = UInt64.zero;
      for (let i = 0; i + j < 64; i++) {
        // Perform the multiplication in UInt64 to ensure that the result always fits (no overflow here)
        let product: UInt64 = this.words[i].toUInt64()
          .mul(other.words[j].toUInt64())
          // Add the previous result for this word index
          .add(result.words[i + j].toUInt64())
          // Lastly, add the previous carry
          .add(carry);

        let { quotient: highBits, rest: lowBits } = product.divMod(UInt64.from("4294967296" /* 2^32 */));
        // Keep only the value that fits in a UInt32 (the low bits)
        result.words[i + j] = lowBits.toUInt32();
        // Extract the carry from the product by keeping the bits that could not fit in a UInt32 (the high bits).
        // This carry will be used in the next iteration
        carry = highBits;
      }
    }

    return result;
  }
}
mitschabaude commented 5 months ago

So how does your zkprogram / contract look like? How many of those bigint multiplications are you performing?

Would also be interested in number of constraints. Print contract.analyzeMethods()

Your code looks quite constraint inefficient, any reason you didn't start from the RSA example I wrote that's in an open PR?

emlautarom1 commented 5 months ago

So how does your zkprogram / contract look like?

We are performing a single multiplication:

export class TestContract extends SmartContract {
  @method mul(a: UInt2048, b: UInt2048): UInt2048 {
    return a.mul(b);
  }
}

The test code:

xit("multiplies", async () => {
    await localDeploy();

    // Assume a valid `fromHexString` that exists only on JS-land
    let a = UInt2048.fromHexString("0xFFFFFFFFAAAAAAAA");
    let b = UInt2048.fromHexString("0xEEEEEEEEBBBBBBBB");

    let res!: UInt2048;
    let retrieve = await Mina.transaction(user_Account, () => {
      res = zkApp.mul(a, b);
    });
    await retrieve.prove();
    await retrieve.sign([user_Key]).send();

    expect(res.words[0].toBigint()).toBe(BigInt("0x2D82D82E"));
    expect(res.words[1].toBigint()).toBe(BigInt("0xCCCCCCCD"));
    expect(res.words[2].toBigint()).toBe(BigInt("0x6C16C16A"));
    expect(res.words[3].toBigint()).toBe(BigInt("0xEEEEEEEE"));
    for (let i = 4; i < res.words.length; i++) {
      const word = res.words[i];
      expect(word.toBigint()).toBe(0n);
    }
  });

Print contract.analyzeMethods()

The output gets trimmed with a console.log and the same goes for JSON.stringify. Any suggestions on how to share this information?

Your code looks quite constraint inefficient, any reason you didn't start from the RSA example I wrote that's in an open PR?

We are exploring different options given that the code in the PR does not work (due to some issues with arithmetic operations). Here, we tried to use UInt32 as a primitive type and implement the arithmetic operations just like you would do in a common programming language like C. Still, I would appreciate any suggestions on how to improve this particular implementation.

emlautarom1 commented 5 months ago

I ended up dumping the JSON to a file and it's over 29 MB. Here is the zipped version: analysis.zip

mitschabaude commented 5 months ago

I ended up dumping the JSON to a file and it's over 29 MB. Here is the zipped version: analysis.zip

Thanks, I only wanted to know the rows (number of constraints) :D which is "rows":5412

We are exploring different options given that the code in the PR does not work (due to some issues with arithmetic operations).

Makes sense! Happy to help get the PR to work

mitschabaude commented 5 months ago

Still, I would appreciate any suggestions on how to improve this particular implementation.

I would in general use the builtin Field type for low level operations because that won't generate any "unexpected" constraints. You should get to a point where you know exactly what constraints your code will generate -- only then you'll be able to optimize it.

For example, with Field an addition or multiplication will translate pretty directly to PLONK generic gates.

For Uint32, however, every addition or multiplication comes with additional constraints that prove that the output again fits in 32 bits (only that way we can justify to return a Uint32). You don't want those extra checks, because you can control precisely how many bits your outputs will have, and you better try to use as much of the available 254 bits of a field element as you can

emlautarom1 commented 5 months ago

So I went ahead and replaced all usages of UIntX for Field, but in practice I'm treating it as if it were a UInt64, such that multiplying two Fields in this context does not cause overflow (it's always <= 2^254). I also was able to reduce the Array to 32 Fields which should help with memory usage (32 x UInt64 = UInt2048)

Now, I'm facing an issue with this line given that Field does not implement divMod:

let { quotient: highBits, rest: lowBits } = product.divMod(UInt64.from("4294967296" /* 2^32 */));

I wanted to change the code for the following which should be equivalent:

let highBits = product.div(Field.from(18446744073709551616n /* 2^64 */ ));
let lowBits = product.sub(highBits.mul(Field.from(18446744073709551616n /* 2^64 */ )));

This does not produce the same result, specifically highBits are different (printing it as BigInt shows extremely large values). Could this be due to the modular nature of the underlying Field, and if so, how can we avoid this issue?

Another option is to use Field::toBits as follows:

let bits = product.toBits(128);
let lowBits = Field.fromBits(bits.slice(0, 64));
let highBits = Field.fromBits(bits.slice(64, 128));

This works on JS-land, but inside a ZK-program it takes > 10 minutes to perform a single computation. The number of rows for this approach is 99592, compared to the 254 of the "divMod equivalent" approach, which makes me think that it's very inefficient.

mitschabaude commented 5 months ago

@emlautarom1 fundamentally, to prove a split into low and high bits, you do the following:

  1. witness the low and high parts
  2. prove that the low and high parts are at most 64 bits
  3. prove that low + high * 2^64 = original

Number 3 should be easy. Number 2 is a 64-bit range check, you can use Gadgets.rangeCheck64(). For number 1, use Provable.witness(Field, () => {...}) to introduce witnesses that depend on your earlier computation (this is the equivalent of <-- in circom)

This is probably pretty much how UInt64.divMod is already implemented. We also have Gadgets.divMod32() which does the same for 32 bits. I suggest that you think about the problem, study those implementations and then create your own which is tailored to your use case.

I also was able to reduce the Array to 32 Fields which should help with memory usage (32 x UInt64 = UInt2048)

Just a minor point, "should help with memory usage" isn't the correct way to think about this 😅 You're writing a circuit, not a normal program -- reducing the number of limbs is good, because it reduces constraints, not because it reduces memory access :)

mitschabaude commented 5 months ago

PS: toBits() and fromBits() use a constraint for each individual bit to prove that it's a boolean. That's why they are very costly. Luckily we have range checks for larger bit sizes that use lookup tables and are much more efficient

mitschabaude commented 5 months ago

Could this be due to the modular nature of the underlying Field, and if so, how can we avoid this issue?

And yeah, as you guessed, Field.div() computes a modular inverse, so not useful for you

emlautarom1 commented 5 months ago

Just a minor point, "should help with memory usage" isn't the correct way to think about this 😅 You're writing a circuit, not a normal program -- reducing the number of limbs is good, because it reduces constraints, not because it reduces memory access :)

Thanks for the clarification, but I was referring to the allocation error that is triggered during serialization on the Rust side: I expected that by reducing the number of elements on the array this would no longer happen.