PetrochukM / PyTorch-NLP

Basic Utilities for PyTorch Natural Language Processing (NLP)
https://pytorchnlp.readthedocs.io
BSD 3-Clause "New" or "Revised" License
2.21k stars 258 forks source link

MaxTokensBatchSampler #114

Open salvacarrion opened 3 years ago

salvacarrion commented 3 years ago

Is there any straightforward way to specify the maximum number of tokens per batch in a sampler (eg.: BucketBatchSampler)?

Reducing the amount of padding per batch is critical for performance and the BucketBatchSampler class does an excellent job in doing so. However, imo, in NLP tasks the concept of batch_size comes second to the number of tokens per batch since the former is an optimization and the latter a constraint to avoid OutOfMemory errors. (For instance, I can train a batch size of 128 with max_length 100 but not one with max_length 512)

salvacarrion commented 3 years ago

I have written this piece of code. Is it not a clean solution, but it works.

import random
from torch.utils.data.sampler import BatchSampler, RandomSampler, SubsetRandomSampler
from torchnlp.utils import identity

class MaxTokensBatchSampler(BatchSampler):

    def __init__(self,
                 sampler,
                 batch_size,
                 max_tokens,
                 drop_last,
                 sort_key=identity,
                 bucket_size_multiplier=100,
                 shuffle=True):
        super().__init__(sampler, batch_size, drop_last)
        self.max_tokens = max_tokens
        self.sort_key = sort_key
        self.bucket_size_multiplier = bucket_size_multiplier
        self.shuffle = shuffle

        # Not a clean solution
        self.bucket_batches = []
        self._build_buckets()

    def __iter__(self):
        # Iterate over buckets
        for batches, batch_sizes in self.bucket_batches:
            # Shuffle bucket-batch order
            batches = SubsetRandomSampler(batches) if self.shuffle else batches
            for batch in batches:
                if self.shuffle:  # Shuffle inner batch
                    random.shuffle(batch)
                yield batch  # Batch indexes [sent1_idx, sent2_idx,...]

    def __len__(self):
        return sum([len(x[0]) for x in self.bucket_batches])

    def _build_buckets(self):
        # Randomize samples
        tmp_sampler = RandomSampler(self.sampler) if self.shuffle else self.sampler

        # Split samples in N batches (or "buckets")
        tmp_sampler = BatchSampler(tmp_sampler, min(self.batch_size * self.bucket_size_multiplier, len(self.sampler)),
                                   False)

        # Sort samples
        self.bucket_batches = []
        for bucket in tmp_sampler:
            bucket_sorted = sorted([(i, self.sort_key(i)) for i in bucket], key=lambda x: x[1])

            # Create batches constrained
            batches = []
            batch_sizes = []

            last_batch = []
            last_batch_size = 0
            for i, (sample_i, length_i) in enumerate(bucket_sorted):
                if (last_batch_size + length_i) < self.max_tokens:
                    last_batch.append(sample_i)
                    last_batch_size += length_i
                else:
                    # Add batch
                    batches.append(last_batch)
                    batch_sizes.append(last_batch_size)

                    # Add new sample
                    last_batch = [sample_i]
                    last_batch_size = length_i

            # Add last batch
            batches.append(last_batch)
            batch_sizes.append(last_batch_size)

            # Add bucket batches
            self.bucket_batches.append((batches, batch_sizes))

It works as follows:

  1. Randomize all samples/sentences: [0,1,2,3,...n] => [6, 12, 60,... , 31]
  2. Split samples into buckets => [6, 12, 60,...], [92, 1, 52,... , 24], [95, 234, 33,... , 31]
  3. Sort in-bucket by sentence lengths
  4. Shuffle batch-orders in butckets: bucket1 (batch1, batch2,...batchN => batch5, batch10,... batch3), butcket2...
  5. Shuffle batch sentences: batch1: [23, 51, 12...] => [391, 2, 33,...]

You can call using:

train_sampler = MaxTokensBatchSampler(SequentialSampler(train_ds), shuffle=True, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(train_ds.datasets.iloc[i]["src"].split()))
    val_sampler = MaxTokensBatchSampler(SequentialSampler(val_ds), shuffle=False, batch_size=BATCH_SIZE, max_tokens=MAX_TOKENS, drop_last=False, sort_key=lambda i: len(val_ds.datasets.iloc[i]["src"].split()))

train_ds and val_ds are torch Dataset classes: (class TranslationDataset(Dataset):)