qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

JIT compilation using Qobj #6

Open quantshah opened 1 year ago

quantshah commented 1 year ago

The benefit of using JAX is the ability to JIT compile. With the setup right now, it is not clear what's the best way to make JAX recognize QuTiP objects as valid inputs since JIT only works for pure JAX arrays. There are workarounds to it, e.g., https://github.com/google/jax/blob/cc13fd1e5892a08f5360db933d4dfd64c0fc66eb/jax/experimental/lapax.py#L164. The alternative is to use the data._jxa instead of passing around quantum objects as:

def test_jit_from_jxa():
    """Test JIT of add using the _jxa array"""

    @jax.jit
    def func():
        return sigmax().to("jax").data._jxa + sigmay().to("jax").data._jxa

    assert isinstance(func(), jax.interpreters.xla.DeviceArray)

def test_jit_from_qobj():
    """Test JIT of add directly using Qobj"""

    @jax.jit
    def func():
        return sigmax() + sigmay()

    assert isinstance(func(), jax.interpreters.xla.DeviceArray)
Ericgig commented 1 year ago

Qobj are designed to work mixing multiple data type mixed and it's not clear that jax can find the mathematics operation in a dispatched function. The dispatcher itself is cython compiled.

We could try to register the Qobj as a pytree node when importing qutip-jax. Also giving Qobj a way to extract the specializations and skip the dispatcher will probably be required.

We should be able to set it to work with .data.

BoxiLi commented 1 year ago

Thank you @Ericgig and @quantshah for making this repo and pushing this forward!! This is amazing!

I just wanna add a small bit of my understanding of JAX JIT. It seems to me that JIT (and jax.grad etc.) works as long as the function-to-be-jitted only takes inputs and gives outputs that are supported by JAX, e.g jnp.arrays, and the input-dependent part only uses JAX supported operations. Any other complication in generating this function doesn't matter as long as their compiler can go through the Python script and translate it to XLA code at the run time. So in principle, the current implementation should be sufficient as long as we are a little bit careful with the function we jit.

The following JIT example works fine for me with the current master branch.

import qutip
import qutip_jax
import jax

@jax.jit
def fun(a):
    M = qutip.sigmay().to("jax")
    N = qutip.sigmaz().to("jax")
    return (a * M.conj() * N + N).data._jxa

fun(3.)

This looks sufficient to me. The input is just numbers and the output is a JAX array.

Maybe instead of making the whole Qobj compatible with JAX JIT, we just need to add an additional wrapper that transfer the final jax array to Qobj?

Ericgig commented 1 year ago

We need to define what is sufficient.

Are Qobj as output only, enough? In the example, it fails if we don't manually convert to jax. Do we want this to be automatic? Do we want only operation to be supported or should it work with ptrace or eigenenergies also?

BoxiLi commented 1 year ago

Are Qobj as output only, enough?

No. This is just an example showing that even if you cannot directly jit a function that returns a Qobj. You can still define a wrapper to make it compatible fairly easily. I'm not sure how difficult it is to represent the Qobj class with a pytree. If it is too much work, maybe we can define a qutip_jax.jit that automatically transfers all the Qobj to JAX array and transfer the output, (if it is a JAX array or Pytrees of them) back to Qobj?

In the way that JAX implemented things, you cannot jit any function but only those that are pure functions. It is very likely that we can never make jax.jit work with qutip.mesolve. We can only jit a customized integrator and return the result in Qobj.

In the example, it fails if we don't manually convert to jax. Do we want this to be automatic?

Yes, that would be great. It should be feasible with some global settings? E.g. with a default dtype.

Do we want only operation to be supported or should it work with ptrace or eigenenergies also?

Yes it should also work with other Qobj operations like tensor and ptrace. But it should be just the same as adding the specialisations like adding and multiplication, no? In principle, most of them should work by replacing np with jnp. The only caveat is that jnp.array does not support NumPy in-place assignment. One has to use their own syntax.

Ericgig commented 1 year ago

tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...

solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.

quantshah commented 1 year ago

Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj and it all works with the extra step of .to('jax'). The JaxArray type registered as a PyTree works for JIT and I was hoping we could do something like that for all Qobj when qutip_jax is imported to get rid of the extra .to('jax').

Several things will break when we use Jax and JIT, as Eric pointed out. Even something as simple as displace() will not work, but there are workarounds to it. I am wondering if we can somehow mark and change all functions that are non JITable natively and suggest a JITable qutip_jax version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to JIT things.

Mesolve cannot be JITted directly of course but there again it's possible to have a custom solver with Jax that is JIT able. I will try to post some examples this week.

On Fri, 7 Oct 2022 at 18:12, Eric Giguère @.***> wrote:

tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...

solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.

— Reply to this email directly, view it on GitHub https://github.com/qutip/qutip-jax/issues/6#issuecomment-1271787821, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY . You are receiving this because you were mentioned.Message ID: @.***>

quantshah commented 1 year ago

On a second thought maybe we should just good documentation and a section about "sharp edges" like the Jax documentation on creating functions that are pure and only take Jax data as inputs and outputs. In the simplest use case, I was thinking of "learning unitaries" using gradient descent. You send in angles that define the unitary and get out some measure of fidelity as a number. Probably that's enough for all the JITing we want to do. Let me post a notebook and we can discuss more.

On Sun, 9 Oct 2022 at 18:45, Shahnawaz Ahmed @.***> wrote:

Thanks Boxi. Yes I have been using JIT by shuttling to and fro QuTiP Qobj and it all works with the extra step of .to('jax'). The JaxArray type registered as a PyTree works for JIT and I was hoping we could do something like that for all Qobj when qutip_jax is imported to get rid of the extra .to('jax').

Several things will break when we use Jax and JIT, as Eric pointed out. Even something as simple as displace() will not work, but there are workarounds to it. I am wondering if we can somehow mark and change all functions that are non JITable natively and suggest a JITable qutip_jax version. Eg, qutip_jax.ptrace() if the use calls ptrace and also wants to JIT things.

Mesolve cannot be JITted directly of course but there again it's possible to have a custom solver with Jax that is JIT able. I will try to post some examples this week.

On Fri, 7 Oct 2022 at 18:12, Eric Giguère @.***> wrote:

tensor will be just adding a new specialisations, but ptrace will not. You cannot branch on input value with jit per default, so ptrace's sel will cause issues. eigenstates return a pair of eigenvalues and list of Qobj, thus I don't see how that could work...

solver and integrator cannot be jitted, nor it makes any sense to try to. But we need to think about getting grad working with solvers.

— Reply to this email directly, view it on GitHub https://github.com/qutip/qutip-jax/issues/6#issuecomment-1271787821, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABVFIBDWWYNELLKAUTDMGZDWCBDXDANCNFSM6AAAAAAQ7PUMZY . You are receiving this because you were mentioned.Message ID: @.***>