Open Ericgig opened 1 year ago
I think we can be a bit more aggressive at this point with JIT and get back to jax.lax in a later iteration but keep track of where we can directly use jax.lax. I think for JITing of conditionals or for loops later we can directly use lax but at this point its probably better to keep everything within jnp.
@quantshah I have been using the
jax.numpy
interface to quickly add specialisation. But some of our function are not optimal using this interface:.dag()
callstranspose
andconj
looping 2 times. The scale in add is added in and extra loop etc. In unary operations I jitted functions that had loops that could be fused. Should I try to usejax.lax
instead? Or should I be more aggressive with jitting function?add
andmatmul
could also benefit from it because of the scale entry.