qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
18 stars 7 forks source link

Use `jax.lax`? #7

Open Ericgig opened 1 year ago

Ericgig commented 1 year ago

@quantshah I have been using the jax.numpy interface to quickly add specialisation. But some of our function are not optimal using this interface: .dag() calls transpose and conj looping 2 times. The scale in add is added in and extra loop etc. In unary operations I jitted functions that had loops that could be fused. Should I try to use jax.lax instead? Or should I be more aggressive with jitting function? add and matmul could also benefit from it because of the scale entry.

quantshah commented 1 year ago

I think we can be a bit more aggressive at this point with JIT and get back to jax.lax in a later iteration but keep track of where we can directly use jax.lax. I think for JITing of conditionals or for loops later we can directly use lax but at this point its probably better to keep everything within jnp.