The original DARTS code (https://github.com/quark0/darts) has .cuda() hard-coded everywhere. At development & debugging stage, we don't really need a GPU. For example to assert the input shape (#7), it is much more convenient to use the local laptop CPU.
A better way is
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
and use .to(device) instead of .cuda()
However, .cuda() appears at too many places, so I am not sure if this refactor is worthwhile.
The original DARTS code (https://github.com/quark0/darts) has
.cuda()
hard-coded everywhere. At development & debugging stage, we don't really need a GPU. For example to assert the input shape (#7), it is much more convenient to use the local laptop CPU.A better way is
and use
.to(device)
instead of.cuda()
However,
.cuda()
appears at too many places, so I am not sure if this refactor is worthwhile.