Open itsakk opened 11 months ago
I believe that you are right about the reason: it seems most likely that slow compilation time is caused by JAX unrolling the for i, eval_point ...
loop.
About your code: I think you were on the right track with using lax.scan
(an alternative would be to use lax.for_i_loop
but both these loop primitives are essentially the same). In addition to simply using it instead of the for-loop, you should use lax.cond
instead of if-else conditionals, and jax.numpy.take
to slice the arrays. I would suggest for you to try to replace all the Python list operations with some jax.numpy
operations. The reason for all of this is so that JAX won't have a static, trace-time dependency on the values, thus allowing you to use lax.scan
instead of the for-loop. There are a lot of changes to make, so good luck!
I strongly recommend you to read through The Sharp Bits, which is one of the most useful resources when starting with JAX. And, if you have time, try to read through all the tutorials on the JAX website under the 'Getting Started' section: they contain a lot of crucial information that allows to debug this type of problems faster. You also might find this answer interesting, since it gives a bit more explanations about why it takes a lot of time to compile the for-loop.
Hello Patrick,
I recently tried to move from torch to jax/equinox (thanks a lot for your contributions both for diffrax and equinox) and i tried to adapt all my code written in torch to jax/equinox.
I am having trouble trying to implement teacher forcing for my NeuralODE. I was used to do that with pytorch as I remarked it was helping the network to not diverge from the ground truth trajectory.
The issue is that my code takes a lot of time to compile and thus training the network takes way more time than when using torch. It is probably due to the for loop, if my understanding is correct from the different comments you made regarding that type of errors.
I have tried to use jax.lax.scan but I don't find a proper way to do that when doing teacher forcing.
Here is my code, you can consider that my network is just a NODE with a MLP network:
Any suggestions to increase efficiency are greatly appreciated. Thanks a lot!