junyanz / pytorch-CycleGAN-and-pix2pix

Image-to-Image Translation in PyTorch
Other
23.09k stars 6.32k forks source link

Add support for other pytorch device types, including MPS #1445

Open adamobeng opened 2 years ago

adamobeng commented 2 years ago

Fixes (junyanz/pytorch-CycleGAN-and-pix2pix#1441)

Change list

  1. Add command line arguments --device_type and --device_ids which allow torch backend and device ordinals to be specified
  2. Make code specific to GPUs/cuda device-agnostic (in particular by using a list of torch devices rather than GPU ids)
  3. Maintain support for --gpu_ids argument with some special logic (it would be cleaner but non-backwards compatible to remove it)
  4. Add some tests of the argument parsing

Testing

NB: On my specific setup, loading a model trained with MPS fails with RuntimeError: don't know how to restore data location of torch.storage._UntypedStorage (tagged with mps:0), but it seems like this is a known and intermittent issue.