For training reasons I had to write my own sampler. Something like:
class MySampler(torch-utils.data.Sampler):
def __init__(self, data, batches_per_epoch, batch_size):
# some python code
def __iter__():
# My iter method obeying specific rules
I no longer that an error. However something strange seems to happen. On my local machine, when I use more workers each epoch takes longer. Why is that? Which exactly are the effects on the distributed dataloading of using replace_sampler_ddp=False?
I could not find clear documentation on this particular topic:
Does every worker have its own copy of the sampler?
If so, are there in fact more batches being computed in every epoch?
How can I wrap my own sampler for ddp? Is there a way to instantiate the sampler in a way such that every worker will handle different batches:
For training reasons I had to write my own sampler. Something like:
To create the data loader I then simply use:
This works great up to the point where I try to use ray lightning. At first I tried to use ray lightning as follows:
Which raised the error:
I then saw that there is a FLAG that disables sampler replacement:
replace_sampler_ddp
. Using this code:I no longer that an error. However something strange seems to happen. On my local machine, when I use more workers each epoch takes longer. Why is that? Which exactly are the effects on the distributed dataloading of using
replace_sampler_ddp=False
?I could not find clear documentation on this particular topic:
For example if I use:
Will this be equivalent for, for example 1 and 4 workers?