Suggestions on more rigorous testing are welcomed!
NB: On my specific setup, loading a model trained with MPS fails with RuntimeError: don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps:0), but it seems like this is a known and intermittent issue.
Fixes (junyanz/pytorch-CycleGAN-and-pix2pix#1441)
Change list
--device_type
and--device_ids
which allow torch backend and device ordinals to be specified--gpu_ids
argument with some special logic (it would be cleaner but non-backwards compatible to remove it)Testing
python train.py --dataroot ./datasets/maps --name maps --model pix2pix --direction AtoB --device_type mps
seem reasonable.NB: On my specific setup, loading a model trained with MPS fails with
RuntimeError: don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps:0)
, but it seems like this is a known and intermittent issue.