qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

Initial JAX data layer definition #1

Closed Ericgig closed 1 year ago

Ericgig commented 1 year ago

@quantshah, just getting the ball rolling.

The readme, setup, etc. have been copied from qutip-tensorflow, but being not familiar with starting a new repo, I skipped that initial PR.

quantshah commented 1 year ago

Awesome Eric. Let's do this.

quantshah commented 1 year ago

I tried to see the most simple usecase for JIT using the JaxArray.trace function with a test. It will not work since JIT expects only JAX types. We have to register JaxArray as a PyTree so that JAX knows how to pack and unpack JaxArray (see strategy 3: https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods). This was the most critical issue for me in making QuTiP and JAX work before.

It works with a simple registering of the JaxArray class as a PyTree (see the new test). We should keep testing the JITing as we develop because that is where all the power comes from.

quantshah commented 1 year ago

I am also switching off the workflow to build documentation for now.

quantshah commented 1 year ago

Basic conversion to and from JAX works and we can create a Qobj. I also added the ._jxa attribute to the JaxArray similar to ._tf for qutip-tensorflow.

Should we merge this and add all the other functionalities? @Ericgig

quantshah commented 1 year ago

Something weird is happening though since sigmax().to('jax') does not work even if sigmax().to('jax').data runs fine. @AGaliciaMartinez any idea why? Am I using the data layer incorrectly?

Ericgig commented 1 year ago

@quantshah The patch fix the error you had. Let add some more tests and merge.

Ericgig commented 1 year ago

I added tests for using JaxArray with the Qobj interface and converting to and from other type. The mathematical operations are not covered yet, but I think we can add those tests when we create the matching dispatched function.

I think we can merge this.

quantshah commented 1 year ago

Merging. This is really cool, I was already able to use this for a simple variational optimization algorithm. However there is a small catch that JAX might not be able to trace QuTiP functions, e.g., I used the qutip.tensor function and JAX could run the code to produce output but autodifferentiation did not work. After this is merged, I will post a use case and test to highlight that and we can see how to deal with it.

Ericgig commented 1 year ago

Right now, no operation have been added, so it will always convert it to Dense when trying to do something with it. tensor need the specialization for the kron product. When we add that, it should work.