Open jank324 opened 1 year ago
Just tried tracking a ParameterBeam
through a Quadrupole
using torch=2.1.1
. I got
The operator 'aten::complex.out' is not currently implemented for the MPS device.
So PyTorch is not quite ready for Cheetah to support this yet.
I just ran the optimize_speed.ipynb
example notebook on #116 with MPS and I still get
NotImplementedError: The operator 'aten::complex.out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
So this operator is still not implemented.
In the nightly build of PyTorch 2.4.0. With that, using an MPS device in Cheetah does work. However, it seems about 10x slower than running Cheetah on CPU with the stable build of PyTorch 2.2.1.
Testing was done using the vectorised branch (#116) and the optimize_speed.ipynb
example notebook.
Currently, Cheetah only considers
"cpu"
and"cuda"
as possible devices. It would be nice to also make use of"mps"
on Apple Silicon Macs. First, we need to test that all PyTorch functions used by Cheetah are already MPS-compatible. SB3, for example, doesn't used MPS yet because there are some incompatible functions.