PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.35k stars 604 forks source link

`jax.jit` compilation extremely slow when grouping qubit-wise-commuting observables #3552

Open Qottmann opened 1 year ago

Qottmann commented 1 year ago

I noticed that the following code snippet takes ages when activating qubit-wise-commuting groupings in the expectation value computation in the Hamiltonian. I aborted the execution after 10 minutes.

The problem (or at least the main one) seems to be setting grouping="qwc" in the Hamiltonian. Without it, the compilation finishes after roughly 40 seconds on my laptop.

Something in that pipeline must cause serious problems. I leave this issue as a reminder for us to be aware of this.

import pennylane as qml
import pennylane.numpy as np
import jax
import jax.numpy as jnp
import datetime

symbols = ["H", "O", "H"]
coordinates = np.array([-0.0399, -0.0038, 0.0, 1.5780, 0.8540, 0.0, 2.7909, -0.5159, 0.0])

basis_set = "sto-3g"
H, n_wires = qml.qchem.molecular_hamiltonian(
    symbols,
    coordinates,
    charge=0,
    mult=1,
    basis=basis_set,
    active_electrons=4,
    active_orbitals=4,
    mapping="bravyi_kitaev",
    method="pyscf",
)

coeffs, obs = H.coeffs, H.ops
H_obj = qml.Hamiltonian(jnp.array(coeffs), obs, grouping_type="qwc")

singles, doubles = qml.qchem.excitations(electrons=4, orbitals=n_wires)
hf = qml.qchem.hf_state(4, n_wires)
default_qubit = qml.device("default.qubit", wires=range(n_wires))

@jax.jit
@qml.qnode(default_qubit, interface="jax")
def qnode_gate(theta):
    qml.AllSinglesDoubles(weights = theta,
        wires = range(n_wires),
        hf_state = hf,
        singles = singles,
        doubles = doubles)
    return qml.expval(H_obj)

params = np.random.rand(26, requires_grad=True)
params = jnp.array(params)

value_and_grad = jax.jit(jax.value_and_grad(qnode_gate, argnums=0))
time0 = datetime.datetime.now()
_ = value_and_grad(params)
time1 = datetime.datetime.now()
print(f"grad and val compilation time: {time1 - time0}")
rmoyard commented 1 year ago

From my tests: default.qubit.jax does not work well with grad jitting with backprop when there is a high number of tape in the batch execute. More exploration is needed to understand why compiling the gradient with the batch execute is so slow.

albi3ro commented 1 year ago

This was actually something I just had to deal with for AWS recently too.

Basically, if the grouping indices are known for a Hamiltonian, it will expand it out it with hamiltonian_expand, whether or not the device has finite shots: https://github.com/PennyLaneAI/pennylane/blob/359130c783b48fea04bac38c5602df2cc3aa401d/pennylane/_device.py#L746

One way to turn this off is by setting the optional device property dev.use_grouping=False.

The idea was that if a user went through the effort of computing the grouping indices, they wanted to use them.

josh146 commented 1 year ago

The idea was that if a user went through the effort of computing the grouping indices, they wanted to use them.

Part of the problem was that we had no UI for:

If we have solutions for these, we can remove this assumption in the logic.

trbromley commented 3 months ago

Testing this after a few years - things are still a bit slow :thinking:

But now, it takes my laptop ~3 minutes to compile with both grouping_type=None and grouping_type="qwc". I'm tempted to say that this issue can be resolved, since before the process was being aborted after more than 10 minutes.

albi3ro commented 3 months ago

Note that @EmilianoG-byte is currently working on speeding up this performance bottleneck. Hopefully we have substantially better numbers in the next couple of days 🎉

trbromley commented 3 months ago

Awesome! Is it this PR: https://github.com/PennyLaneAI/pennylane/pull/6043

EmilianoG-byte commented 3 months ago

Hi! You can see from the graphs in my PR's description that indeed the qwc commutation pipeline could use a lot of improvement since it was orders of magnitude slower than the other two grouping types. As you can see from the graph, we expect this to change taking advantage of the symplectic representation :)