Open quantshah opened 2 years ago
Hi @quantshah
it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.
This very much comes down to how the JAX JIT interface uses host_callback.call
at the moment: (see newer comment here).
~This very much comes down to what JAX offers at the moment:~
host_callback.call
supports jax.jit
, but not jax.jacobian
;~host_callback.id_tap
supports jax.jacobian
, but not jax.jit
in a way that we could define custom gradients (likely due its side-effect nature, see this issue https://github.com/google/jax/issues/9172).~~Maybe there's a way of using host_callback.id_tap
with jax.jit
in the way we'd like to, but so far it doesn't seem to be the case.~
Thanks @antalszava for the explanation. Feel free to leave this issue open till there is a resolution or close it since this is an issue with Jax and not PL.
Sure :) I might leave it open, just so that there's a point of reference if this becomes a question for others too.
Note: this issue could be potentially resolved by a refactor to the JAX JIT interface. We have this on our radar and would like to look into a resolution in the coming weeks.
In specific, at the moment the g
cotangent value is being used as input parameters here. This g
value instead should be applied to the result of host_callback.call
.
@antalszava just curious: if we move towards having the quantum device itself compute the VJP, this would require that the cotangent vector g
must be passed through the host_callback
? One example would be the adjoint
method with lightning.qubit
.
Likely not. JAX seems to assume that g
is applied to the residuals to yield the jacobian.
Specifically for adjoint
with mode="forward"
, the jacobian could be "passed" to the registered backward function as a residual using the following pattern (as suggested on this discussion thread here):
params = jnp.array([0.1, 0.2])
@jax.custom_vjp
def wrapped_exec(params):
y = params ** 2, params ** 3
# don't need compute jacs here
return y
def wrapped_exec_fwd(params):
y = wrapped_exec(params)
jacs = jnp.diag(2 * params), jnp.diag(3 * params ** 2) # compute here
return y, jacs # don't need params here
def wrapped_exec_bwd(res, g):
jac1, jac2 = res
g1, g2 = g
return (g1 @ jac1) + (g2 @ jac2),
wrapped_exec.defvjp(wrapped_exec_fwd, wrapped_exec_bwd)
jax.jacobian(wrapped_exec)(params)
Hi everyone, getting back to this thread as I saw that in Jax, there is a possibility to implement higher order gradients (VJPs) with host_callback
using an outside implementation for the gradient computation (e.g., TensorFlow). See the discussion here: https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-tensorflow-function-with-reverse-mode-autodiff-support
I had a look at the implementation here: https://github.com/google/jax/blob/main/tests/host_callback_to_tf_test.py#L100 but haven't figured out completely what is happening in the custom backward pass that allows one to compute higher order gradients with host_callback
. It feels like somehow they are just hooking up the TensorFlow autodiff mechanism to the Jax custom_vjp definitions and it works all the way up to the higher order derivative.
Just putting this out here for reference in the future in case we look into this again and it is helpful.
I just stumbled on this issue, I think a way to avoid this issue would be to define a new jax primitive operation jax.core.Primitive("qml_expval")
, defining the CPU and GPU implementations, and then also how to differentiate it.
Then qml_expval
will be perfectly equivalent to any other jax native operation (so everything can be supported).
The main 'complication' is that I'm not sure if you can feed a host_callback
as the implementation.
I think it could be possible, but should be tried.
I'm sure you could feed a C function (because that's what it natively supports), which then trampolines back into python code.
The interface is quite stable and has not changed in the last 2 years.
Ps: If there's any interest for that, I don't have the time to put in, but I can provide some guidance. I did that already for two different packages.
This is interesting, there is a nice example here of how this is done all the way upto JITing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what host_callback
does). So I don’t know if this would still work.
Unless there is a way to use the XLA custom call (https://www.tensorflow.org/xla/custom_call) to get the value of the expval.
Hi @PhilipVinc, thank you for the suggestion! :tada:
Personally, I'm new to creating a custom JAX primitive, any help and guidance here would definitely be appreciated. :slightly_smiling_face:
Also wondering about the question that Shahnawaz mentioned: how could introducing the new primitive help with the specific error originally reported? It would seem that the issue is specific to the invocation of host_callback.call
. Assuming that jax.core.Primitive
is compatible with host_callback.call
and jax.core.Primitive("qml_expval")
is defined, wouldn't the the same issue arise with its custom gradient?
But the problem still remains that we cannot use XLA operations to do the expval evaluation as the expval evaluation happens purely on the quantum device right? So there has to be a break in the computational graph somewhere (what host_callback does). So I don’t know if this would still work.
There's two graphs at play here. The one used during function transformations, pre-jax.jit
, (which is more of a tape than a graph, but whatever). For this one, a primitive is a node in the graph. It must specify, much akin to host_callback, what's the input and output shape (primitive.def_abstract_eval
), and how that primitive transforms under passes like vmap
(batching-primitive_batching
), vjp
and jvp
(ad.primitive_jvps
, but does not need to specify anything else.
The issue with host_callback
is that due to a bunch of issues, you cannot customise how host_callback
transforms under vmap
vjp
and jvp
. But you can do that for your custom primitive.
Then, you must also tell jax to what XLA operation the primitive corresponds to when he compiles (xla.backend_specific_translations["cpu"][primitive]
)
XLA Operations contain a reference to a C function that performs that operation.
You can always call ANY C-code.
For example, I used this mechanism to support MPI operations inside of jax-jitted functions with mpi4jax
(a good, self-contained file is this here implementing the primitive ) or to support numba
operations inside of jitted functions with numba4jax
.
But, I guess, you could also enqueue an host_callback operation.
Assuming that jax.core.Primitive is compatible with host_callback.call and jax.core.Primitive("qml_expval") is defined, wouldn't the the same issue arise with its custom gradient?
What issue exactly? I'm not familiar with the depth of pennylane's source, so if you have a short example, even in pseudocode, that would help clarify
What issue exactly? I'm not familiar with the depth of pennylane's source, so if you have a short example, even in pseudocode, that would help clarify
Sure. :slightly_smiling_face: At the moment, the use of @jax.custom_vjp
in PennyLane is not ideal because we are passing the cotangent vectors (g
) along with the input parameters to the host_callback.call
invocation:
args = tuple(params) + (g,)
vjps = host_callback.call(
non_diff_wrapper,
args,
result_shape=jax.ShapeDtypeStruct((total_params,), dtype),
)
Passing g
helped with using the qml.gradients.batch_vjp
function we have internally, a function called in the non_diff_wrapper
function and shared across other machine learning frameworks too (including TensorFlow and PyTorch).
At the same time, passing g
along the other arguments creates issues because g
may become a BatchTrace
object when using certain transforms (e.g., jax.jacobian
that uses jax.vmap
) and this seems to be the culprit for the original error:
NotImplementedError: JVP rule is implemented only for id_tap, not for call.
It would seem that there may be two components to a solution here:
g
should not be passed as an argument, but rather be used to mutate the output of host_callback.call
(as per how custom_vjp
should work in JAX);qml.gradients.batch_vjp
or implementing it to be aligned with the logic here).With those changes we should have no BatchTrace
objects flow through the host_callback.call
's invocation and should be able to implement the logic for supporting the original Hessian computation.
At the same time, passing g along the other arguments creates issues because g may become a BatchTrace object when using certain transforms (e.g., jax.jacobian that uses jax.vmap) and this seems to be the culprit for the original error:
Are you aware of jax.custom_batching.custom_vmap
? If you define a custom_vmap
rule for your custom_vjp
you might sidestep the issue entirely.
Wasn't aware of it! :astonished: Will try this out thank you. :slightly_smiling_face: I see it's functionality in the works, but should be worthwhile to try because of the jax.jacobian
support. :+1:
Expected behavior
I was trying to compute the Hessian and saw that the Jax interface breaks down if we have the JIT on. Without JIT, it works fine. The error seems to be due to the non-availability of JVPs in the
host_callback
bridge between PL and Jax. To make it work, just remove the @jax.jit from the definition of the circuit.@josh146 and I discussed this over slack and it seems a bit strange to have something that works with the JIT off not work when we simply JIT things.
Actual behavior
can't apply forward-mode autodiff (jvp) to a custom_vjp function. JVP rule is implemented only for id_tap, not for call.
Additional information
No response
Source code
Tracebacks
System information
Existing GitHub issues