Open ethanwharris opened 9 months ago
Hey @ethanwharris,
The map and optimize callable supports some sort of device already. You just need to add the optional device argument to your callable.
def fn(..., device):
...
map(fn, ...)
Does this solve your issue ?
@tchaton
Cool, yeah so we can probably mark this as done - although hard to discover that you can do it, might still be nice to just pass an nn.Module or something and have it done automatically 😃
🚀 Feature
Provide an easy or automated way to get batches + models on to the correct device with map.
Motivation
We often want to map over a bunch of GPU machines, maybe each with more that one GPU on board. Right now, deciding which device to use in each process is a little tricky, you have to get the rank modulo the number of CUDA devices.
Pitch
Probably the cleanest thing would be to just automatically handle devices more like a LightningModule - maybe if you pass an nn.Module to map we could put it on a correct device for the process and wrangle the inputs / outputs.
Alternatives
Additional context