idptools / parrot

Python package for protein sequence-based bidirectional recurrent neural network. Generalizable to a variety of protein bioinformatic applications.
MIT License
16 stars 2 forks source link

Include MPS as accelerator for Apple ARM machines #14

Open agdiaz opened 2 months ago

agdiaz commented 2 months ago

Dear code maintainers, Would it be possible to include mps as device for running parrot as fast as using CUDA device on Macbook laptops?

Here is a snippet where mps can be included:

# Device configuration
if forceCPU:
    device = 'cpu'
elif gpu_id:
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else 'cpu')
    print(f"You've specified to run this network on cuda:{gpu_id}. Running on {device=}")
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Available on https://github.com/idptools/parrot/blob/6e09567afdc3a59d0c03f0802cf4d2fe9c973feb/scripts/parrot-train#L135C1-L142C5

This code could help to integrate MPS:

has_mps = torch.backends.mps.is_built()
device = "mps" if has_mps else "cuda" if torch.cuda.is_available() else "cpu"

Thanks in advance

agdiaz commented 2 months ago

For parrot-optimize, the lines of code are:

# Device configuration
if forceCPU:
    device = 'cpu'
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

https://github.com/idptools/parrot/blob/6e09567afdc3a59d0c03f0802cf4d2fe9c973feb/scripts/parrot-train#L135C1-L142C5

Suggested lines:

parser.add_argument('--mps', action='store_true',
                    help="Flag which, if provided, ensures MPS is supported")

# ...
# Device configuration
has_mps = torch.backends.mps.is_built()
if forceCPU:
    device = 'cpu'
elif has_mps and args.mps:
    device = torch.device('mps')
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')