Sllambias / yucca

Apache License 2.0
17 stars 2 forks source link

Not possible to do prediction locally #184

Closed asbjrnmunk closed 2 months ago

asbjrnmunk commented 2 months ago

The following line in YuccaNet breaks CPU support on MPS accelerated devices:

...
def predict(self, mode, data, patch_size, overlap, sliding_window_prediction=True, mirror=False):
    data = data.to(torch.device(get_available_device()))
...

Since MPS doesn't support Conv3D, when working locally one is forced to use the CPU. However, get_available_device() will return mps, such that even though the trainer was instantiated using L.Trainer(accelerator='cpu'). The implication is that one will get errors such as

RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

or

RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same

Fix: Find a way to use the accelerator chosen by lightning, since this will be where the data is moved to and ensure that the format is correct. In general, we should not move the data ourself, this should be done by lightning.

asbjrnmunk commented 2 months ago

Fixed by #183