Open GallagherCommaJack opened 1 year ago
Thanks for the feature request, @GallagherCommaJack. This is definitely on our radar. Out of curiosity (and to serve as test cases), do you have example of models in jax that use scan that you wanted to port over to torch?
I'm not the author, but one very useful test case is a sequential model (like a custom RNN). For example, this page implements a recurrent model using scan
+1 for this. I find it extremely useful for implementing rollouts in RL and RNNs.
it would be really nice to be able to eg take models implemented in jax with
jax.lax.scan
and port them over to torch without having to unroll scans over modules