PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.17k stars 568 forks source link

`to_catalyst` plxpr conversion function using a jaxpr interpreter #5883

Open albi3ro opened 1 week ago

albi3ro commented 1 week ago

Backlogged in favor of Catalyst #837

Context:

This PR succeeds #5771 .

While that PR directly constructed the Jaxpr and JaxprEqn, this PR creates a "Jaxpr interpreter" that we then trace the execution of with jax.make_jaxpr.

This route has several benefits:

1) Not needing to be concerned about micro details, like the particular shapes, counts, and conversion that needed to be manually controlled in the other route

2) Following the standard "jaxpr interpreter" framework explained in the jax docs. This will probably be the framework we follow for other types of jaxpr interpretation and conversion. If we follow this pattern over and over again in our jaxpr interpretation/ transformations, it will lower then number of concepts developers need to get familiar with in order to maintain this code.

https://jax.readthedocs.io/en/latest/notebooks/Writing_custom_interpreters_in_Jax.html

Description of the Change:

Implements a to_catalyst(plxpr)(*args) function for conversion from plxpr variant jaxpr to catalyst variant jaxpr.

Benefits:

We will be able to use plxpr program capture in the catalyst pipeline in the future.

Possible Drawbacks:

Related GitHub Issues:

[sc-61537]

github-actions[bot] commented 1 week ago

Hello. You may have forgotten to update the changelog! Please edit doc/releases/changelog-dev.md with: