PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.18k stars 570 forks source link

[BUG] `jax.grad` + `jax.jit` does not work with `AmplitudeEmbedding` and finite shots #5541

Open KetpuntoG opened 2 months ago

KetpuntoG commented 2 months ago

Expected behavior

qml.AmplitudeEmbedding should work with jit , grad and finite shots

from pennylane import numpy as np
import pennylane as qml
import jax 

# with shots = None it works
dev = qml.device("default.qubit", wires = 4, shots=100)

@qml.qnode(dev)
def circuit(coeffs):
        qml.AmplitudeEmbedding(coeffs, normalize = True, wires = [0,1])
        return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])

jac_fn = jax.jacobian(circuit)
# without jax.jit it works
jac_fn = jax.jit(jac_fn)

jac = jac_fn(params)
print(jac)

Actual behavior

ValueError: need at least one array to stack

Additional information

Same issue with qml.StatePrep and qml.MottonenStatePreparation

Source code

No response

Tracebacks

No response

System information

Name: PennyLane
Version: 0.36.0.dev0
Summary: PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
Home-page: https://github.com/PennyLaneAI/pennylane
Author: 
Author-email: 
License: Apache License 2.0
Location: /usr/local/lib/python3.10/dist-packages
Requires: appdirs, autograd, autoray, cachetools, networkx, numpy, pennylane-lightning, requests, rustworkx, scipy, semantic-version, toml, typing-extensions
Required-by: PennyLane_Lightning

Platform info:           Linux-6.1.58+-x86_64-with-glibc2.35
Python version:          3.10.12
Numpy version:           1.25.2
Scipy version:           1.11.4
Installed devices:
- default.clifford (PennyLane-0.36.0.dev0)
- default.gaussian (PennyLane-0.36.0.dev0)
- default.mixed (PennyLane-0.36.0.dev0)
- default.qubit (PennyLane-0.36.0.dev0)
- default.qubit.autograd (PennyLane-0.36.0.dev0)
- default.qubit.jax (PennyLane-0.36.0.dev0)
- default.qubit.legacy (PennyLane-0.36.0.dev0)
- default.qubit.tf (PennyLane-0.36.0.dev0)
- default.qubit.torch (PennyLane-0.36.0.dev0)
- default.qutrit (PennyLane-0.36.0.dev0)
- null.qubit (PennyLane-0.36.0.dev0)
- lightning.qubit (PennyLane_Lightning-0.35.1)

Existing GitHub issues

albi3ro commented 2 months ago

Issue seems to be with taking parameter shift of a GlobalPhase:

This gives the exact same error:

import numpy as np
import pennylane as qml
import jax 

# with shots = None it works
dev = qml.device("default.qubit", wires = 4, shots=100)

@qml.qnode(dev)
def circuit(phase):
    qml.GlobalPhase(phase)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

params = jax.numpy.array(0.4)

jac_fn = jax.jacobian(circuit)
# without jax.jit it works
jac_fn = jax.jit(jac_fn)

jac = jac_fn(params)
print(jac)

Potentially need a custom gradient recipe or something. Or just identify that the gradient of a global phase is always zero.

Tarun-Kumar07 commented 2 months ago

Hey @albi3ro and @KetpuntoG, I was looking at issues to contribute to. Can I tackle this issue ?

albi3ro commented 2 months ago

Just setting GlobalPhase.grad_method = None seems to work 🤞 . That would basically just indicate "no differentiable parameters here". Should be a simple enough fix if you're interested.

Screenshot 2024-04-19 at 12 50 58 PM
KetpuntoG commented 2 months ago

In the operator I'm working on, I am differentiating with respect to the global phase parameter, would it stop working @albi3ro ?

albi3ro commented 2 months ago

Using a controlled global phase or something?

KetpuntoG commented 2 months ago

Correct 👌 It seems strange, but it appears naturally in the formulation

albi3ro commented 2 months ago

We could potentially update Controlled.grad_method from:

    @property
    def grad_method(self):
        return self.base.grad_method

to:

    @property
    def grad_method(self):
        return "A" if self.base.name == "GlobalPhase" else self.base.grad_method
Tarun-Kumar07 commented 2 months ago

Hey @albi3ro and @KetpuntoG , I want to work on this. Can you please assign it to me.

albi3ro commented 2 months ago

Assigned. Note that our next release is in two weeks on May 6th, so we may take it over next week to make sure we get the fix in.

Tarun-Kumar07 commented 2 months ago

Hey @albi3ro , I am unsure how to add tests for this. The below test fails as the values are far apart

@pytest.mark.jax
def test_jacobian_with_and_without_jit_has_same_output():
    import jax

    dev = qml.device("default.qubit", wires=4, shots=100)

    @qml.qnode(dev)
    def circuit(coeffs):
        qml.AmplitudeEmbedding(coeffs, normalize=True, wires=[0, 1])
        return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

    params = jax.numpy.array([0.4, 0.5, 0.1, 0.3])
    jac_fn = jax.jacobian(circuit)
    jac_jit_fn = jax.jit(jac_fn)

    # Array([ 1.67562983, -1.88046271, -0.48002058,  1.05993827], dtype=float64
    jac = jac_fn(params)

    # Array([ 1.77878858, -2.02651428, -0.11825829,  1.04522512], dtype=float64))
    jac_jit = jac_jit_fn(params)

    assert qml.math.allclose(jac, jac_jit)

I just made GlobalPhase.grad_method = None as mentioned in this comment

albi3ro commented 2 months ago

Sorry about not getting back to you earlier @Tarun-Kumar07 . Been a bit busy the last few days.

This looks to be a case of the shots being too low. I bumped it up to shots=5000, and then the numbers started to converge better.

So there's two options: 1) Just bumping up the shot number and setting a seed on the device to reduce flakiness qml.device('default.qubit', shots=10000, seed=7890234) 2) Using analytic mode shots=None and manually specifying diff_method="parameter-shift".

Potentially we can just test both.

Tarun-Kumar07 commented 2 months ago

Hey @albi3ro ,

I tried writing tests for MottonenStatePreparation and StatePrep similar to the one in the above comment. Even by providing a high number of shots and using the analytic mode, these tests fail.

I have opened a draft PR: #5620, and these are the only tests failing, not sure how to fix them.