PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.27k stars 585 forks source link

[BUG] JAX JIT errors-out when using a tracer-based PRNG key #6054

Open trbromley opened 1 month ago

trbromley commented 1 month ago

Expected behavior

It is possible to JAX JIT a function where qml.device has seed set to be a tracer for the PRNG key.

Actual behavior

An error is raised in many cases. I'm not exactly sure when, but it appears to be if there are other tracers active - like with the param variable below.

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

Additional information

Although there is a draft PR to fix this problem, it has gone stale and I decided it would be more useful to track the problem in this issue.

Source code

import jax
import pennylane as qml

def circuit(key, param):
    dev = qml.device("default.qubit", wires=2, shots=10, seed=key)

    @qml.qnode(dev, interface="jax")
    def my_circuit():
        qml.RX(param, wires=0)  # no error if this line is commented-out
        qml.CNOT(wires=[0, 1])
        return qml.sample(qml.PauliZ(0))
    return my_circuit()

key = jax.random.PRNGKey(1967)
jax.jit(circuit)(key, 0.5)

Tracebacks

XlaRuntimeError: INTERNAL: Generated function failed: CpuCallback error: UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type uint32[2] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was circuit at /tmp/ipykernel_31794/2671401565.py:4 traced for jit.
------------------------------
The leaked intermediate value was created on line /tmp/ipykernel_31794/2671401565.py:15 (<module>). 
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
/tmp/ipykernel_31794/2671401565.py:15 (<module>)

System information

Working on the dev branch of PennyLane with JAX==0.4.23.

Existing GitHub issues

josh146 commented 1 month ago

@trbromley do you have more details on this?

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

This PR might be relevant, which adds a seed kwarg to qjit: https://github.com/PennyLaneAI/catalyst/pull/936

trbromley commented 1 month ago

@trbromley do you have more details on this?

There is also an error when using qml.qjit and lightning.qubit, which doesn't expect a JAX tracer for the seed argument.

When I run this:

import jax
import pennylane as qml

def circuit(key, param):
    dev = qml.device("lightning.qubit", wires=2, shots=10, seed=key)

    @qml.qnode(dev, interface="jax")
    def my_circuit():
        qml.RX(param, wires=0)
        qml.CNOT(wires=[0, 1])
        return qml.sample(qml.PauliZ(0))
    return my_circuit()

key = jax.random.PRNGKey(1967)

I get this error:

>>> qml.qjit(circuit)(key, 0.5)
TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=1/0)>

This PR might be relevant, which adds a seed kwarg to qjit: PennyLaneAI/catalyst#936

Oh interesting, so the seed should live with qjit rather than the device? :thinking:

josh146 commented 1 month ago

Oh interesting, so the seed should live with qjit rather than the device? 🤔

The PR was motivated by a technical problem: we have tests in the catalyst code base which are flaky, but no qjit-compatible way of specifying a seed. So passing it through the qjit decorator is a quick way of getting this support in there, but we can revisit from a user POV if we need to

josh146 commented 1 month ago

I get this error:

Ah yes I guess this is to be expected, since JAX seeds are not supported by MLIR/LLVM