Closed pharringtonp19 closed 3 years ago
It should definitely be possible. We were considering functorch
as a dependency of torchdyn
since vmap
and other vectorized operations are needed to implement some variants of fancier methods, such a FwS MSLs. Perhaps a simple guide on how to do this could be useful down the line.
I tend to work with relatively small datasets, so I have found one nice advantage of working with Jax is that I can train an ensemble of networks on a single GPU via Will Whitney's approach.
I see that functorch has made it possible to now do this in pytorch.
Question: Would it be feasible to make use of functorch's implementation to train an ensemble of neuralodes on a single GPU?