princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
498 stars 38 forks source link

Pruning crash at iteration 592. #32

Open lippman1125 opened 7 months ago

lippman1125 commented 7 months ago

@xiamengzhou [batch=592/3200]
Train time/batch: 591
Train time/sample: 18912
Train time/batch_in_epoch: 591
Train time/sample_in_epoch: 18912
Train time/token: 77463552
Train time/token_in_epoch: 77463552
Train metrics/train/cc_weight: 0.2292 Train metrics/train/github_weight: 0.0121
Train metrics/train/book_weight: 0.0220 Train metrics/train/stackexchange_weight: 0.0059 Train metrics/train/wiki_weight: 0.5933 Train metrics/train/arxiv_weight: 0.0038
Train metrics/train/c4-rp_weight: 0.1336
Train memory/current_allocated_mem: 14.6140 Train memory/current_active_mem: 14.6140
Train memory/current_inactive_mem: 1.9265
Train memory/current_reserved_mem: 43.4220
Train memory/peak_allocated_mem: 28.0710 Train memory/peak_active_mem: 28.0710 Train memory/peak_inactive_mem: 11.7290 Train memory/peak_reserved_mem: 43.4220 Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.3583 Train metrics/train/target_head_sparsity: 0.3463
Train metrics/train/expected_intermediate_sparsity: 0.3196 Train metrics/train/target_intermediate_sparsity: 0.3436
Train metrics/train/expected_layer_sparsity: 0.0039 Train metrics/train/target_layer_sparsity: 0.0000 Train metrics/train/expected_hidden_sparsity: 0.4266
Train metrics/train/target_hidden_sparsity: 0.3463 Train metrics/train/expected_sparsity: 0.6188
Train metrics/train/target_sparsity: 0.5616
Train trainer/device_train_microbatch_size: 4 Train loss/train/total: 3.5578 Train loss/train/ce_loss: 2.8953
Train loss/train/lag_loss: 0.6625 Train metrics/train/LanguageCrossEntropy: 2.8953 Train metrics/train/Perplexity: 18.0886 Train metrics/train/cc_LanguageCrossEntropy: 3.0387 Train metrics/train/cc_count: 9884 Train metrics/train/github_LanguageCrossEntropy: nan Train metrics/train/github_count: 652 Train metrics/train/book_LanguageCrossEntropy: nan Train metrics/train/book_count: 712 Train metrics/train/stackexchange_LanguageCrossEntropy: nan Train metrics/train/stackexchange_count: 236 Train metrics/train/wiki_LanguageCrossEntropy: 2.7964 Train metrics/train/wiki_count: 4011 Train metrics/train/arxiv_LanguageCrossEntropy: nan Train metrics/train/arxiv_count: 267 Train metrics/train/c4-rp_LanguageCrossEntropy: 3.1243 Train metrics/train/c4-rp_count: 3182 Train throughput/batches_per_sec: 0.1329 Train throughput/samples_per_sec: 4.2523 Train throughput/device/batches_per_sec: 0.0166 Train throughput/device/samples_per_sec: 0.5315 Train throughput/tokens_per_sec: 17417.3748 Train throughput/device/tokens_per_sec: 2177.1719 Train throughput/flops_per_sec: 816440956730026.0000 Train throughput/device/flops_per_sec: 102055119591253.2500 Train time/train: 1.2715 Train time/val: 0.6538 Train time/total: 1.9253 Traceback (most recent call last):
File "/llm-shearing//llmshearing/train.py", line 317, in main(cfg)
File "/llm-shearing//llmshearing/train.py", line 301, in main trainer.fit()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 18 76, in fit
self._train_loop()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 20 18, in _train_loop
for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 30 24, in _iter_dataloader batch = next(dataloader_iter) File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in next data = self._next_data() File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", li ne 32, in fetch data.append(next(self.dataset_iter)) File "/llm-shearing/llmshearing/datasets/streaming_dataset.py", line 392, in iter domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \ IndexError: index 552 is out of bounds for axis 0 with size 552

Reproduce as follow: prepare data for prunnig as paper said. then execuate following command /bin/bash llmshearing/scripts/prunning.sh

lippman1125 commented 7 months ago

Line 392 of llmshearing/datasets/streaming_dataset.py

        sample_ids_per_stream = self._get_work(world, epoch, used_sample_ids)
        # Currently only supports dynamically loading data from each domain for once. 
        # Issues could occur if one domain of data is used up. 
        while True:
            proportion = self.proportion
            stream_id = np.random.choice(range(self.num_streams), 1, p=proportion)[0].item()
            domain_sample_id = sample_ids_per_stream[stream_id]
            domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]
            self.used_num_samples_per_stream[stream_id] += 1
            yield self[domain_sample_id]

I think

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]

should change to

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % sample_ids_per_stream[stream_id]]

because sample_ids_per_stream[stream_id] size is smaller than self.samples_per_stream[stream_id]

lippman1125 commented 7 months ago

this code domain_sample_id = sample_ids_per_stream[stream_id] is always return same sample id list

sample_ids_per_stream[stream_id] = self.samples_per_stream[stream_id] / gpu_num / worker_num ?

I also found that always rank0 takes data. debug info as follow:

rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=200
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=642
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=201
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=1012
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=4
self.used_num_samples_per_stream[4]=16
self.samples_per_stream[4]=4394
1 domain_sample_id size=1104
@ domain_sample_id =[76176 76205 76178 ...    -1    -1    -1]
2 domain_sample_id=76208
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=202
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=717
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=203
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=987
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=204
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=729
lippman1125 commented 7 months ago

@xiamengzhou Can you help me look at this issue?

xiamengzhou commented 7 months ago

Hi -- thanks for bringing this to our attention. I think you are correct! However,

Let me know if it helps!

PengWenChen commented 6 months ago

Hi @xiamengzhou, I also encounter the same issue on the same index: IndexError: index 552 is out of bounds for axis 0 with size 552.

As the comment said, Issues could occur if one domain of data is used up. However, I use the same amount of data (0.4B) and same initial proportion as papers. Why data exhausting happens? How to avoid exhausting data of any streams? Thank you.

PengWenChen commented 6 months ago

Thanks to @lippman1125 's advice. I agree with the modification of domain_sample_id. I change the domain_sample_id from domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % self.samples_per_stream[stream_id]] to domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % len(sample_ids_per_stream[stream_id])]

And it works now.