During training, every batch goes through uncollate (from dict of batches to list of sample dicts)
This loops over: |dataset| * |keys|
and is not done with multiprocessing, so it is a major bottleneck.
This is actually not needed, because we can easily gather batches without looping over the dataset samples.
for k in keys:
torch.cat(batch[k], for batch in batches], dim=0)
Attaching the profiler output on my code (50K samples, each sample is a dictionary of ~10 vectors):
During training, every batch goes through uncollate (from dict of batches to list of sample dicts) This loops over: |dataset| * |keys| and is not done with multiprocessing, so it is a major bottleneck.
This is actually not needed, because we can easily gather batches without looping over the dataset samples.
Attaching the profiler output on my code (50K samples, each sample is a dictionary of ~10 vectors):