PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
122 stars 27 forks source link

[Bug] `qml.StatePrep` decomposes unnecessarily for Lightning, and the decomposition is very slow #939

Closed josh146 closed 2 weeks ago

josh146 commented 1 month ago

There are two potential issues here, both associated with qml.StatePrep.

qml.StatePrep should not be decomposed when used with Lightning.

Lightning natively supports preparing a quantum state via qml.StatePrep directly from an array. However, Catalyst will decompose qml.StatePrep via the qml.MottonenStatePreparation decomposition.

This can be seen by viewing the JAXPR of the following circuit:

dev = qml.device("lightning.qubit", wires=4)

@qml.qjit
@qml.qnode(dev)
def f(psi):
    qml.StatePrep(psi, wires=[0, 1, 2, 3])
    return qml.expval(qml.PauliX(0))

>>> psi = jnp.ones(16) / 4
>>> f(psi)

qml.StatePrep when decomposed results in extremely long compilation time.

When qml.StatePrep is decomposed, the compilation time is excessively long, and the generated JAXPR surprisingly large. In the example above, the JAXPR has over a thousand lines.

Modifying the example to take place on 10 wires,

dev = qml.device("lightning.qubit", wires=10)

@qml.qjit
@qml.qnode(dev)
def f(psi):
    qml.StatePrep(psi, wires=range(10))
    return qml.expval(qml.PauliX(0))

>>> psi = jnp.ones(2**10) / 10
>>> f(psi)

this example now takes over 1.5 minutes to compile, and generates a JAXPR of size 48206.

dime10 commented 1 month ago

Is this a bug or adding a missing feature? The compile times are indeed really bad, so we should definitely add the native state prep op!

josh146 commented 1 month ago

I would say:

qml.StatePrep should not be decomposed when used with Lightning.

is a missing feature, while

qml.StatePrep when decomposed results in extremely long compilation time.

is a performance bug? I don't think this is necessarily a bug in Catalyst, it is likely moreso a performance issue in how the Mottonen decomposition logic is written and being traced (I see hundreds of lines in the JAXPR which are just reshaping/broadcasting(?))

dime10 commented 1 month ago

is a performance bug? I don't think this is necessarily a bug in Catalyst, it is likely moreso a performance issue in how the Mottonen decomposition logic is written and being traced (I see hundreds of lines in the JAXPR which are just reshaping/broadcasting(?))

Unfortunately this seems to be the main performance bottleneck for compilation in Catalyst (and JAX for that matter): The JAX NumPy code generating a very large amount of instructions for only small computations, with often seemingly redundant code or with "non-computational" instructions far outweighing the "computational" ones. As a result the program representation blows up. I'm not sure what a solution to this might look like however 🤔

paul0403 commented 1 month ago

Just adding a note here, while it is still very far away in the future, the decomposition should still happen on quantum hardware, and prepping a state from a complex array directly should only be possible on simulators.

josh146 commented 1 month ago

@paul0403 yep exactly :+1: But this is less in the future and possible right now. E.g., if you are using Catalyst with the AWS or OQC backends, the stateprep decomposition should still take place.

paul0403 commented 1 month ago

is a performance bug? I don't think this is necessarily a bug in Catalyst, it is likely moreso a performance issue in how the Mottonen decomposition logic is written and being traced (I see hundreds of lines in the JAXPR which are just reshaping/broadcasting(?))

Unfortunately this seems to be the main performance bottleneck for compilation in Catalyst (and JAX for that matter): The JAX NumPy code generating a very large amount of instructions for only small computations, with often seemingly redundant code or with "non-computational" instructions far outweighing the "computational" ones. As a result the program representation blows up.

I'm not sure what a solution to this might look like however 🤔

One way I see is to perform the mottonen state prep decomposition in the milr level as a pass, rather than decompose in Python and trace it. This way we circumvent all the wasted time on non computational instructions in jax. And since it's just another mlir pass the time taken for transformation should be small. This seems like a good one to have in the quantum peephole library.

The question is do we reimplement the entire mottonen algorithm from scratch, or do we somehow use the existing one in PL but only query the resultant decomposition instead of tracing it (e.g. when we see state prep during tracing, don't actually add anything to the jaxpr, but instead "exec()" an independent circuit that calls qml.mottonenstateprep outside qjit and just keep the result as a string or some other lightweight data type in jaxpr, then propagate it down to mlir for the decomposition). I would argue the second way is cleaner and safer in terms of us should give the same decomposition as core PL.

Note that I think this speedup feature is good to have even with #955, since decomposition at compile time would still happen for openqasm and oqc, and whatever future non simulator partners we add.

dime10 commented 1 month ago

The question is do we reimplement the entire mottonen algorithm from scratch, or do we somehow use the existing one in PL but only query the resultant decomposition instead of tracing it (e.g. when we see state prep during tracing, don't actually add anything to the jaxpr, but instead "exec()" an independent circuit that calls qml.mottonenstateprep outside qjit and just keep the result as a string or some other lightweight data type in jaxpr, then propagate it down to mlir for the decomposition). I would argue the second way is cleaner and safer in terms of us should give the same decomposition as core PL.

This doesn't work if the state prep is dependant on the function arguments.

But it's true that in general we might want to "counter-act" JAX's omnistaging by running most computations that have constant input values during tracing (i.e. constant folding at trace time), which might cut down on compile-time significantly. But this can vary on a case-by-case basis (sometimes doing a computation in Python might be slower than compiling it and executing it natively).