PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 602 forks source link

[BUG] Nested `jax.vmap` does not work #3492

Open PhilipVinc opened 1 year ago

PhilipVinc commented 1 year ago

Following #3452 jax.vmap should be composable, but it still isn't.

from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import numpy as np
import pennylane as qml

phys_qubits = 2
n_configs = 12
pars_q      = np.random.rand(n_configs,2)
pars_q_r = pars_q.reshape(4,3,2)

dev = qml.device("default.qubit", wires=tuple(range(phys_qubits)), shots=None)

def minimal_circ(params):
    @qml.qnode(dev, interface="jax-jit",diff_method="parameter-shift", cache=None)
    def _measure_operator():
        qml.RY(params[...,0],wires=0)
        qml.RY(params[...,1],wires=1)

        op = qml.Hamiltonian([1.0],[qml.PauliZ(0) @ qml.PauliZ(1)])
        return qml.expval(op)
    res = _measure_operator()
    return res

# works
minimal_circ(pars_q)
# works 
jax.vmap(minimal_circ)(pars_q)
#does not work
jax.vmap(minimal_circ)(pars_q_r)
#does not work
jax.vmap(jax.vmap(minimal_circ))(pars_q_r)
CatalinaAlbornoz commented 1 year ago

Thank you for reporting this bug @PhilipVinc .

rmoyard commented 1 year ago

Hi @PhilipVinc, I think the problem here is that PennyLane broadcasting does not support multiple broadcasted dimensions. If you remove broadcasting inside the QNode and that you use:

jax.vmap(jax.vmap(minimal_circ))(pars_q_r)

it is working. I think it is a good feature request for us, thank you!