Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.93k stars 3.34k forks source link

Interleaved Model for combined_loader #17910

Open wrmthorne opened 1 year ago

wrmthorne commented 1 year ago

Description & Motivation

Description: A mode for combined_loader which would evenly distribute samples from all iterables.

Motivation: While working with cycle consistency training, two models are updated iteratively on the output of the other model. The current modes for combined_loader do not train these models optimally:

min_size: Wastes valuable data from larger dataset max_size: One model is trained on the static state of the other, causing it to overfit to an old model state max_size_cycle: May cause overfitting on smaller dataset if dataset sizes are significantly different sequential: Same problem as max_size but worse

Pitch

Interleaved mode which calculates how to evenly distribute one dataset with the other. This comes with the caveat of knowing the lengths of each iterable which may not be possible.

e.g.

iterables = {'a': DataLoader(range(5), batch_size=1),
             'b': DataLoader(range(3), batch_size=1)}
combined_loader = CombinedLoader(iterables, 'interleaved')

for batch in combined_loader;
    print(batch)
# {'a': tensor([0]), 'b': tensor([0])}
# {'a': tensor([1]), 'b': None}
# {'a': tensor([2]), 'b': tensor([1])}
# {'a': tensor([3]), 'b': None}
# {'a': tensor([4]), 'b': tensor([2])}

Alternatives

No response

Additional context

I don't believe the __len__ issue would be too damaging as pytorch dataloaders implement a __len__ method. Obtaining iterable lengths will get stuck on infinite iterables but lengths could otherwise be obtained with:

sum(1 for _ in iterable)

cc @borda

chaitanya100100 commented 3 months ago

Is there any update on this?

wrmthorne commented 3 months ago

As I never got a response, I assumed it was not something that was really wanted or was too small a feature to spend development time and effort on. I have been too occupied with other work since making this feature request so I haven't gotten around to it. If I ever implement it, I'll update here.