pytorch / examples

A set of examples around pytorch in Vision, Text, Reinforcement Learning, etc.
https://pytorch.org/examples
BSD 3-Clause "New" or "Revised" License
22.47k stars 9.56k forks source link

Add `--device` argument to run examples on a specific device #1288

Open shink opened 2 months ago

shink commented 2 months ago

All examples can only be run on CUDA or MPS using --cuda or --mps. This change adds an argument --device to make our examples run on a specific device, e.g., --device cpu. This will benefit all device manufacturers.

cc: @Yikun @FFFrog @hipudding @jgong5 @EikanWang

netlify[bot] commented 2 months ago

Deploy Preview for pytorch-examples-preview canceled.

Name Link
Latest commit 2846161e5b0eee8d55b7e4a63d30380e802fb12b
Latest deploy log https://app.netlify.com/sites/pytorch-examples-preview/deploys/66f11872584dd60008bf36a1
shink commented 2 months ago

Could someone please review this change? If it makes sense to you, I will proceed.

msaroufim commented 2 months ago

This change would be easier to merge if we started testing M1 in CI

Do you have any experience in doing this kind of stuff?

Basically you can copy this https://github.com/pytorch/examples/blob/main/.github/workflows/main_python.yml change this line to support M1 https://github.com/pytorch/examples/blob/main/.github/workflows/main_python.yml#L16 macos-latest and then make sure you have a test file where you always pass in the M1 device

shink commented 2 months ago

@msaroufim Thanks for your review! The following code snippet is the key of this change. args.device is cpu by default so this change is compatible. As you can see, --device has lower priority than --cuda and --mps.

I tested this change on my out-of-tree backend device and the result said OK.

if args.cuda:
    device = torch.device("cuda")
elif args.mps:
    device = torch.device("mps")
else:
-   device = torch.device("cpu")
+   device = torch.device(args.device)

This change would be easier to merge if we started testing M1 in CI

ah yes! So should I test this change on MPS and add a workflow in this PR?