PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.34k stars 599 forks source link

support gradient transforms applied manually to `QNode` with non-scalar tape parameters #4225

Open dwierichs opened 1 year ago

dwierichs commented 1 year ago

Edit from original submission

this was originally reported as a bug, but it is known behaviour so it will instead be considered a feature request. In short, the request here is to make classical jacobian computation compatible with non-scalar parameters.

Expected behavior

Applying a gradient_transform manually to a QNode that has non-scalar parameters on the tape level works and returns correct Jacobians.

Actual behavior

For mixed-shape trainable parameters on the tape level, an error is raised. For consistent non-scalar trainable tape-level parameters, wrong Jacobians may be returned.

Additional information

The problem is with qml.transforms.classical_jacobian and the contraction of the resulting classical Jacobian with the quantum Jacobian in gradient_transform.default_qnode_wrapper.

For mixed-shape parameters, stacking the tape parameters is not allowed, but stack is used in classical_preprocessing within classical_jacobian.

For non-scalar parameters, the reshaping of the quantum and classical Jacobians, as well as the contraction axes, are incorrect, leading to errors or wrong results.

Source code

import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)

import pennylane as qml
from pennylane.pulse import ParametrizedEvolution
from pennylane.gradients.pulse_generator_gradient import pulse_generator

coeff1 = lambda p, t: p[0] * jnp.sin(t * p[1]) + p[2]
obs1 = qml.PauliX(0)
param1 = jnp.array([1., 2., 3.], dtype=jnp.float64)

coeff2 = lambda p, t: jnp.cos(p * t)
obs2 = qml.PauliX(1)
param2 = jnp.array(.5)

ham = qml.dot([coeff1, coeff2], [obs1, obs2])
dev = qml.device('default.qubit.jax', wires=2)

@qml.qnode(dev)
def qnode(param1, param2):
    qml.evolve(ham)((param1, param2), t=1.)
    return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

# Runs but returns incorrect result: Single non-scalar trainable tape parameter
grad_fn = qml.gradients.stoch_pulse_grad(qnode, argnums=0)
grad_fn(param1, param2) # Returns (3, 3, 3)-shaped tensor instead of (3,)-shaped vector

# Raises an error (see tracebacks): Mixed-shape trainable tape parameters
grad_fn = qml.gradients.stoch_pulse_grad(qnode, argnums=[0, 1])
grad_fn(param1, param2) # Raises ValueError regarding stacking.

Tracebacks

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 26
     23     return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))
     25 grad_fn = qml.gradients.stoch_pulse_grad(qnode, argnums=[0, 1])
---> 26 grad_fn(param1, param2)
     27 # jax.jacobian(qnode)(param1, param2)

File ~/repos/pennylane/pennylane/gradients/gradient_transform.py:627, in gradient_transform.default_qnode_wrapper.<locals>.jacobian_wrapper(*args, **kwargs)
    624 # Special case where we apply a Jax transform (jacobian e.g.) on the gradient transform and argnums are
    625 # defined on the outer transform and therefore on the args.
    626 argnum_cjac = trainable_params or argnums if interface == "jax" else None
--> 627 cjac = qml.transforms.classical_jacobian(
    628     qnode, argnum=argnum_cjac, expand_fn=self.expand_fn
    629 )(*args, **kwargs)
    631 if qml.active_return():
    632     num_measurements = len(qnode.tape.measurements)

File ~/repos/pennylane/pennylane/transforms/classical_jacobian.py:189, in classical_jacobian.<locals>.qnode_wrapper(*args, **kwargs)
    186     def _jacobian(*args, **kwargs):
    187         return jax.jacobian(classical_preprocessing, argnums=argnum)(*args, **kwargs)
--> 189     jac = _jacobian(*args, **kwargs)
    191 if qnode.interface == "tf":
    192     import tensorflow as tf

File ~/repos/pennylane/pennylane/transforms/classical_jacobian.py:187, in classical_jacobian.<locals>.qnode_wrapper.<locals>._jacobian(*args, **kwargs)
    186 def _jacobian(*args, **kwargs):
--> 187     return jax.jacobian(classical_preprocessing, argnums=argnum)(*args, **kwargs)

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/api.py:928, in jacrev.<locals>.jacfun(*args, **kwargs)
    926 tree_map(partial(_check_input_dtype_jacrev, holomorphic, allow_int), dyn_args)
    927 if not has_aux:
--> 928   y, pullback = _vjp(f_partial, *dyn_args)
    929 else:
    930   y, pullback, aux = _vjp(f_partial, *dyn_args, has_aux=True)

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/api.py:2177, in _vjp(fun, has_aux, reduce_axes, *primals)
   2175 if not has_aux:
   2176   flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
-> 2177   out_primal, out_vjp = ad.vjp(
   2178       flat_fun, primals_flat, reduce_axes=reduce_axes)
   2179   out_tree = out_tree()
   2180 else:

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:139, in vjp(traceable, primals, has_aux, reduce_axes)
    137 def vjp(traceable, primals, has_aux=False, reduce_axes=()):
    138   if not has_aux:
--> 139     out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
    140   else:
    141     out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/ad.py:128, in linearize(traceable, *primals, **kwargs)
    126 _, in_tree = tree_flatten(((primals, primals), {}))
    127 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
--> 128 jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
    129 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
    130 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/profiler.py:314, in annotate_function.<locals>.wrapper(*args, **kwargs)
    311 @wraps(func)
    312 def wrapper(*args, **kwargs):
    313   with TraceAnnotation(name, **decorator_kwargs):
--> 314     return func(*args, **kwargs)
    315   return wrapper

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/interpreters/partial_eval.py:777, in trace_to_jaxpr_nounits(fun, pvals, instantiate)
    775 with core.new_main(JaxprTrace, name_stack=current_name_stack) as main:
    776   fun = trace_to_subjaxpr_nounits(fun, main, instantiate)
--> 777   jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
    778   assert not env
    779   del main, fun, env

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/linear_util.py:188, in WrappedFun.call_wrapped(self, *args, **kwargs)
    185 gen = gen_static_args = out_store = None
    187 try:
--> 188   ans = self.f(*args, **dict(self.params, **kwargs))
    189 except:
    190   # Some transformations yield from inside context managers, so we have to
    191   # interrupt them before reraising the exception. Otherwise they will only
    192   # get garbage-collected at some later time, running their cleanup tasks
    193   # only after this exception is handled, which can corrupt the global
    194   # state.
    195   while stack:

File ~/repos/pennylane/pennylane/transforms/classical_jacobian.py:149, in classical_jacobian.<locals>.classical_preprocessing(*args, **kwargs)
    147 if expand_fn is not None:
    148     tape = expand_fn(tape)
--> 149 return qml.math.stack(tape.get_parameters(trainable_only=trainable_only))

File ~/repos/pennylane/pennylane/math/multi_dispatch.py:151, in multi_dispatch.<locals>.decorator.<locals>.wrapper(*args, **kwargs)
    148 interface = interface or get_interface(*dispatch_args)
    149 kwargs["like"] = interface
--> 151 return fn(*args, **kwargs)

File ~/repos/pennylane/pennylane/math/multi_dispatch.py:488, in stack(values, axis, like)
    459 """Stack a sequence of tensors along the specified axis.
    460 
    461 .. warning::
   (...)
    485        [5.00e+00, 8.00e+00, 1.01e+02]], dtype=float32)>
    486 """
    487 values = np.coerce(values, like=like)
--> 488 return np.stack(values, axis=axis, like=like)

File ~/venvs/dev/lib/python3.10/site-packages/autoray/autoray.py:79, in do(fn, like, *args, **kwargs)
     30 """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single
     31 dispatch to retrieve ``fn`` based on whichever library defines the class of
     32 the ``args[0]``, or the ``like`` keyword argument if specified.
   (...)
     76     <tf.Tensor: id=91, shape=(3, 3), dtype=float32>
     77 """
     78 backend = choose_backend(fn, *args, like=like, **kwargs)
---> 79 return get_lib_fn(backend, fn)(*args, **kwargs)

File ~/venvs/dev/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1750, in stack(arrays, axis, out, dtype)
   1748 for a in arrays:
   1749   if shape(a) != shape0:
-> 1750     raise ValueError("All input arrays must have the same shape.")
   1751   new_arrays.append(expand_dims(a, axis))
   1752 return concatenate(new_arrays, axis=axis, dtype=dtype)

ValueError: All input arrays must have the same shape.

System information

pl dev

Existing GitHub issues

rmoyard commented 1 year ago

Yes I guess it is because of our definition of classical jacobian that uses stacking. We should reconsider it soon!