Open Jacobiano opened 9 months ago
Hi,
Thanks for reporting the issue. This issue seems to be more specific to Jax, have you tried the solution as suggested in the error message?
I think it is possible to implement in jax in another way. But I report this issue, because in the idea of having multiple backend, I found it strange that jax does not manage to use well the while_loop function.
I think neither JAX nor keras is at fault here. The while_loop
works as expected but there are certain limitations to it. For JAX, it is always advisable to use scan
because it is much better in almost every aspect.
I have implemented a layer to calculate the morphology reconstruction. When it is used without being considered in the backprop, it works fine in all three backends. But when the backprop has to be calculated, in the TF or Pytorch backend it works without problem, but not in JAX.
JAX ERROR: -> 1549 raise ValueError("Reverse-mode differentiation does not work for " 1550 "lax.while_loop or lax.fori_loop with dynamic start/stop values. " 1551 "Try using lax.scan, or using fori_loop with static start/stop.")
ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.
The code is available in (https://colab.research.google.com/drive/1bWQO6TAQeN_-a0y6iY7b_jlnGzv-XRdv?usp=sharing)