ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

What happens with custom samplers? #252

Open AugustoPeres opened 1 year ago

AugustoPeres commented 1 year ago

For training reasons I had to write my own sampler. Something like:

class MySampler(torch-utils.data.Sampler):
    def __init__(self, data, batches_per_epoch, batch_size):
        # some python code
    def __iter__():
        # My iter method obeying specific rules

To create the data loader I then simply use:

sampler = MySampler(data, batches_per_epoch, batch_size)
dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=MySampler)

This works great up to the point where I try to use ray lightning. At first I tried to use ray lightning as follows:

plugin = RayStrategy(num_workers=num_workers,
                     num_cpus_per_worker=num_cpus_per_worker,
                     use_gpu=use_gpu)
trainer = pl.Trainer(max_epochs=max_epochs,
                     strategy=plugin,
                     logger=False)

Which raised the error:

AttributeError: 'SeqMatchSeqSampler' object has no attribute 'drop_last'

I then saw that there is a FLAG that disables sampler replacement: replace_sampler_ddp. Using this code:

plugin = RayStrategy(num_workers=num_workers,
                     num_cpus_per_worker=num_cpus_per_worker,
                     use_gpu=use_gpu)
trainer = pl.Trainer(max_epochs=max_epochs,
                     strategy=plugin,
                     logger=False,
                     replace_sampler_ddp=False)

I no longer that an error. However something strange seems to happen. On my local machine, when I use more workers each epoch takes longer. Why is that? Which exactly are the effects on the distributed dataloading of using replace_sampler_ddp=False?

I could not find clear documentation on this particular topic:

For example if I use:

sampler = MySampler(data, int(batches_per_epoch/num_ray_workers), batch_size)

Will this be equivalent for, for example 1 and 4 workers?