Open muellerzr opened 1 month ago
This sounds cool! I'd like to try to contribute this, I'll start working on it this week 😄
Hi @muellerzr, I should hopefully have a PR up for this pretty soon! Just to clarify the goal here, is there a reason to write a new dataloader as opposed to adding support for custom types to the existing dataloader in Accelerate?
I assume the intent of this is to sidestep the TypeError
that you get if you try to make a torch DataLoader
that also causes problems if you have prepare
a DataLoader
through Accelerate
, e.g., in pytorch
from torch.utils.data import DataLoader
class MyIterable:
def __init__(self, data):
self.data = data
def __iter__(self):
return iter(self.data)
list(DataLoader(MyIterable(data=[1,2,3,4]), batch_size=2))
which throws TypeError: object of type 'MyIterable' has no len()
, but should be able to be treated like DataLoader([1,2,3,4], batch_size=2)
instead of throwing because it yields things that can be cast to torch tensors. Is that correct?
I think that can be done relatively easily, e.g., by adding a really thin wrapper to encapsulate the custom type as an IterableDataset
, but I assume that could also be done to extend support for existing dataloader types without a new flag 🤔 Just want to make sure I'm on the same page if a new option + data loader is added for how device placement & options like dispatch_batches
should behave, e.g., if dispatch_batches=True
and custom_types=True
, since the options currently determine which dataloader is used! Whereas if support were added to existing data loaders, that wouldn't really be an issue
@alex-jw-brooks the idea behind this is indeed as you say :)
Flag would be better, and do note that realistically dispatch_batches
or split_batches
shouldn't do anything, this is full user control of the dataloader and we just simply move it to the device.
The user's custom Iterable should determine how data is split/etc, and accelerate will move it to the device when drawn
Make a simplistic version of the
DispatchDataLoader
which allows for users to easily pass in anyIterable
type of object and call it's__iter__
. Likely situation to make sureAccelerator
knows what's up duringprepare
is to add an option to theDataLoaderConfiguration
forcustom_classes
which accepts types for us to doisinstance()
on