Open paul0403 opened 1 month ago
I think this is because the static_argnum
tracks the argument indices at the call, not at the definition, though this is just a hunch and might be completely wrong
Hi! I am Aniket, a PhD candidate at Duke interviewing for a role on the compiler team. I was given this issue as a technical challenge. Apart from the details mentioned in this issue, are there any other pointers that might help me tackle this issue? Thank you!
Hi @AniketDalvi , do you have a more specific question in mind? I would love to give out more pointers and resolve any confusion you may have!
Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for? pip
says there are 3 available versions - 0.1.0, 0.1.1 and 0.1.2
.
More specifically, I get the following error when downloading the wheel on my linux computer:
ERROR: Cannot install pennylane-catalyst==0.1.0, pennylane-catalyst==0.1.1 and pennylane-catalyst==0.1.2 because these package versions have conflicting dependencies.
The conflict is caused by:
pennylane-catalyst 0.1.2 depends on jaxlib==0.4.1
pennylane-catalyst 0.1.1 depends on jaxlib==0.4.1
pennylane-catalyst 0.1.0 depends on jaxlib==0.4.1
Hi! So for installation - it says to download the PyPI wheel. Which version should I be downloading the wheel for?
pip
says there are 3 available versions -0.1.0, 0.1.1 and 0.1.2
.
I don't think there's gonna be any difference w.r.t. this particular issue, any one of the three should be good.
cc @rauletorresc who worked on the frontend dev plug-in
Hey @AniketDalvi, I am not too sure why your pip says that there are only 3 versions available. We have version 0.8.1 in pypi and will be releasing version 0.9.0 soon.
Okay I am just going to download the .whl
file from the above link. I will then extract it in the frontend
directory as was directed in the installation instructions
Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?
Hi! Okay, so I have followed the instruction and seemed to have successfully installed the repository. Is there a quick sanity check experiment/file I can run to verify the installation?
You can run the tests. pytest frontend/test/pytest -n auto
. I think it is possible that some debugging tests fail in your machine as they expect a specific path. But if the vast majority are passing, then it should be good.
Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12
Running pytest gave me an error stating that there was an interpreter mismatch. The interpreter is Python 3.10 while the package is compatible only with 3.12
When you download via pip, pip enforces this compatibility check, but manually downloading it you need to make sure to download the appropriate one for you. See here for a list of several wheels with different python version compatibility.
Maybe it would be easier to install from source? It just takes a long time to build LLVM initially.
EDIT: It also looks like the new Catalyst version 0.9.0 is now available for download :)
Understood, that makes sense. I might re-try it with a different wheel with a compatible python version. If not, I will resort to installing from source.
Okay when trying to run the make frontend
command with the new wheel, I get this error - error: command '/usr/bin/g++' failed with exit code 1
. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?
Okay when trying to run the
make frontend
command with the new wheel, I get this error -error: command '/usr/bin/g++' failed with exit code 1
. I have tried all solutions that stack overflow has to offer, but to no avail. Any thoughts on what this could be?
There's some required packages before building Catalyst. Maybe some of them are missing? See the build from source guide
(If wheels are too complicated I recommend just building from source.)
Yup installed all the required packages, but the error persists. I am now just going to build from source instead.
Yup installed all the required packages, but the error persists. I am now just going to build from source instead.
To avoid all package version issues, I also recommend using a fresh virtual environment when developing, e.g.
python3 -m venv pyenv
source ./pyenv/bin/activate
after which you can pip install all the requirements and make all
from source.
Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected). I am running all of this from with within a conda environment on a linux machine.
Okay I seemed to have gotten it to work from source. Most tests pass, some are skipped, and 4 debugging tests fails (as @erick-xanadu said is expected).
Yeah, a bunch are skipped, 4 failing ones.
I am running all of this from with within a conda environment on a linux machine.
Awesome! We don't normally use conda so I am happy to hear this worked for you :)
Hi! So from my initial analysis, I traced the issue down to this check that throws an exception - https://github.com/PennyLaneAI/catalyst/blob/75dc517a6e2f2583da890b9e8198241933a7aef5/frontend/catalyst/tracing/type_signatures.py#L120.
It appears that it checks the index used to specify the static argument with the number of args
passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater than len(args)
as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.
Would like to get your thoughts on this!
Hi! So from my initial analysis, I traced the issue down to this check that throws an exception -
. It appears that it checks the index used to specify the static argument with the number of
args
passed to the function. However, in the case of a default parameter, the static argument index is likely to be greater thanlen(args)
as the default parameter may not be passed when the function is called. Considering this, my initial proposal for a solution would be to compare the static argument index to the number of arguments in the function signature, as opposed to the number in the function itself.Would like to get your thoughts on this!
Hi! Usually what we do for these challenges is you can fork the catalyst repo, push your changes, and open a PR. It will be easier to review.
It's good if it turns out to be a simple fix, but one thing I'm afraid of is whether loosening the verification would allow in some errors. Can you test the frontend test suite to make sure this does not happen? You can run make test-frontend
in the root catalyst directory.
Hi! Okay that sounds good. I am working off of a my branch. Does that work, or does it have to be a fork?
It doesn't really matter how you develop, as long as you are able to push a pull request :). I think to do that, you do need a fork. But you can always just add a new remote to your local git workspace.
git clone $pennylane/catalyst
# work on the issue
# fork $pennylane/catalyst to $yourrepo/catalyst
git remote add myrepo $yourrepo/catalyst
git push myrepo $yourbranch
# open a PR from $yourrepo/catalyst to $pennylane/catalyst
Context
When jit-compiling a python function, the arguments of the compiled function lose their concrete values and are replaced by tracers, which at a high level means abstract variables that have the same type and shape as the concrete variable. A compiled program, called the jaxpr, uses these abstract tracers to represent how the arguments of the function are used.
The below example shows how to use Catalyst to jit-compile a function, and how to inspect the compiled jaxpr.
Notice that in the jaxpr, the type of the arguments to the function,
i64
andf64
, are the same as the type of the concrete arguments of their corresponding calls. The process of converting python to jaxpr is called tracing.One issue with arguments being abstract is when their concrete value is needed, for example when being compared to other concrete values, tracing will fail, since abstract tracers cannot be interpreted as concrete values. See here for more details.
To avoid this problem, some of the function arguments can be marked static, which essentially means when tracing, keep their concrete values, and don't replace them with tracers. This marking can be done by the
static_argnums
keyword argument ofqjit
, which takes in a list of argument indices to be marked static.However, currently in Catalyst, arguments with default values cannot be marked as
static_argnum
:Goal
We would like to support
static_argnums
in qjit to mark arguments with default values, as this is supported by nativejax.jit
:Requirements:
jax.jit
. Explicitly:Technical details
Due to reasons that do not concern us here, all jaxprs produced by
qjit
will carry atransform_named_sequence
. You can safely ignore it.The
qjit
function takes in a python function and returns aQJIT
object, which is a callable. In theQJIT
object, there is acapture
method that determines how a python function is traced into a jaxpr. See frontend/catalyst/jit.py.It should be possible to implement this functionality completely in the capture layer, without delving into the actual underlying machinery of the
trace_to_jaxpr
methods. For example, one option is to create two versions of the function, both without any default-valued arguments, adjusted to behave correctly, and trace these two functions depending on whether a default value was supplied by the user's call. Other options might be possible too.There are many cases potentially possible for how a user might call a function. For the purpose of this assessment, only the simplest case is required, aka the example above that worked with pure jax (though bonus points if you can make more complicated patterns work).
Installation help
To save time, instead of installing Catalyst from source it is also be possible to download the PyPI wheel and extract it into the
frontend
directory of a cloned catalyst repository (taking care to match git tags before hand), followed bymake frontend
. This then allows modifying the Python files in-place.Alternatively, complete instructions to install Catalyst from source can be found here, but due to the size of the llvm-project it can take a while (~3 hrs on a personal laptop) to compile.