-
I am having trouble when updating a jitted function within a class. The function is not updated and my guess that is not being recompiled by jit after the update. It is worth noting that this only hap…
-
-
You can actually set static pytree leaves, this would overall allow for string values to be set to the class, while still having them recognised as pure-jax pytrees.
This issue is the method is mu…
-
Hello,
When solving the (trivial) SDE $d y_t = -y_t\ dt + 0.2\ dW_t$, the Diffrax Euler solver is ~200x slower than a naive for loop. Am I doing something wrong? The speed difference is consistent …
-
Is it possible to do the following?
I would like to take a gradient through a an argmin involving a closure, where the function passed to the solver contains `y`, which is not passed as an argument…
JTT94 updated
2 years ago
-
Hello, this looks like a really cool library that I would love to use in my research. However, I am not able to run the provided example code using either the PyTorch or JAX backend. Specifically, her…
-
I am interested in porting some object detection models from PyTorch to JAX. In particular, I am looking at the RCNN family of models.
After looking into some existing implementations in TF/PyTorc…
-
### Description
[JAX](https://github.com/google/jax) includes a numpy compatible `jax.numpy` module which has a bunch of nice features (automatic differentiation, jit compilation, vectorized mappin…
-
### System information
flax Version: 0.9.0
jax Version: 0.4.34
jaxlib Version: 0.4.34
on my laptop, running in CPU only mode.
flax Version: 0.8.5
jax Version: 0.4.33
jaxlib Version: 0.4.33
on goo…
-
Dear Jax team,
I am building an optics package [morphine](https://github.com/benjaminpope/morphine) that uses Jax to do autodiff for telescope simulations.
I'm encountering a bug
`Unexpecte…