Lightning-AI / litdata

Streamline data pipelines for AI. Process datasets across 1000s of machines, and optimize data for blazing fast model training.
Apache License 2.0
249 stars 23 forks source link

`map` device management #15

Open ethanwharris opened 4 months ago

ethanwharris commented 4 months ago

🚀 Feature

Provide an easy or automated way to get batches + models on to the correct device with map.

Motivation

We often want to map over a bunch of GPU machines, maybe each with more that one GPU on board. Right now, deciding which device to use in each process is a little tricky, you have to get the rank modulo the number of CUDA devices.

Pitch

Probably the cleanest thing would be to just automatically handle devices more like a LightningModule - maybe if you pass an nn.Module to map we could put it on a correct device for the process and wrangle the inputs / outputs.

Alternatives

Additional context

tchaton commented 4 months ago

Hey @ethanwharris,

The map and optimize callable supports some sort of device already. You just need to add the optional device argument to your callable.

def fn(..., device):
    ...

map(fn, ...)

Does this solve your issue ?

ethanwharris commented 4 months ago

@tchaton

Cool, yeah so we can probably mark this as done - although hard to discover that you can do it, might still be nice to just pass an nn.Module or something and have it done automatically 😃