Open agdiaz opened 2 months ago
For parrot-optimize, the lines of code are:
# Device configuration
if forceCPU:
device = 'cpu'
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Suggested lines:
parser.add_argument('--mps', action='store_true',
help="Flag which, if provided, ensures MPS is supported")
# ...
# Device configuration
has_mps = torch.backends.mps.is_built()
if forceCPU:
device = 'cpu'
elif has_mps and args.mps:
device = torch.device('mps')
else:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
Dear code maintainers, Would it be possible to include
mps
as device for running parrot as fast as using CUDA device on Macbook laptops?Here is a snippet where mps can be included:
Available on https://github.com/idptools/parrot/blob/6e09567afdc3a59d0c03f0802cf4d2fe9c973feb/scripts/parrot-train#L135C1-L142C5
This code could help to integrate MPS:
Thanks in advance