NVIDIA / Megatron-LM

Ongoing research training transformer models at scale
https://docs.nvidia.com/megatron-core/developer-guide/latest/user-guide/index.html#quick-start
Other
9.86k stars 2.23k forks source link

[ENHANCEMENT] BlendedDataset in extreme cases can lead to an IndexError. #705

Open ghost opened 6 months ago

ghost commented 6 months ago

⚠️ This discussion is based on the old version of Megatron. This problem still exists in the new version of Megatron.

1. Background

Megatron's data blending algorithm is based on a function in helpers.cpp: build_blending_indices. This function IS NOT ROBUST, and users can construct extreme test cases to cause bugs in Megatron's training.

Specifically, in the __getitem__ method of the BlendableDataset class, sample_idx >= len(self.datasets[dataset_idx]) may appear(Because there is no limit on the growth of current_samples), causing the index to go out of bounds.

Consider two identical datasets being blended, but the weights assigned to them are all extreme, one approaching 0 and one approaching 1, such as 0.001 and 0.999. In this case, according to the algorithm in build_blending_indices, the dataset with the smaller weight will be sampled for more than one epoch, thus causing an IndexError.

2. How to trigger this bug?

Assume that we have generated the corresponding binary files: data_text_document.bin and data_text_document.idx. The files do not need to be too large, just appropriate.

Use the following code to examine this dataset:

from megatron.data.indexed_dataset import make_dataset

data_prefix = '/path/to/data_text_document'

indexed_dataset = make_dataset(data_prefix, 'mmap', True)

sizes = indexed_dataset.sizes
all_token_counts = np.sum(sizes)
print("document count:", len(sizes) / 1000000, "M")
print("all token count:", all_token_counts / 1000000000, "B")
print("average token count:", all_token_counts / len(sizes))
print("max token count:", np.max(sizes))
print("min token count:", np.min(sizes))
print("tp50 token count:", np.percentile(sizes, 50))
print("tp90 token count:", np.percentile(sizes, 90))
print("tp99 token count:", np.percentile(sizes, 99))
print("tp999 token count:", np.percentile(sizes, 99.9))

Output:

document count: 0.0001 M
all token count: 0.000108625 B
average token count: 1086.25
max token count: 4970
min token count: 146
tp50 token count: 793.5
tp90 token count: 2316.4000000000005
tp99 token count: 4306.7000000000035
tp999 token count: 4903.670000000007

It can be seen that this dataset has a total of 100 documents.

Consider blending two identical datasets and set the --train-data-path parameter in the training script like this:

--train-data-path 0.999 /path/to/data_text_document 0.001 /path/to/data_text_document

Then print the following information in the source code of BlendableDataset for observation:

class BlendableDataset(torch.utils.data.Dataset):
    def __init__(self, datasets, weights):
        self.datasets = datasets
        num_datasets = len(datasets)
        assert num_datasets == len(weights)

        for i, dataset in enumerate(self.datasets):
            print_rank_0(f"dataset {i}: {len(dataset)}")  # print information

        self.size = 0
        for dataset in self.datasets:
            self.size += len(dataset)

        # Normalize weights.
        weights = np.array(weights, dtype=np.float64)
        sum_weights = np.sum(weights)
        assert sum_weights > 0.0
        weights /= sum_weights

        # Build indecies.
        start_time = time.time()
        assert num_datasets < 255
        self.dataset_index = np.zeros(self.size, dtype=np.uint8)
        self.dataset_sample_index = np.zeros(self.size, dtype=np.int64)

        from megatron.data import helpers
        helpers.build_blending_indices(self.dataset_index,
                                       self.dataset_sample_index,
                                       weights, num_datasets, self.size,
                                       torch.distributed.get_rank() == 0)
        print_rank_0('> elapsed time for building blendable dataset indices: '
                     '{:.2f} (sec)'.format(time.time() - start_time))

        print_rank_0(f"dataset_index: {self.dataset_index}")  # print information
        print_rank_0(f"dataset_sample_index: {self.dataset_sample_index}")  # print information

Then run the training script and first observe the output:

dataset 0: 1060
dataset 1: 106
dataset_index: [0 1 0 ... 0 0 0]
dataset_sample_index: [   0    0    1 ... 1161 1162 1163]

It is obvious that the size of the first dataset is only 1060, but sample_idx has exceeded this value in several places(1161, 1162, 1163), which will inevitably cause the index to go out of bounds.

If we fix to only take the last sample of the BlendableDataset each time, the bug can be triggered stably:

def __getitem__(self, idx):
    idx = self.size - 1  #  Always take the last sample of BlendableDataset
    dataset_idx = self.dataset_index[idx]
    sample_idx = self.dataset_sample_index[idx]
    return {
        "dataset_idx" : dataset_idx,
        **self.datasets[dataset_idx][sample_idx],  # "text": ndarray
    }

3. How to fix?

Fixing this bug is not difficult, just let sample_idx take the modulus of the length of the corresponding dataset.

github-actions[bot] commented 4 months ago

Marking as stale. No activity in 60 days.