anoma / vamp-ir

Vamp-IR is a proof-system-agnostic language for writing arithmetic circuits
https://anoma.github.io/VampIR-Book/
Apache License 2.0
155 stars 44 forks source link

VampIR hangs #79

Open lukaszcz opened 1 year ago

lukaszcz commented 1 year ago

I managed to make VampIR hang at the type inference phase with the following program (VampIR version 0.1.1):

def isZero x = {
  def xi = fresh (1 | x);
  x * (1 - xi * x) = 0;
  1 - xi * x
};

def if b (x1, x2) (y1, y2) = (b * x1 + (1 - b) * y1, b * x2 + (1 - b) * y2);

def fib_0 cont x = (1, 0);

def fib_1 cont x = {
  if (isZero x)
    (cont 0)
    (if (isZero (x - 1))
      (cont 1)
      (fib_0 (fun a { fib_0 (fun b {cont (a + b)}) (x - 2) }) (x - 1)))
};

def fib_2 cont x = {
  if (isZero x)
    (cont 0)
    (if (isZero (x - 1))
      (cont 1)
      (fib_1 (fun a { fib_1 (fun b {cont (a + b)}) (x - 2) }) (x - 1)))
};

def fib_3 cont x = {
  if (isZero x)
    (cont 0)
    (if (isZero (x - 1))
      (cont 1)
      (fib_2 (fun a { fib_2 (fun b {cont (a + b)}) (x - 2) }) (x - 1)))
};

def fib_4 cont x = {
  if (isZero x)
    (cont 0)
    (if (isZero (x - 1))
      (cont 1)
      (fib_3 (fun a { fib_3 (fun b {cont (a + b)}) (x - 2) }) (x - 1)))
};

def fib_5 cont x = {
  if (isZero x)
    (cont 0)
    (if (isZero (x - 1))
      (cont 1)
      (fib_4 (fun a { fib_4 (fun b {cont (a + b)}) (x - 2) }) (x - 1)))
};

def fib = fib_5 (fun x {(0, x)});

fib x = (0, y);

With fib_4 instead of fib_5 in the definition of fib, it errors with TrimmingDegreeTooLarge. With fib_3 it works.

I think it shouldn't hang. If it can't compile a program it should at least give some error indicating why.

AHartNtkn commented 1 year ago

It's hanging because the circuit is too big. I looked at the size of the compiled circuits, and they are; fib_1: 24 constraints fib_2: 101 constraints fib_3: 752 constraints fib_4: 29,901 constraints fib_5: >3,827,328 constraints; It ate too much memory and my computer killed the process.

The way you've designed this, the circuit size increases at a superexponential rate. It isn't hanging during type checking, but during circuit compilation into 3ac. If an error was to be added, it would have to enforce a limitation on the size of a compiled circuit.

lukaszcz commented 1 year ago

Okay, the output suggested it was doing type-checking.

lukaszcz commented 1 year ago

But this doesn't bode well for compiling non-trivial recursive programs into circuits. And why is it superexponential? I would expect it to be exponential (like the fib function). The continuation passing style shouldn't blow up the circuit size by itself?

lukaszcz commented 1 year ago

Why would it even be exponential? Isn't circuit a DAG? Shouldn't two occurrences of fib_K somehow share the same subcircuit? I guess you need to duplicate them because they have different inputs/outputs? But then this makes it essentially hopeless to compile recursive programs with more than one recursive call in the function body (which don't need to be exponential themselves -- this way a simple filter function on lists results in exponential circuit size).

lukaszcz commented 1 year ago

I guess this essentially means that the only reasonable use of recursion for Juvix programs compiled to VampIR is via fold(-like) functions.

AHartNtkn commented 1 year ago

As for why it's super-exponential, each fib_n calls fib(n-1) twice in addition to calling its continuation twice, which will generally include other fib calls. It's those continuations that make it superexponential. I tracked the number of calls to isZero, and it's a(n, 0) for fib_n where a is defined by the recurrence equation.

a(0, k) = 0;
a(n, k) = 2 + 2 k + a(n-1, a(n-1, k))

I don't know how to solve this, but it looks super-exponential to me, like an Ackermann function.

It's definitely the continuations making it so large. I don't know why Vamp-IR doesn't like them, but if you use the dumb implementation of fib, it produces more reasonable size circuits.

def isZero x = {
  def xi = fresh (1 | x);
  x * (1 - xi * x) = 0;
  1 - xi * x
};

def if b x y = b * x + (1 - b) * y;

def fib_0 x = 0;

def fibi r x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (r (x - 1) + r (x - 2)))
};

def fib = iter 10 fibi fib_0;

fib x = y;

produces a large, but workable circuit (I got it to compile and check with halo2). And, of course, an actually good implementation creates a quite small circuit;

def fib_0 x = (0, 1);

def fibi r x = {
  if (isZero x)
    (0, 1)
    (if (isZero (x - 1))
      (1, 1)
      ((fun (x, y) {(y, x + y)}) (r (x - 1))))
};

Optimizations are still an active area of development for Vamp-IR. Your example should be optimizable to something reasonable, but currently, it isn't, so you'll need to put in the effort to do that yourself for now.

AHartNtkn commented 1 year ago

To give further clarification, each call to isZero generates a new witness, and Vamp-IR doesn't know that many of the witnesses will be the same. They are inputs to the circuit, and as far as Vamp-IR is concerned, each might be different. This limits the amount of sharing that Vamp-IR can do for such calls. There should be a way to tell Vamp-IR that different "fresh" witnesses are the same, but it would have to be done manually. Alternatively, you can factor out the witnesses generated by fresh into arguments issued to the functions. That would likely be some work, but it would allow you to control how many witnesses are actually generated.

lukaszcz commented 1 year ago

Is there some way to test if a field element is equal to zero without generating these witnesses?

AHartNtkn commented 1 year ago

There are no efficient alternatives that I'm aware of. You could use lagrange interpolation, but that would result in an extraordinarily large polynomial. You could also use Fermat's little theorem, which states that a^{p-1} = 1, for any non-zero a, where p is the size of the prime field. So we can define our function as isZero(x) = 1 - x^{p-1}. But this creates an extraordinarily high-degree polynomial. Not using witnesses would result in a much larger polynomial than using them does.

lukaszcz commented 1 year ago

That's unfortunate, because isZero seems necessary to compile "high-level" functional programs and it would be used quite often.

AHartNtkn commented 1 year ago

Something did occure to me about that Fermat's little theorem suggestion. You can calculate large exponents by breaking it down into smaller, reusable parts. For example, 2^64 = 2^32 2^32; 2^32 = 2^16 2^16, etc, and I wondered if Vamp-IR already does this sort of optimization. So I tried this;

def isZero x = 1 - x ^ 52435875175126190479447740508185965837690552500527637822603658699938581184512;

isZero x = 1;

isZero y = 0;

and it works fine. It produces a circuit of size 776, and works properly for Plonk (it doesn't work for Halo2 since it uses a different field size) which is much larger than the witness version, but it's not too large, and you can drop-in replace it in your original program, but it doesn't fix your original issue, so it looks like I was wrong about unshared witnesses being the problem. Well, it is the problem, but it's the lack of sharing, not the fact that they're witnesses. It looks like the part of the compiler that handles higher-order functions will need to be looked at since it's clearly not sharing as much as it should be able to.

lukaszcz commented 1 year ago

If this version could be made to not result in an exponential circuit size blow-up, it would be fine:

def isZero x = {
  def xi = fresh (1 | x);
  x * (1 - xi * x) = 0;
  1 - xi * x
};

def if b x y = b * x + (1 - b) * y;

def fib_0 x = 0;

def fib_1 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_0 (x - 1) + fib_0 (x - 2)))
};

def fib_2 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_1 (x - 1) + fib_1 (x - 2)))
};

def fib_3 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_2 (x - 1) + fib_2 (x - 2)))
};

def fib_4 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_3 (x - 1) + fib_3 (x - 2)))
};

def fib_5 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_4 (x - 1) + fib_4 (x - 2)))
};

def fib_6 x = {
  if (isZero x)
    0
    (if (isZero (x - 1))
      1
      (fib_5 (x - 1) + fib_5 (x - 2)))
};

def fib = fib_6;

fib 5 = y;

But I'm not sure if this is possible, because you can't just create one circuit for fib_K and then reuse it with different inputs/outputs? I'm not exactly sure how composing arithmetic circuits works in detail, but it seems that in general you need to duplicate the circuit for fib_K with each occurrence of fib_K?