Lightning-AI / litgpt

Pretrain, finetune, deploy 20+ LLMs on your own data. Uses state-of-the-art techniques: flash attention, FSDP, 4-bit, LoRA, and more.
https://lightning.ai
Apache License 2.0
6.85k stars 726 forks source link

After some iteration in pretraining a LLM, IndexError is raised related to dataset chunking #1377

Open MusulmonLolayev opened 2 weeks ago

MusulmonLolayev commented 2 weeks ago

After some iteration, pretraining script suddenly raised IndexError when resume the pretraining from checkpoints. Here some logs:

Epoch 1 | iter 82002 step 41001 | loss train: 1.772, val: n/a | iter time: 243.87 ms (step) remaining time: 395762 days, 18:00:45
Traceback (most recent call last):
  File "/home/user/miniconda3/envs/llama2/bin/litgpt", line 8, in <module>
    sys.exit(main())
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/__main__.py", line 143, in main
    fn(**kwargs)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 121, in setup
    main(
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 213, in main
    fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 265, in fit
    for train_data in train_iterator:
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/utils.py", line 382, in __next__
    return next(self._iterator)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/lightning/fabric/wrappers.py", line 315, in __iter__
    for item in self._dataloader:
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataloader.py", line 598, in __iter__
    for batch in super().__iter__():
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
    return self._process_data(data)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
    data.reraise()
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/_utils.py", line 722, in reraise
    raise exception
IndexError: Caught IndexError in DataLoader worker process 3.
Original Traceback (most recent call last):
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
    fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 187, in __iter__
    self._resume(chunks_replica, intervals_replica)
  File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 246, in _resume
    interval = self.worker_intervals[self.chunk_index]
IndexError: list index out of range

Printing print(self.worker_intervals, self.chunk_index) shows [[251476, 269466], [179572, 197595], [35858, 53769], [53769, 71771]] 4 that there are only 4 items in self.worker_intervals while self.chunk_index is 4. It seems there is todo saying Implement elastic sampling where the number of workers, ranks can change., so it raising that error.