Open quantshah opened 1 year ago
Qobj
are designed to work mixing multiple data type mixed and it's not clear that jax
can find the mathematics operation in a dispatched function. The dispatcher itself is cython compiled.
We could try to register the Qobj as a pytree node when importing qutip-jax.
Also giving Qobj
a way to extract the specializations and skip the dispatcher will probably be required.
We should be able to set it to work with .data
.
Thank you @Ericgig and @quantshah for making this repo and pushing this forward!! This is amazing!
I just wanna add a small bit of my understanding of JAX JIT. It seems to me that JIT (and jax.grad etc.) works as long as the function-to-be-jitted only takes inputs and gives outputs that are supported by JAX, e.g jnp.arrays
, and the input-dependent
part only uses JAX supported operations. Any other complication in generating this function doesn't matter as long as their compiler can go through the Python script and translate it to XLA code at the run time. So in principle, the current implementation should be sufficient as long as we are a little bit careful with the function we jit.
The following JIT example works fine for me with the current master branch.
import qutip
import qutip_jax
import jax
@jax.jit
def fun(a):
M = qutip.sigmay().to("jax")
N = qutip.sigmaz().to("jax")
return (a * M.conj() * N + N).data._jxa
fun(3.)
This looks sufficient to me. The input is just numbers and the output is a JAX array.
Maybe instead of making the whole Qobj
compatible with JAX JIT, we just need to add an additional wrapper that transfer the final jax array to Qobj
?
We need to define what is sufficient.
Are Qobj
as output only, enough?
In the example, it fails if we don't manually convert to jax
. Do we want this to be automatic?
Do we want only operation to be supported or should it work with ptrace
or eigenenergies
also?
Are Qobj as output only, enough?
No. This is just an example showing that even if you cannot directly jit
a function that returns a Qobj
. You can still define a wrapper to make it compatible fairly easily. I'm not sure how difficult it is to represent the Qobj
class with a pytree. If it is too much work, maybe we can define a qutip_jax.jit
that automatically transfers all the Qobj
to JAX array and transfer the output, (if it is a JAX array or Pytrees of them) back to Qobj
?
In the way that JAX implemented things, you cannot jit any function but only those that are pure functions. It is very likely that we can never make jax.jit
work with qutip.mesolve
. We can only jit a customized integrator and return the result in Qobj
.
In the example, it fails if we don't manually convert to jax. Do we want this to be automatic?
Yes, that would be great. It should be feasible with some global settings? E.g. with a default dtype
.
Do we want only operation to be supported or should it work with ptrace or eigenenergies also?
Yes it should also work with other Qobj
operations like tensor
and ptrace
. But it should be just the same as adding the specialisations like adding and multiplication, no? In principle, most of them should work by replacing np
with jnp
. The only caveat is that jnp.array
does not support NumPy in-place assignment. One has to use their own syntax.
tensor
will be just adding a new specialisations, but ptrace
will not. You cannot branch on input value with jit
per default, so ptrace
's sel
will cause issues. eigenstates
return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...
solver
and integrator
cannot be jitted
, nor it makes any sense to try to. But we need to think about getting grad
working with solvers.
Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj and it all works with the extra step of .to('jax'). The JaxArray type registered as a PyTree works for JIT and I was hoping we could do something like that for all Qobj when qutip_jax is imported to get rid of the extra .to('jax').
Several things will break when we use Jax and JIT, as Eric pointed out. Even something as simple as displace() will not work, but there are workarounds to it. I am wondering if we can somehow mark and change all functions that are non JITable natively and suggest a JITable qutip_jax version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to JIT things.
Mesolve cannot be JITted directly of course but there again it's possible to have a custom solver with Jax that is JIT able. I will try to post some examples this week.
On Fri, 7 Oct 2022 at 18:12, Eric Giguère @.***> wrote:
tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...
solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.
— Reply to this email directly, view it on GitHub https://github.com/qutip/qutip-jax/issues/6#issuecomment-1271787821, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY . You are receiving this because you were mentioned.Message ID: @.***>
On a second thought maybe we should just good documentation and a section about "sharp edges" like the Jax documentation on creating functions that are pure and only take Jax data as inputs and outputs. In the simplest use case, I was thinking of "learning unitaries" using gradient descent. You send in angles that define the unitary and get out some measure of fidelity as a number. Probably that's enough for all the JITing we want to do. Let me post a notebook and we can discuss more.
On Sun, 9 Oct 2022 at 18:45, Shahnawaz Ahmed @.***> wrote:
Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj and it all works with the extra step of .to('jax'). The JaxArray type registered as a PyTree works for JIT and I was hoping we could do something like that for all Qobj when qutip_jax is imported to get rid of the extra .to('jax').
Several things will break when we use Jax and JIT, as Eric pointed out. Even something as simple as displace() will not work, but there are workarounds to it. I am wondering if we can somehow mark and change all functions that are non JITable natively and suggest a JITable qutip_jax version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to JIT things.
Mesolve cannot be JITted directly of course but there again it's possible to have a custom solver with Jax that is JIT able. I will try to post some examples this week.
On Fri, 7 Oct 2022 at 18:12, Eric Giguère @.***> wrote:
tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...
solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.
— Reply to this email directly, view it on GitHub https://github.com/qutip/qutip-jax/issues/6#issuecomment-1271787821, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY . You are receiving this because you were mentioned.Message ID: @.***>
The benefit of using JAX is the ability to JIT compile. With the setup right now, it is not clear what's the best way to make JAX recognize QuTiP objects as valid inputs since JIT only works for pure JAX arrays. There are workarounds to it, e.g., https://github.com/google/jax/blob/cc13fd1e5892a08f5360db933d4dfd64c0fc66eb/jax/experimental/lapax.py#L164. The alternative is to use the data._jxa instead of passing around quantum objects as: