PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Function for setting and resetting global changes during compilation #913

Open erick-xanadu opened 1 month ago

erick-xanadu commented 1 month ago

There are several places where we modify JAX during compilation.

Just to list some:

# Required for JAX tracer objects as PennyLane wires.
# pylint: disable=unnecessary-lambda
setattr(jax.interpreters.partial_eval.DynamicJaxprTracer, "__hash__", lambda x: id(x))

# This flag cannot be set in ``QJIT.get_mlir()`` because values created before
# that function is called must be consistent with the JAX configuration value.
jax.config.update("jax_enable_x64", True)

Patchers... (see jax_extras, see jax_transient_config)

And we also have a global context to see whether or not we are running or jax via the EvaluationContext.

With callbacks, this now changes the assumption that if we are tracing, we will never go back to the python environment. We should have a function that is able to save the configuration before we trace, change however we want it, reset it during callbacks, and reset it back to what we need once we exit the callback scope.

Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?

dime10 commented 1 month ago

What is the bug?

Note, could we instead of changing jax.interpreters.partial_eval.DynamicJaxprTracer and adding a hash, can't we change pennylane wire utilities to find whether the wire is jax.interpreters.partial_eval.DynamicJaxprTracer and compute the id as its hash instead of modifying jax itself?

Yes I think that's a good idea! The capture module will probably take care of this.

erick-xanadu commented 1 month ago

What is the bug?

The bug is that we are changing the global state and we shouldn't :sweat_smile: Just changing the state when we shouldn't. The concrete error is #894 but that is just a symptom.