Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.35k stars 3.39k forks source link

Custom batch sampler fails to re-instantiate in `_dataloader_init_kwargs_resolve_sampler` #20272

Closed Kami-chanw closed 2 months ago

Kami-chanw commented 2 months ago

Outline & Motivation

I created a batch sampler of batches which is used to sample a larger batch. For example, assume we have a PyTorch BatchSampler which yields a batch of batch_size=3. I can use my custom batch sampler whose batch_size=5 to sample 5 times from underlying batch sampler to yield a large batch of batch_size=15.

However, in function _dataloader_init_kwargs_resolve_sampler, it tries to inject a normal sampler (only yield one batch) to my custom batch sampler.

Pitch

I want to know what is the correct approach in my situation.

Additional context

No response

cc @justusschock @awaelchli

Kami-chanw commented 2 months ago

I solved this problem with following approach. (Perhaps nobody will meet the same problem as me)

  1. Add use_distributed_sampler=False in Trainer when instantiate it.
  2. Replace all sampler that yield one batch (such as RandomSampler, SequentialSamper etc) with DistributedSampler when trainer.world_size > 1.