desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
33 stars 13 forks source link

Add MPS device #61

Open jank324 opened 1 year ago

jank324 commented 1 year ago

Currently, Cheetah only considers "cpu" and "cuda" as possible devices. It would be nice to also make use of "mps" on Apple Silicon Macs. First, we need to test that all PyTorch functions used by Cheetah are already MPS-compatible. SB3, for example, doesn't used MPS yet because there are some incompatible functions.

jank324 commented 9 months ago

Just tried tracking a ParameterBeam through a Quadrupole using torch=2.1.1. I got

The operator 'aten::complex.out' is not currently implemented for the MPS device.

So PyTorch is not quite ready for Cheetah to support this yet.

jank324 commented 6 months ago

I just ran the optimize_speed.ipynb example notebook on #116 with MPS and I still get

NotImplementedError: The operator 'aten::complex.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

So this operator is still not implemented.

jank324 commented 6 months ago

In the nightly build of PyTorch 2.4.0. With that, using an MPS device in Cheetah does work. However, it seems about 10x slower than running Cheetah on CPU with the stable build of PyTorch 2.2.1.

Testing was done using the vectorised branch (#116) and the optimize_speed.ipynb example notebook.