PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

Drawing the circuit with catalyst.measure fails #492

Open ankit27kh opened 7 months ago

ankit27kh commented 7 months ago

Issue description

Description of the issue - Circuit drawing fails when using mid circuit measurment.

Source code and tracebacks

@qjit
@qml.qnode(qml.device("lightning.qubit", wires=2,shots=100))
def circuit(theta):
    qml.Hadamard(wires=0)
    qml.RX(theta, wires=1)
    qml.CNOT(wires=[0,1])
    m=measure(wires=0)
    return qml.expval(qml.PauliZ(wires=1)),m

print(qml.draw(circuit)(1.2))

Error: catalyst.utils.exceptions.CompileError: catalyst.measure can only be used from within @qjit.

josh146 commented 7 months ago

Hi @ankit27kh, unfortunately this is an known error when drawing QNodes with Catalyst mid-circuit measurements, and something we are working to fix! In the meantime, swapping to the PennyLane mid-circuit measurement qml.measure() will allow you to draw the circuit (without the @qjit decorator).

ankit27kh commented 7 months ago

Hi, @josh146. Yeah, I found this in the sharp bits later. I had missed it earlier. Can you also elaborate a bit on this: image I don't know what's being said. Also, I am currently trying to switch my use case from JAX-JIT to catalyst. But I'm having some issues. One of them is incompatibility with jax.vmap. From my quick testing:

f1 = jax.vmap(catalyst.qjit(cost), in_axes=[None, 0]) # device is lightning.qubit, None shots
f2 = jax.jit(jax.vmap(cost, in_axes=[None, 0])) # device is default.qubit, None shots, jax interface

f2 is much faster to execute than f1 (it takes longer to compile though). f1 took 200 times more to run for the simple example. Any way to speed this up?

dime10 commented 7 months ago

Hi @ankit27kh, the section you are quoting attempts to explain that catalyst.measure has an immediate effect on the state vector of a simulator when the operation is being processed. Specifically, it will draw a random sample according the probability distribution of the given qubit, and then project the quantum state onto post-measurement state consistent with that randomly selected outcome. So far this is not surprising, as it is consistent with how quantum measurements are explained in text books. The tricky bit is when the mid-circuit measurement is combined with a PennyLane measurement process like qml.expval or qml.probs. Consider the following example:

def circuit():
    qml.Hadamard(0)
    m = catalyst.measure(0)
    return qml.expval(qml.PauliZ(0))

In Catalyst, each time you run this circuit the value of m will be randomly chosen to be 0 or 1 (with 50% probability). If m=0, then the state is collapsed to (1., 0.) and the expectation value computed subsequently will be +1. If m=1, then the state is collapsed to (0., 1.) and the expectation value at the end will be -1. Thus the final measurement processes are no longer deterministic but instead reflect the quantum state for a particular measurement outcome.

In PennyLane on the other hand, the simulator will average over possible measurement outcomes and qml.expval(qml.PauliZ(0)) would produce 0 instead (halfway between +1 and -1).

The behaviour for measure that Catalyst exhibits is best suited for algorithms that deal with measurement values directly or that can be thought of as "single execution algorithms" (without expectation values), like Shor's algorithm for instance. It also allows you to run arbitrary functions on measurement results or condition classical code on the measurement value, because the actual value of the measurement is immediately available during the simulation.

josh146 commented 7 months ago

Nice explanation @dime10 :)

To cover your other question @ankit27kh, we are actually in the process of implementing catalyst.vmap in #497, so will take this feedback into account! Once done, we should have a catalyst-native approach you can use.

ankit27kh commented 7 months ago

In Catalyst, each time you run this circuit the value of m will be randomly chosen to be 0 or 1 (with 50% probability). If m=0, then the state is collapsed to (1., 0.) and the expectation value computed subsequently will be +1. If m=1, then the state is collapsed to (0., 1.) and the expectation value at the end will be -1. Thus the final measurement processes are no longer deterministic but instead reflect the quantum state for a particular measurement outcome.

Hi @dime10, when you say 'each time you run this circuit', do you mean regardless of shots? I measured samples from the circuit you provided, and they always have the same value for all shots. But aren't shots supposed to be separate executions? I would expect to get a different result for all shots. Mid-circuit measurement and possible postselection have other uses too, other than just 'single execution algorithms'. This will become a handicap in the current implementation.

dime10 commented 7 months ago

Hi @dime10, when you say 'each time you run this circuit', do you mean regardless of shots?

Ah apologies, I should have been more specific. What I meant was "each time you call the compiled function".

When running on simulators, "shots" are merely simulated by drawing samples from the probability distribution of the final statevector, because repeating the circuit simulation for each shot would take a very long time. But I understand your concern, the behaviour of the shots feature should match what one would get on hardware.

Right now, it is best to think of the samples function as just that, drawing samples from the current quantum state's probability distribution, but we are working hard on a better approach to make the interaction of mid-circuit measurements with shots as intuitive as they are on hardware, while still being performant.

In the meantime, if you want to force the repeated execution of the circuit, you can modify your function to use a for loop around the QNode and collect shot results that way (the QNode should now return a single shot result only). Let me know if you want to know more about this and I can provide some examples.

Mid-circuit measurement and possible post-selection have other uses too, other than just 'single execution algorithms'. This will become a handicap in the current implementation.

Out of curiosity, do you have some specific use cases or examples in mind? These can be really helpful in guiding development.

dime10 commented 7 months ago

Hi @ankit27kh, could you share the cost function that you measured here?

From my quick testing:

f1 = jax.vmap(catalyst.qjit(cost), in_axes=[None, 0]) # device is lightning.qubit, None shots
f2 = jax.jit(jax.vmap(cost, in_axes=[None, 0])) # device is default.qubit, None shots, jax interface

f2 is much faster to execute than f1 (it takes longer to compile though). f1 took 200 times more to run for the simple example. Any way to speed this up?

ankit27kh commented 7 months ago

Hi @dime10,

Out of curiosity, do you have some specific use cases or examples in mind?

I was thinking of variational qml circuits, where you measure some qubits for your output postselected on the remaining qubits. The output from the measured qubits will depend on the probability distribution. But in the current case, if the distribution will simply collapse to one outcome, it'll not be possible to get correct outcomes for training the circuit.

could you share the cost function that you measured here?

It was simply this circuit:

def cost(params, x):
    for l in range(3):
        qml.AngleEmbedding(x, wires=range(4))
        qml.StronglyEntanglingLayers(params[l], wires=range(4))
    return qml.expval(qml.PauliZ(0))
josh146 commented 7 months ago

Thanks @ankit27kh!