Closed dwierichs closed 2 months ago
All modified and coverable lines are covered by tests :white_check_mark:
Project coverage is 99.58%. Comparing base (
43cac72
) to head (c2c877e
). Report is 1 commits behind head on master.
:umbrella: View full report in Codecov by Sentry.
:loudspeaker: Have feedback on the report? Share it here.
The issue seems to be that defining a primitive's JVP rule using jax.jvp
again exposes the zero tracers sent through the program as tangents to the custom implementation, which breaks when looking at the PytreeDef
s, because zeros are registered as CustomNode
. This issue existed with a previous implementation in JAX, back in 2019. The old issue re-appears with up-to-date JAX code, even for a simple non-PennyLane example :thinking: Should I post an issue on the JAX repo for this?
It is worth noting that this issue is not introduced in this PR, but exists on master as well. It's only that we look into more complex signatures in this PR and thus found that partial differentiation (with argnum
not pointing to all args) breaks.
MWE for the JAX issue:
f = lambda a, b: a + b
f_prim = jax.core.Primitive("f")
@f_prim.def_impl
def _(a, b):
return a + b
@f_prim.def_abstract_eval
def _(a, b):
return jax.core.ShapedArray(a.shape, a.dtype)
from jax.interpreters import ad
def custom_jvp(args, dargs):
print("args", args)
print("dargs", dargs)
return jax.jvp(f, args, dargs)
ad.primitive_jvps[f_prim] = custom_jvp
def F(*args):
return f_prim.bind(*args)
print("forward")
print(F(1., 2.)) # this works
print("grad wrt both args")
print(jax.grad(F, argnums=[0,1])(1., 2.)) # this works
print("grad wrt one arg")
print(jax.grad(F, argnums=[0])(1., 2.)) # error
Intercepting ad.Zero
and manually producing zero-valued tangents like in this tutorial worked.
I added a test based on the bug above.
Context:
6120 and #6127 add support to capture
qml.grad
andqml.jacobian
in plxpr. Once captured, they dispatch tojax.grad
andjax.jacobian
.Description of the Change: This PR adds support for pytree inputs and outputs of the differentiated functions, similar to #6081. For this, it extends the internal class
FlatFn
by the extra functionality to turn the wrapper into a*flat_args -> *flat_outputs
function, instead of a*pytree_args -> *flat_outputs
function.Benefits: Pytree support :deciduous_tree:
Possible Drawbacks:
Related GitHub Issues:
[sc-70930] [sc-71862]