After some iteration, pretraining script suddenly raised IndexError when resume the pretraining from checkpoints. Here some logs:
Epoch 1 | iter 82002 step 41001 | loss train: 1.772, val: n/a | iter time: 243.87 ms (step) remaining time: 395762 days, 18:00:45
Traceback (most recent call last):
File "/home/user/miniconda3/envs/llama2/bin/litgpt", line 8, in <module>
sys.exit(main())
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/__main__.py", line 143, in main
fn(**kwargs)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 121, in setup
main(
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 213, in main
fit(fabric, devices, state, train_dataloader, val_dataloader, out_dir, tokenizer_dir, train, eval)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/pretrain.py", line 265, in fit
for train_data in train_iterator:
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litgpt/utils.py", line 382, in __next__
return next(self._iterator)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/lightning/fabric/wrappers.py", line 315, in __iter__
for item in self._dataloader:
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataloader.py", line 598, in __iter__
for batch in super().__iter__():
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
data = self._next_data()
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1346, in _next_data
return self._process_data(data)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1372, in _process_data
data.reraise()
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/_utils.py", line 722, in reraise
raise exception
IndexError: Caught IndexError in DataLoader worker process 3.
Original Traceback (most recent call last):
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 252, in _worker_loop
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
self.dataset_iter = iter(dataset)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 187, in __iter__
self._resume(chunks_replica, intervals_replica)
File "/home/user/miniconda3/envs/llama2/lib/python3.9/site-packages/litdata/streaming/dataset.py", line 246, in _resume
interval = self.worker_intervals[self.chunk_index]
IndexError: list index out of range
Printing print(self.worker_intervals, self.chunk_index) shows [[251476, 269466], [179572, 197595], [35858, 53769], [53769, 71771]] 4 that there are only 4 items in self.worker_intervals while self.chunk_index is 4. It seems there is todo saying Implement elastic sampling where the number of workers, ranks can change., so it raising that error.
After some iteration, pretraining script suddenly raised IndexError when resume the pretraining from checkpoints. Here some logs:
Printing
print(self.worker_intervals, self.chunk_index)
shows[[251476, 269466], [179572, 197595], [35858, 53769], [53769, 71771]] 4
that there are only 4 items inself.worker_intervals
whileself.chunk_index
is 4. It seems there is todo sayingImplement elastic sampling where the number of workers, ranks can change.
, so it raising that error.