Lightning-AI / litdata

Transform datasets at scale. Optimize datasets for fast AI model training.
Apache License 2.0
375 stars 43 forks source link

`map` device management #15

Open ethanwharris opened 9 months ago

ethanwharris commented 9 months ago

🚀 Feature

Provide an easy or automated way to get batches + models on to the correct device with map.

Motivation

We often want to map over a bunch of GPU machines, maybe each with more that one GPU on board. Right now, deciding which device to use in each process is a little tricky, you have to get the rank modulo the number of CUDA devices.

Pitch

Probably the cleanest thing would be to just automatically handle devices more like a LightningModule - maybe if you pass an nn.Module to map we could put it on a correct device for the process and wrangle the inputs / outputs.

Alternatives

Additional context

tchaton commented 9 months ago

Hey @ethanwharris,

The map and optimize callable supports some sort of device already. You just need to add the optional device argument to your callable.

def fn(..., device):
    ...

map(fn, ...)

Does this solve your issue ?

ethanwharris commented 9 months ago

@tchaton

Cool, yeah so we can probably mark this as done - although hard to discover that you can do it, might still be nice to just pass an nn.Module or something and have it done automatically 😃