princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
563 stars 47 forks source link

Pruning crash at iteration 592. #32

Open lippman1125 opened 12 months ago

lippman1125 commented 12 months ago

@xiamengzhou [batch=592/3200]
Train time/batch: 591
Train time/sample: 18912
Train time/batch_in_epoch: 591
Train time/sample_in_epoch: 18912
Train time/token: 77463552
Train time/token_in_epoch: 77463552
Train metrics/train/cc_weight: 0.2292 Train metrics/train/github_weight: 0.0121
Train metrics/train/book_weight: 0.0220 Train metrics/train/stackexchange_weight: 0.0059 Train metrics/train/wiki_weight: 0.5933 Train metrics/train/arxiv_weight: 0.0038
Train metrics/train/c4-rp_weight: 0.1336
Train memory/current_allocated_mem: 14.6140 Train memory/current_active_mem: 14.6140
Train memory/current_inactive_mem: 1.9265
Train memory/current_reserved_mem: 43.4220
Train memory/peak_allocated_mem: 28.0710 Train memory/peak_active_mem: 28.0710 Train memory/peak_inactive_mem: 11.7290 Train memory/peak_reserved_mem: 43.4220 Train memory/alloc_retries: 0
Train metrics/train/expected_head_sparsity: 0.3583 Train metrics/train/target_head_sparsity: 0.3463
Train metrics/train/expected_intermediate_sparsity: 0.3196 Train metrics/train/target_intermediate_sparsity: 0.3436
Train metrics/train/expected_layer_sparsity: 0.0039 Train metrics/train/target_layer_sparsity: 0.0000 Train metrics/train/expected_hidden_sparsity: 0.4266
Train metrics/train/target_hidden_sparsity: 0.3463 Train metrics/train/expected_sparsity: 0.6188
Train metrics/train/target_sparsity: 0.5616
Train trainer/device_train_microbatch_size: 4 Train loss/train/total: 3.5578 Train loss/train/ce_loss: 2.8953
Train loss/train/lag_loss: 0.6625 Train metrics/train/LanguageCrossEntropy: 2.8953 Train metrics/train/Perplexity: 18.0886 Train metrics/train/cc_LanguageCrossEntropy: 3.0387 Train metrics/train/cc_count: 9884 Train metrics/train/github_LanguageCrossEntropy: nan Train metrics/train/github_count: 652 Train metrics/train/book_LanguageCrossEntropy: nan Train metrics/train/book_count: 712 Train metrics/train/stackexchange_LanguageCrossEntropy: nan Train metrics/train/stackexchange_count: 236 Train metrics/train/wiki_LanguageCrossEntropy: 2.7964 Train metrics/train/wiki_count: 4011 Train metrics/train/arxiv_LanguageCrossEntropy: nan Train metrics/train/arxiv_count: 267 Train metrics/train/c4-rp_LanguageCrossEntropy: 3.1243 Train metrics/train/c4-rp_count: 3182 Train throughput/batches_per_sec: 0.1329 Train throughput/samples_per_sec: 4.2523 Train throughput/device/batches_per_sec: 0.0166 Train throughput/device/samples_per_sec: 0.5315 Train throughput/tokens_per_sec: 17417.3748 Train throughput/device/tokens_per_sec: 2177.1719 Train throughput/flops_per_sec: 816440956730026.0000 Train throughput/device/flops_per_sec: 102055119591253.2500 Train time/train: 1.2715 Train time/val: 0.6538 Train time/total: 1.9253 Traceback (most recent call last):
File "/llm-shearing//llmshearing/train.py", line 317, in main(cfg)
File "/llm-shearing//llmshearing/train.py", line 301, in main trainer.fit()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 18 76, in fit
self._train_loop()
File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 20 18, in _train_loop
for batch_idx, self.state.batch in enumerate(self._iter_dataloader(TrainerMode.TRAIN)): File "/pyenv/py310-shear/lib/python3.10/site-packages/composer/trainer/trainer.py", line 30 24, in _iter_dataloader batch = next(dataloader_iter) File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 630, in next data = self._next_data() File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/pyenv/py310-shear/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", li ne 32, in fetch data.append(next(self.dataset_iter)) File "/llm-shearing/llmshearing/datasets/streaming_dataset.py", line 392, in iter domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \ IndexError: index 552 is out of bounds for axis 0 with size 552

Reproduce as follow: prepare data for prunnig as paper said. then execuate following command /bin/bash llmshearing/scripts/prunning.sh

lippman1125 commented 12 months ago

Line 392 of llmshearing/datasets/streaming_dataset.py

        sample_ids_per_stream = self._get_work(world, epoch, used_sample_ids)
        # Currently only supports dynamically loading data from each domain for once. 
        # Issues could occur if one domain of data is used up. 
        while True:
            proportion = self.proportion
            stream_id = np.random.choice(range(self.num_streams), 1, p=proportion)[0].item()
            domain_sample_id = sample_ids_per_stream[stream_id]
            domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]
            self.used_num_samples_per_stream[stream_id] += 1
            yield self[domain_sample_id]

I think

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % self.samples_per_stream[stream_id]]

should change to

domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] \
                                % sample_ids_per_stream[stream_id]]

because sample_ids_per_stream[stream_id] size is smaller than self.samples_per_stream[stream_id]

lippman1125 commented 12 months ago

this code domain_sample_id = sample_ids_per_stream[stream_id] is always return same sample id list

sample_ids_per_stream[stream_id] = self.samples_per_stream[stream_id] / gpu_num / worker_num ?

I also found that always rank0 takes data. debug info as follow:

rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=200
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=642
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=201
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=1012
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=4
self.used_num_samples_per_stream[4]=16
self.samples_per_stream[4]=4394
1 domain_sample_id size=1104
@ domain_sample_id =[76176 76205 76178 ...    -1    -1    -1]
2 domain_sample_id=76208
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=202
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=717
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=203
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=987
rank=0, num_ranks=4, ranks_per_node=4
proportion=[0.67, 0.045, 0.045, 0.02, 0.045, 0.025, 0.15], stream_id=0
self.used_num_samples_per_stream[0]=204
self.samples_per_stream[0]=65431
1 domain_sample_id size=16360
@ domain_sample_id =[  257    95   275 ... 32631    -1    -1]
2 domain_sample_id=729
lippman1125 commented 11 months ago

@xiamengzhou Can you help me look at this issue?

xiamengzhou commented 11 months ago

Hi -- thanks for bringing this to our attention. I think you are correct! However,

Let me know if it helps!

PengWenChen commented 11 months ago

Hi @xiamengzhou, I also encounter the same issue on the same index: IndexError: index 552 is out of bounds for axis 0 with size 552.

As the comment said, Issues could occur if one domain of data is used up. However, I use the same amount of data (0.4B) and same initial proportion as papers. Why data exhausting happens? How to avoid exhausting data of any streams? Thank you.

PengWenChen commented 10 months ago

Thanks to @lippman1125 's advice. I agree with the modification of domain_sample_id. I change the domain_sample_id from domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % self.samples_per_stream[stream_id]] to domain_sample_id = domain_sample_id[self.used_num_samples_per_stream[stream_id] % len(sample_ids_per_stream[stream_id])]

And it works now.