rtqichen / torchdiffeq

Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.
MIT License
5.61k stars 930 forks source link

Set default dtype #256

Closed varunagrawal closed 2 weeks ago

varunagrawal commented 2 weeks ago

This PR sets the default dtype for MPS based accelerators (aka Mac M1 etc), thus allowing use of this library for them.

Without this, the user either gets an unsupported dtype error or has to manually set it themselves. I figured the latter option could be done by torchdiffeq directly.

varunagrawal commented 2 weeks ago

I am closing this since the tests fail due to different precision types. :( Hopefully Pytorch supports float64 upstream for MPS systems soon.