qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

fix _eigs_jaxarray to be compatible with jit #52

Closed rochisha0 closed 3 weeks ago

rochisha0 commented 3 weeks ago

Description

This PR is in an effort to enable jax.jit with qutip.core.metrics and entropy. It fixes _eigs_jaxarray to be compatible with jax.jit.

Result

With this change trace_dist works with jax.jit

rochisha0 commented 3 weeks ago

@Ericgig I have made some changes, please have a look. As far as I tested this way we can work around jax.lax.cond.

rochisha0 commented 3 weeks ago

@Ericgig You are right I missed that jax only supports grad of eigenvalues not eigenvectors.

Ericgig commented 3 weeks ago

Look good. Does is work well when isherm=None, that is not included in our tests cases.

rochisha0 commented 3 weeks ago

@Ericgig I have added the suggested changes and the required tests are now passing.

rochisha0 commented 3 weeks ago

Yes, it works. @Ericgig

Ericgig commented 3 weeks ago

Could you add/modify a test for it.