PennyLaneAI / pennylane

PennyLane is a cross-platform Python library for quantum computing, quantum machine learning, and quantum chemistry. Train a quantum computer the same way as a neural network.
https://pennylane.ai
Apache License 2.0
2.36k stars 604 forks source link

[Program Capture] Add pytree support to captured `qml.grad` and `qml.jacobian` #6134

Closed dwierichs closed 2 months ago

dwierichs commented 2 months ago

Context:

6120 and #6127 add support to capture qml.grad and qml.jacobian in plxpr. Once captured, they dispatch to jax.grad and jax.jacobian.

Description of the Change: This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class FlatFn by the extra functionality to turn the wrapper into a *flat_args -> *flat_outputs function, instead of a *pytree_args -> *flat_outputs function.

Benefits: Pytree support :deciduous_tree:

Possible Drawbacks:

Related GitHub Issues:

[sc-70930] [sc-71862]

codecov[bot] commented 2 months ago

Codecov Report

All modified and coverable lines are covered by tests :white_check_mark:

Project coverage is 99.58%. Comparing base (43cac72) to head (c2c877e). Report is 1 commits behind head on master.

Additional details and impacted files ```diff @@ Coverage Diff @@ ## master #6134 +/- ## ========================================== - Coverage 99.59% 99.58% -0.01% ========================================== Files 443 443 Lines 42231 42255 +24 ========================================== + Hits 42058 42078 +20 - Misses 173 177 +4 ```

:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.

dwierichs commented 2 months ago

The issue seems to be that defining a primitive's JVP rule using jax.jvp again exposes the zero tracers sent through the program as tangents to the custom implementation, which breaks when looking at the PytreeDefs, because zeros are registered as CustomNode. This issue existed with a previous implementation in JAX, back in 2019. The old issue re-appears with up-to-date JAX code, even for a simple non-PennyLane example :thinking: Should I post an issue on the JAX repo for this?

It is worth noting that this issue is not introduced in this PR, but exists on master as well. It's only that we look into more complex signatures in this PR and thus found that partial differentiation (with argnum not pointing to all args) breaks.

MWE for the JAX issue:

f = lambda a, b: a + b

f_prim = jax.core.Primitive("f")

@f_prim.def_impl
def _(a, b):
    return a + b

@f_prim.def_abstract_eval
def _(a, b):
    return jax.core.ShapedArray(a.shape, a.dtype)

from jax.interpreters import ad

def custom_jvp(args, dargs):
  print("args", args)
  print("dargs", dargs)
  return jax.jvp(f, args, dargs)

ad.primitive_jvps[f_prim] = custom_jvp

def F(*args):
    return f_prim.bind(*args)

print("forward")
print(F(1., 2.)) # this works
print("grad wrt both args")
print(jax.grad(F, argnums=[0,1])(1., 2.)) # this works
print("grad wrt one arg")
print(jax.grad(F, argnums=[0])(1., 2.)) # error
dwierichs commented 2 months ago

Intercepting ad.Zero and manually producing zero-valued tangents like in this tutorial worked.

I added a test based on the bug above.