Closed Ericgig closed 1 year ago
Awesome Eric. Let's do this.
I tried to see the most simple usecase for JIT using the JaxArray.trace function with a test. It will not work since JIT expects only JAX types. We have to register JaxArray as a PyTree so that JAX knows how to pack and unpack JaxArray (see strategy 3: https://jax.readthedocs.io/en/latest/faq.html#how-to-use-jit-with-methods). This was the most critical issue for me in making QuTiP and JAX work before.
It works with a simple registering of the JaxArray class as a PyTree (see the new test). We should keep testing the JITing as we develop because that is where all the power comes from.
I am also switching off the workflow to build documentation for now.
Basic conversion to and from JAX works and we can create a Qobj. I also added the ._jxa attribute to the JaxArray similar to ._tf for qutip-tensorflow.
Should we merge this and add all the other functionalities? @Ericgig
Something weird is happening though since sigmax().to('jax')
does not work even if sigmax().to('jax').data
runs fine. @AGaliciaMartinez any idea why? Am I using the data layer incorrectly?
@quantshah The patch fix the error you had. Let add some more tests and merge.
I added tests for using JaxArray with the Qobj interface and converting to and from other type. The mathematical operations are not covered yet, but I think we can add those tests when we create the matching dispatched function.
I think we can merge this.
Merging. This is really cool, I was already able to use this for a simple variational optimization algorithm. However there is a small catch that JAX might not be able to trace QuTiP functions, e.g., I used the qutip.tensor function and JAX could run the code to produce output but autodifferentiation did not work. After this is merged, I will post a use case and test to highlight that and we can see how to deal with it.
Right now, no operation have been added, so it will always convert it to Dense
when trying to do something with it.
tensor
need the specialization for the kron
product. When we add that, it should work.
@quantshah, just getting the ball rolling.
The readme, setup, etc. have been copied from
qutip-tensorflow
, but being not familiar with starting a new repo, I skipped that initial PR.