princeton-nlp / LLM-Shearing

[ICLR 2024] Sheared LLaMA: Accelerating Language Model Pre-training via Structured Pruning
https://arxiv.org/abs/2310.06694
MIT License
552 stars 44 forks source link

LanguageCrossEntropy logs nan when bash pruning.sh #9

Closed Longyichen closed 11 months ago

Longyichen commented 11 months ago

When I conducted the pruning experiment, I simply configured the data set and made no other changes. I found that it seems that the metric is not updated, and the log repeatedly prints loss as nan, as follows:

[metric][batch=0]: time/epoch: 2365 
[metric][batch=0]: metrics/train/LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Perplexity: nan 
[metric][batch=0]: metrics/train/ArXiv_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/ArXiv_count: 0 
[metric][batch=0]: metrics/train/Books_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Books_count: 0 
[metric][batch=0]: metrics/train/Wikipedia_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Wikipedia_count: 0 
[metric][batch=0]: time/epoch: 2366 
[metric][batch=0]: metrics/train/LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Perplexity: nan 
[metric][batch=0]: metrics/train/ArXiv_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/ArXiv_count: 0 
[metric][batch=0]: metrics/train/Books_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Books_count: 0 
[metric][batch=0]: metrics/train/Wikipedia_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Wikipedia_count: 0 
.... repeat

I set pdb breakpoints in metric's update function and composerllama's update_mteric, but these breakpoints were not executed. The input data seems to be intact. I tested trainloader and train.eval and everything is normal. However, this problem inevitably occurs in train.fit.

The Setting of pruning:

# learning setup
lr=1e-4 # learning rate for the main parameters
max_duration=3200ba # 0.42B tokens
save_interval=3200ba # save in the end
t_warmup=320ba # 10% learning rate warmup 

# dynamic loading setup
dynamic=True
# set_names=[cc,github,book,stackexchange,wiki,arxiv,c4-rp] # domain names
set_names=[ArXiv,Books,Wikipedia] # domain names
# proportion=[0.67,0.045,0.045,0.02,0.045,0.025,0.15] # initial proportion of RP, make sure that the sum(proportion) = 1
proportion=[0.4,0.3,0.3] 
# doremi: update weights with exponential descent
# constant: keep the weights constant
update_type=doremi 
if [[ $to_model == 1.3b ]]; then
    # target_loss=[1.9643,0.7459,2.1393,1.6117,1.7590,1.4449,2.1251] # 1.3b predicted loss from scaling law
    target_loss=[1.4449,2.1393,1.7590]
else
    # target_loss=[1.8712,0.6883,2.0325,1.5353,1.6297,1.3560,2.0328] # 2.7b predicted loss from scaling law
    target_loss=[1.3560,2.0325,1.6297]
fi
eval_split_name=eval_merge # eval on all domains
eval_target_model=false # evaluate on the current model, not the target model, otherwise the loss will be inaccurate
eval_interval=50ba # eval every 50 batches and update the loading proportion

# pruning setup
lag_lr=1.0 # learning rate or l0_module
lagr_warmup=640ba # 20% sparsity warmup
if [[ $to_model == 1.3b ]]; then
    target_d_model=2048; target_n_heads=16; target_n_layers=24; target_intermediate_size=5504
elif [[ $to_model == 3b ]]; then
    target_d_model=2560; target_n_heads=20; target_n_layers=32; target_intermediate_size=6912
fi

composer $TRAIN_SCRIPT \
    $config_file \
    run_name=${run_name} \
    data_local=${data_local} \
    eval_loader.dataset.split=${eval_split_name} \
    global_train_batch_size=${global_train_batch_size} \
    device_train_microbatch_size=${device_train_microbatch_size} \
    device_eval_batch_size=${device_eval_batch_size} \
    max_seq_len=${max_seq_len} \
    max_duration=${max_duration} \
    eval_first=false \
    scheduler.t_warmup=${t_warmup} \
    save_folder=${save_dir} \
    loggers.wandb.init_kwargs.dir=${wandb_dir} \
    eval_interval=${eval_interval} \
    save_interval=${save_interval} \
    optimizer.lr=${lr} \
    optimizer.lag_lr=${lag_lr} \
    model.l0_module.lagrangian_warmup_steps=${lagr_warmup} \
    model.l0_module.pruning_modules='[head,intermediate,layer,hidden]' \
    model.l0_module.eval_target_model=${eval_target_model} \
    model.l0_module.target_model.d_model=${target_d_model} \
    model.l0_module.target_model.n_heads=${target_n_heads} \
    model.l0_module.target_model.n_layers=${target_n_layers} \
    model.l0_module.target_model.intermediate_size=${target_intermediate_size} \
    callbacks.data_loading.dynamic=${dynamic} \
    callbacks.data_loading.set_names=${set_names} \
    callbacks.data_loading.proportion=${proportion} \
    callbacks.data_loading.update_type=${update_type} \
    callbacks.data_loading.target_loss=${target_loss} \
    train_loader.num_workers=0 \
    train_loader.prefetch_factor=null \
    train_loader.persistent_workers=false \
    autoresume=false
Longyichen commented 11 months ago

As the error message says, it may be caused by the lack of upadte.

[stderr]: ******************************
[metric][batch=0]: time/epoch: 0 
[stderr]: /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:42: UserWarning: The ``compute`` method of metric LanguageCrossEntropy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
[stderr]:   warnings.warn(*args, **kwargs)  # noqa: B028
[stderr]: /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:42: UserWarning: The ``compute`` method of metric LanguagePerplexity was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
[stderr]:   warnings.warn(*args, **kwargs)  # noqa: B028
[stderr]: /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:42: UserWarning: The ``compute`` method of metric DomainLanguageCrossEntropy was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
[stderr]:   warnings.warn(*args, **kwargs)  # noqa: B028
[stderr]: /root/miniconda3/envs/shearing/lib/python3.10/site-packages/torchmetrics/utilities/prints.py:42: UserWarning: The ``compute`` method of metric DomainCount was called before the ``update`` method which may lead to errors, as metric states have not yet been updated.
[stderr]:   warnings.warn(*args, **kwargs)  # noqa: B028
[metric][batch=0]: metrics/train/LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Perplexity: nan 
[metric][batch=0]: metrics/train/ArXiv_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/ArXiv_count: 0 
[metric][batch=0]: metrics/train/Books_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Books_count: 0 
[metric][batch=0]: metrics/train/Wikipedia_LanguageCrossEntropy: nan 
[metric][batch=0]: metrics/train/Wikipedia_count: 0 
xiamengzhou commented 11 months ago

It could be that the dataset you've created may not include the set entry. The count of each domain and the corresponding Cross-Entropy (CE) values are computed based on each block's data's set name. Could you check your mds files and see if it is the case?

Longyichen commented 11 months ago

It could be that the dataset you've created may not include the set entry. The count of each domain and the corresponding Cross-Entropy (CE) values are computed based on each block's data's set name. Could you check your mds files and see if it is the case?

hi i list the name in mds. Actually i change the set name to the file name in dataset

ls mds_redpajama/for_prune/
ArXiv/  Books/  eval_merge/  train_small/  Wikipedia/
xiamengzhou commented 11 months ago

Do you have an entry of set for each data point in your mds files for each domain? And does it correspond to the set_names you pass to the script?

This line here collects each data point's set entry and use it as part of the input to the model.

Longyichen commented 11 months ago

Do you have an entry of set for each data point in your mds files for each domain? And does it correspond to the set_names you pass to the script?

This line here collects each data point's set entry and use it as part of the input to the model.

I looked at the index file of each domain, which includes set, such as:

image

I directly injected print into the line of code you provided, and the print result is as follows:

print('batch: ', batch)
        print('examples: ', [example["set"] for example in examples] ) 
        batch["set"] = torch.tensor(
            [self.set_name_to_id[example["set"]] for example in examples])
        print('batch["set"]',batch["set"])
---
batch:  batch:  batch: {'input_ids': tensor([[  372, 29889,  1334,  ...,   540,  1497, 29889],
        [  372, 29889,  1334,  ...,   540,  1497, 29889],
        [  372, 29889,  1334,  ...,   540,  1497, 29889],
        ...,
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892],
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892],
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892]]), 'labels': tensor([[  372, 29889,  1334,  ...,   540,  1497, 29889],
        [  372, 29889,  1334,  ...,   540,  1497, 29889],
        [  372, 29889,  1334,  ...,   540,  1497, 29889],
        ...,
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892],
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892],
        [ 3431, 29889, 29871,  ..., 29953, 29900, 29892]])} 
examples:  ['Books', 'Books', 'Books', 'Wikipedia', 'Wikipedia', 'Wikipedia', 'Wikipedia', 'Wikipedia']
batch["set"] tensor([1, 1, 1, 2, 2, 2, 2, 2])

This result seems that each of the examples corresponds to the domain to which each piece of data belongs. Is this result normal?

Longyichen commented 11 months ago

I found that aligning the names of all folders with the names of the author folders solved the problem. issue closed

YanxiZSQ commented 11 months ago

我发现将所有文件夹的名称与作者文件夹的名称对齐可以解决问题。问题已结束

I have same issue, how fix it?