pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

support scan #1036

Open GallagherCommaJack opened 1 year ago

GallagherCommaJack commented 1 year ago

it would be really nice to be able to eg take models implemented in jax with jax.lax.scan and port them over to torch without having to unroll scans over modules

zou3519 commented 1 year ago

Thanks for the feature request, @GallagherCommaJack. This is definitely on our radar. Out of curiosity (and to serve as test cases), do you have example of models in jax that use scan that you wanted to port over to torch?

hbenazha commented 1 year ago

I'm not the author, but one very useful test case is a sequential model (like a custom RNN). For example, this page implements a recurrent model using scan

subho406 commented 1 year ago

+1 for this. I find it extremely useful for implementing rollouts in RL and RNNs.