Since MPS doesn't support Conv3D, when working locally one is forced to use the CPU. However, get_available_device() will return mps, such that even though the trainer was instantiated using L.Trainer(accelerator='cpu'). The implication is that one will get errors such as
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same
or
RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
Fix: Find a way to use the accelerator chosen by lightning, since this will be where the data is moved to and ensure that the format is correct. In general, we should not move the data ourself, this should be done by lightning.
The following line in
YuccaNet
breaks CPU support on MPS accelerated devices:Since MPS doesn't support Conv3D, when working locally one is forced to use the CPU. However,
get_available_device()
will returnmps
, such that even though the trainer was instantiated usingL.Trainer(accelerator='cpu')
. The implication is that one will get errors such asor
Fix: Find a way to use the accelerator chosen by lightning, since this will be where the data is moved to and ensure that the format is correct. In general, we should not move the data ourself, this should be done by lightning.