f90 / Wave-U-Net-Pytorch

Improved Wave-U-Net implemented in Pytorch
MIT License
293 stars 62 forks source link

Support for Apple Metal (MPS) backend #16

Open Archie3d opened 1 year ago

Archie3d commented 1 year ago

Please add support for the MPS backend as you do for cuda:

if torch.backends.mps.is_available():
    mps = torch.device("mps")
    model = model_utils.DataParallel(model)
    model.to(mps)

# ... and so on...
x = x.to(mps)
reinerterig commented 10 months ago

Has anyone looked into implementing MPS for this yet?