pytorch / text

Models, data loaders and abstractions for language processing, powered by PyTorch
https://pytorch.org/text
BSD 3-Clause "New" or "Revised" License
3.5k stars 812 forks source link

How to train data with the similar number of tokens in a batch using distributed training? #1295

Open sandthou opened 3 years ago

sandthou commented 3 years ago

My code needs two functions:

  1. Bucket iterator;
  2. In each batch, the number of tokens are similar. (This means the batch size of each batch is not same.)

I think I could fulfill the function 2 with a custom sampler which inherits torch.utils.data.Sampler, but as seen in the tutorial, Bucket iterator inherits torch.utils.data.Dataset, and for distributed training, the torch.utils.data.distributed.DistributeSampler should be used. The custom sampler and the DistributedSampler can’t both be used in torch.utils.data.DataLoader (dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False).

So, how to sample data (sentences) in a batch with the similar number of tokens for distributed training?

Thanks a lot.

parmeet commented 3 years ago

Hi, for general implementation strategies and questions about usage, PyTorch Forums https://discuss.pytorch.org/ is the go-to place. Could you please post your question there? Thanks!

Also Please note that we have deprecated Bucket Iterator. You may refer to migration tutorial to look for alternative implementation strategy. Hope this helps!

sandthou commented 3 years ago

Hi, for general implementation strategies and questions about usage, PyTorch Forums https://discuss.pytorch.org/ is the go-to place. Could you please post your question there? Thanks!

Also Please note that we have deprecated Bucket Iterator. You may refer to migration tutorial to look for alternative implementation strategy. Hope this helps!

The bucket iterator that my code uses now is the bucket_dataloader in the migration tutorial and it works well! But I just wonder how to modify the code in the tutorial to batch sentences with a similar number of tokens. And maybe it is true that how to train distributed with different a batch size is a general question and I would put this issue on Pytorch Forums as well.

Thanks for your reply!

parmeet commented 3 years ago

The bucket_dataloader in the migration tutorial indeed is batching sentences with similar number of tokens. The logic to club sentences with similar length (number of tokens) is in batch_sampler generator.

sandthou commented 3 years ago

The bucket_dataloader in the migration tutorial indeed is batching sentences with similar number of tokens. The logic to club sentences with similar length (number of tokens) is in batch_sampler generator.

The number of tokens I mean is the total number of tokens in a batch, however, the bucket_dataloader can not make a batch with a similar number of total tokens in a batch.

To be more specific, I want to reproduce the results of WMT14 translation task in the paper, "Attention is all you need", which said, "Sentence pairs were batched together by approximate sequence length. Each training batch contained a set of sentence pairs containing approximately 25000 source tokens and 25000 target tokens". I think, as you said, the bucket_dataloader could batch sentence pairs by approximate sequence length, but I am having trouble with how to make each batch containing approximately 25000 source tokens and 25000 target tokens. This means the batch size would be changed due to the different lengths of sentences in each batch. If I use sampler to achieve that, I can not use DistributedSampler to train distributed.

parmeet commented 3 years ago

Thanks for the clarification. Yes, batch_sampler above may not be well positioned to solve this..

One way to deal with this is to completely ditch collation by DataLoader (https://pytorch.org/docs/stable/data.html#disable-automatic-batching) and handle the collation directly inside your dataset (now you are free to collate the samples so that each batch consist of roughly 25000 src/target tokens). Then you can use Distributed sampler with this dataset and pass it to you DataLoader (make sure to disable automatic batching). Note that Distributed sampler (as with all other samplers) yield indices over your dataset except that with distributed sampler you get sub-set of indices in each worker. Since you have already taken care of collation with similar number of tokens, your DataLoader now is simply fetching these batches.

cc: @cpuhrsch @VitalyFedyunin

sandthou commented 3 years ago

Thanks for the clarification. Yes, batch_sampler above may not be well positioned to solve this..

One way to deal with this is to completely ditch collation by DataLoader (https://pytorch.org/docs/stable/data.html#disable-automatic-batching) and handle the collation directly inside your dataset (now you are free to collate the samples so that each batch consist of roughly 25000 src/target tokens). Then you can use Distributed sampler with this dataset and pass it to you DataLoader (make sure to disable automatic batching). Note that Distributed sampler (as with all other samplers) yield indices over your dataset except that with distributed sampler you get sub-set of indices in each worker. Since you have already taken care of collation with similar number of tokens, your DataLoader now is simply fetching these batches.

cc: @cpuhrsch @VitalyFedyunin

Thanks, it really helps! But I still have a question... When automatic batching is disabled, how to batch samples by collate_fn? Pytorch document said, "When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator." It seems like only one sample could be collated by collate_fn each time?

parmeet commented 3 years ago

Thanks for the clarification. Yes, batch_sampler above may not be well positioned to solve this.. One way to deal with this is to completely ditch collation by DataLoader (https://pytorch.org/docs/stable/data.html#disable-automatic-batching) and handle the collation directly inside your dataset (now you are free to collate the samples so that each batch consist of roughly 25000 src/target tokens). Then you can use Distributed sampler with this dataset and pass it to you DataLoader (make sure to disable automatic batching). Note that Distributed sampler (as with all other samplers) yield indices over your dataset except that with distributed sampler you get sub-set of indices in each worker. Since you have already taken care of collation with similar number of tokens, your DataLoader now is simply fetching these batches. cc: @cpuhrsch @VitalyFedyunin

Thanks, it really helps! But I still have a question... When automatic batching is disabled, how to batch samples by collate_fn? Pytorch document said, "When automatic batching is disabled, collate_fn is called with each individual data sample, and the output is yielded from the data loader iterator." It seems like only one sample could be collated by collate_fn each time?

Yes, that's the idea. Essentially in this case we are no longer relying on DataLoader collation. We can potentially collate samples directly inside dataset and return the collated samples (not the individual ones) from __getitem__. As stated in the documentation In certain cases, users may want to handle batching manually in dataset code, or simply load individual samples., we are basically handling batching manually in dataset code.

parmeet commented 3 years ago

cc: @hudeven