PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
138 stars 35 forks source link

`qjit(static_argnums=...)` fails when the marked static argument has a default value #1163

Open paul0403 opened 1 month ago

paul0403 commented 1 month ago

Context

When jit-compiling a python function, the arguments of the compiled function lose their concrete values and are replaced by tracers, which at a high level means abstract variables that have the same type and shape as the concrete variable. A compiled program, called the jaxpr, uses these abstract tracers to represent how the arguments of the function are used.

The below example shows how to use Catalyst to jit-compile a function, and how to inspect the compiled jaxpr.

import catalyst
from catalyst import qjit

@qjit
def f(x, y):
  return x+y

result = f(10, 20)
print(result)
print(f.jaxpr)

result = f(1.1, 2.2)
print(result)
print(f.jaxpr)
30
{ lambda ; a:i64[] b:i64[]. let
    c:i64[] = add a b
  in (c,) }
3.3
{ lambda ; a:f64[] b:f64[]. let
    c:f64[] = add a b
  in (c,) }

Notice that in the jaxpr, the type of the arguments to the function, i64 and f64, are the same as the type of the concrete arguments of their corresponding calls. The process of converting python to jaxpr is called tracing.

One issue with arguments being abstract is when their concrete value is needed, for example when being compared to other concrete values, tracing will fail, since abstract tracers cannot be interpreted as concrete values. See here for more details.

@qjit
def f(x, y):
  if x < 100:
    return x+y
  return 42

result = f(10, 20)
print(result)
print(f.jaxpr)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..

To avoid this problem, some of the function arguments can be marked static, which essentially means when tracing, keep their concrete values, and don't replace them with tracers. This marking can be done by the static_argnums keyword argument of qjit, which takes in a list of argument indices to be marked static.

@qjit(static_argnums=[0])
def f(x, y):
  if x < 100:
    return x+y
  return 42

result = f(10, 20)
print(result)
print(f.jaxpr)

result = f(1000, 20)
print(result)
print(f.jaxpr)
30
{ lambda ; a:i64[]. let  ; b:i64[] = add 10 a in (b,) }
42
{ lambda ; a:i64[]. let  ;  in (42,) }

However, currently in Catalyst, arguments with default values cannot be marked as static_argnum:

@qjit(static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

res = f(20)
print(res)
catalyst.utils.exceptions.CompileError: argnum 1 is beyond the valid range of [0, 1).

Goal

We would like to support static_argnums in qjit to mark arguments with default values, as this is supported by native jax.jit:

from functools import partial
import jax
@partial(jax.jit, static_argnums=[1])
def f(y, x=9):
    if x < 10:
        return x + y
    return 42000

print(f(20), f(20, 3), f(20, 30000))
29 23 42000

Requirements:

Technical details

Installation help

To save time, instead of installing Catalyst from source it is also be possible to download the PyPI wheel and extract it into the frontend directory of a cloned catalyst repository (taking care to match git tags before hand), followed by make frontend. This then allows modifying the Python files in-place.

Alternatively, complete instructions to install Catalyst from source can be found here, but due to the size of the llvm-project it can take a while (~3 hrs on a personal laptop) to compile.

paul0403 commented 1 month ago

I think this is because the static_argnum tracks the argument indices at the call, not at the definition, though this is just a hunch and might be completely wrong

AniketDalvi commented 1 week ago

Hi! I am Aniket, a PhD candidate at Duke interviewing for a role on the compiler team. I was given this issue as a technical challenge. Apart from the details mentioned in this issue, are there any other pointers that might help me tackle this issue? Thank you!

paul0403 commented 1 week ago

Hi @AniketDalvi , do you have a more specific question in mind? I would love to give out more pointers and resolve any confusion you may have!

AniketDalvi commented 1 week ago

Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? pip says there are 3 available versions - 0.1.0, 0.1.1 and 0.1.2.

More specifically, I get the following error when downloading the wheel on my linux computer:

ERROR: Cannot install pennylane-catalyst==0.1.0, pennylane-catalyst==0.1.1 and pennylane-catalyst==0.1.2 because these package versions have conflicting dependencies.

The conflict is caused by:
    pennylane-catalyst 0.1.2 depends on jaxlib==0.4.1
    pennylane-catalyst 0.1.1 depends on jaxlib==0.4.1
    pennylane-catalyst 0.1.0 depends on jaxlib==0.4.1
paul0403 commented 1 week ago

Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? pip says there are 3 available versions - 0.1.0, 0.1.1 and 0.1.2.

I don't think there's gonna be any difference w.r.t. this particular issue, any one of the three should be good.

cc @rauletorresc who worked on the frontend dev plug-in

erick-xanadu commented 1 week ago

Hey @AniketDalvi, I am not too sure why your pip says that there are only 3 versions available. We have version 0.8.1 in pypi and will be releasing version 0.9.0 soon.

AniketDalvi commented 1 week ago

Okay I am just going to download the .whl file from the above link. I will then extract it in the frontend directory as was directed in the installation instructions

AniketDalvi commented 1 week ago

Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?

erick-xanadu commented 1 week ago

Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?

You can run the tests. pytest frontend/test/pytest -n auto. I think it is possible that some debugging tests fail in your machine as they expect a specific path. But if the vast majority are passing, then it should be good.

AniketDalvi commented 1 week ago

Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12

erick-xanadu commented 1 week ago

Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12

When you download via pip, pip enforces this compatibility check, but manually downloading it you need to make sure to download the appropriate one for you. See here for a list of several wheels with different python version compatibility.

Maybe it would be easier to install from source? It just takes a long time to build LLVM initially.

EDIT: It also looks like the new Catalyst version 0.9.0 is now available for download :)

AniketDalvi commented 1 week ago

Understood, that makes sense. I might re-try it with a different wheel with a compatible python version. If not, I will resort to installing from source.

AniketDalvi commented 1 week ago

Okay when trying to run the make frontend command with the new wheel, I get this error - error: command '/usr/bin/g++' failed with exit code 1. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?

paul0403 commented 1 week ago

Okay when trying to run the make frontend command with the new wheel, I get this error - error: command '/usr/bin/g++' failed with exit code 1. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?

There's some required packages before building Catalyst. Maybe some of them are missing? See the build from source guide

paul0403 commented 1 week ago

(If wheels are too complicated I recommend just building from source.)

AniketDalvi commented 1 week ago

Yup installed all the required packages, but the error persists. I am now just going to build from source instead.

paul0403 commented 1 week ago

Yup installed all the required packages, but the error persists. I am now just going to build from source instead.

To avoid all package version issues, I also recommend using a fresh virtual environment when developing, e.g.

python3 -m venv pyenv
source ./pyenv/bin/activate

after which you can pip install all the requirements and make all from source.

AniketDalvi commented 1 week ago

Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected). I am running all of this from with within a conda environment on a linux machine.

erick-xanadu commented 1 week ago

Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected).

Yeah, a bunch are skipped, 4 failing ones.

I am running all of this from with within a conda environment on a linux machine.

Awesome! We don't normally use conda so I am happy to hear this worked for you :)

AniketDalvi commented 6 days ago

Hi! So from my initial analysis, I traced the issue down to this check that throws an exception - https://github.com/PennyLaneAI/catalyst/blob/75dc517a6e2f2583da890b9e8198241933a7aef5/frontend/catalyst/tracing/type_signatures.py#L120.

It appears that it checks the index used to specify the static argument with the number of args passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args) as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.

Would like to get your thoughts on this!

paul0403 commented 6 days ago

Hi! So from my initial analysis, I traced the issue down to this check that throws an exception -

https://github.com/PennyLaneAI/catalyst/blob/75dc517a6e2f2583da890b9e8198241933a7aef5/frontend/catalyst/tracing/type_signatures.py#L120

. It appears that it checks the index used to specify the static argument with the number of args passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args) as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.

Would like to get your thoughts on this!

Hi! Usually what we do for these challenges is you can fork the catalyst repo, push your changes, and open a PR. It will be easier to review.

It's good if it turns out to be a simple fix, but one thing I'm afraid of is whether loosening the verification would allow in some errors. Can you test the frontend test suite to make sure this does not happen? You can run make test-frontend in the root catalyst directory.

AniketDalvi commented 6 days ago

Hi! Okay that sounds good. I am working off of a my branch. Does that work, or does it have to be a fork?

erick-xanadu commented 6 days ago

It doesn't really matter how you develop, as long as you are able to push a pull request :). I think to do that, you do need a fork. But you can always just add a new remote to your local git workspace.

git clone $pennylane/catalyst
# work on the issue
# fork $pennylane/catalyst to $yourrepo/catalyst
git remote add myrepo $yourrepo/catalyst
git push myrepo $yourbranch
# open a PR from $yourrepo/catalyst to $pennylane/catalyst