DiffEqML / torchdyn

A PyTorch library entirely dedicated to neural differential equations, implicit models and related numerical methods
https://torchdyn.org
Apache License 2.0
1.36k stars 125 forks source link

Parallelizing Neural ODEs with Functorch #87

Closed pharringtonp19 closed 3 years ago

pharringtonp19 commented 3 years ago

I tend to work with relatively small datasets, so I have found one nice advantage of working with Jax is that I can train an ensemble of networks on a single GPU via Will Whitney's approach.

I see that functorch has made it possible to now do this in pytorch.

Question: Would it be feasible to make use of functorch's implementation to train an ensemble of neuralodes on a single GPU?

Zymrael commented 3 years ago

It should definitely be possible. We were considering functorch as a dependency of torchdyn since vmap and other vectorized operations are needed to implement some variants of fancier methods, such a FwS MSLs. Perhaps a simple guide on how to do this could be useful down the line.