alibaba / Pai-Megatron-Patch

The official repo of Pai-Megatron-Patch for LLM & VLM large scale training developed by Alibaba Cloud.
Apache License 2.0
644 stars 93 forks source link

finetune qwen1.4-4B with tp=2, failed when load model with embedding_shape. #184

Closed hudengjunai closed 4 months ago

hudengjunai commented 4 months ago

training qwe1.5-4b with tp=2 failed with embedding-table load error.

I start train job as follow.

convert model

I convert the qwen1.5-4B model to tp=2 and pp=1,with command

/workspace/llm_train/Pai-Megatron-Patch/toolkits/model_checkpoints_convertor/qwen# bash hf2megatron_convertor.sh  ../../../Megatron-LM-231007/ /hf_cache/hub/Qwen1.5-4B   /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21/ 2 1 qwen1.5 0 false
Zarr-based strategies will not be registered because of missing packages
Converting
converting embedding layer
converting transformer layers
0 min 20 sec

start training job

then i start the train job with

set -ex
export WORK_DIR=/workspace/llm_train
cd ${WORK_DIR}/Pai-Megatron-Patch/examples/qwen1_5
sh run_finetune_megatron_qwen.sh \
        qs \
        ${WORK_DIR}/Pai-Megatron-Patch \
        4B 1 1e-5 1e-6 80 81 \
        1 \
        bf16 \
        2 1 sel true true true false  \
        /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-train.json \
        /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-valid.json \
        /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21  \
        2 \
        ${WORK_DIR}/output_patch_test 

the start job log show

+ cd /workspace/llm_train/Pai-Megatron-Patch/examples/qwen1_5
+ sh run_finetune_megatron_qwen.sh qs /workspace/llm_train/Pai-Megatron-Patch 4B 1 1e-5 1e-6 80 81 1 bf16 2 1 sel true true true false /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-train.json /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-valid.json /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21 2 /workspace/llm_train/output_patch_test /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-train.json /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-valid.json
torchrun --nproc_per_node 4 --nnodes 1 --node_rank 0 --master_addr localhost --master_port 63954 ../llama2/finetune_megatron_llama.py --load /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21 --save /workspace/llm_train/output_patch_test/checkpoint/qs-finetune-megatron-llama-4B-lr-1e-5-ep-2-bs-1-seqlen-80-pr-bf16--do-true-tp-2-ac-sel-sp-true --train-data /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-train.json --valid-data /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-valid.json --train-data-path /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-train.json --valid-data-path /workspace/llm_train/qwen-datasets/alpaca_zh-qwen-valid.json --num-layers 40 --hidden-size 2560 --num-attention-heads 20 --seq-length 80 --max-position-embeddings 80 --ffn-hidden-size 6912 --keep-last --micro-batch-size 1 --epochs 2 --lr 1e-5 --min-lr 1e-6 --lr-decay-style cosine --weight-decay 0.1 --clip-grad 1.0 --adam-beta1 0.9 --adam-beta2 0.95 --init-method-std 0.01 --num-workers 0 --log-interval 1 --eval-interval 1000 --eval-iters 10 --save-interval 1000000 --tensorboard-queue-size 1 --dataset LLama-SFT --tensorboard-dir /workspace/llm_train/output_patch_test/tensorboard/qs-finetune-megatron-llama-4B-lr-1e-5-ep-2-bs-1-seqlen-80-pr-bf16--do-true-tp-2-ac-sel-sp-true_2024.04.17-12.11.14 --log-timers-to-tensorboard --log-batch-size-to-tensorboard --log-validation-ppl-to-tensorboard --tensor-model-parallel-size 2 --pipeline-model-parallel-size 1 --finetune --no-load-optim --no-load-rng --seed 1234 --max-padding-length 81 --extra-vocab-size 1 --patch-tokenizer-type LLamaTokenizer --swiglu --normalization RMSNorm --use-llama2-rotary-position-embeddings --position-embedding-type rope --untie-embeddings-and-output-weights --rotary-base 1000000 --rotary-scale-factor 1 --bf16 --load /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21 --recompute-activations --use-distributed-optimizer --use-flash-attn --sequence-parallel
WARNING:torch.distributed.run:

training error

  storage = bucket.data.storage()._untyped()
 loading release checkpoint from /hf_cache/hub/models--Qwen--Qwen1.5-4B-Chat-hf2mg21
Traceback (most recent call last):
  File "/workspace/llm_train/Pai-Megatron-Patch/examples/qwen1_5/../llama2/finetune_megatron_llama.py", line 88, in <module>
    finetune(train_valid_datasets_provider=train_valid_datasets_provider,
  File "/workspace/llm_train/Pai-Megatron-Patch/megatron_patch/finetune_utils.py", line 310, in finetune
    model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
  File "/workspace/llm_train/Pai-Megatron-Patch/Megatron-LM-231007/megatron/training.py", line 382, in setup_model_and_optimizer
    args.iteration = load_checkpoint(model, optimizer, opt_param_scheduler)
  File "/workspace/llm_train/Pai-Megatron-Patch/Megatron-LM-231007/megatron/checkpointing.py", line 586, in load_checkpoint
    model[0].load_state_dict(state_dict['model'], strict=strict)
  File "/workspace/llm_train/Pai-Megatron-Patch/megatron_patch/model/llama2/gpt_model.py", line 132, in load_state_dict
    self.language_model.load_state_dict(state_dict, strict=strict)
  File "/workspace/llm_train/Pai-Megatron-Patch/megatron_patch/model/llama2/language_model.py", line 603, in load_state_dict
    self.embedding.load_state_dict(state_dict_, strict=strict)
  File "/workspace/llm_train/Pai-Megatron-Patch/megatron_patch/model/llama2/language_model.py", line 285, in load_state_dict
    self.word_embeddings.load_state_dict(state_dict_, strict=strict)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 2040, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for VocabParallelEmbedding:
        size mismatch for weight: copying a param with shape torch.Size([75968, 2560]) from checkpoint, the shape in current model is torch.Size([75822, 2560]).
hudengjunai commented 4 months ago

I have two questions

  1. I want to finetune qwen1.5-4b model. but the. Pai-Megatron-Patch/examples/qwen1_5/run_finetune_megatron_qwen.sh start python job with ../llama2/finetune_megatron_llama.py to train. instread of suppose qwen1_5/finetune_megagtron_qwen1_5.py ?
  2. I checked the converted megatron-model and the tp=2 converted model embedding_size is 151936/2. but the megatron model defined is 75968.
    RuntimeError: Error(s) in loading state_dict for VocabParallelEmbedding:
        size mismatch for weight: copying a param with shape torch.Size([75968, 2560]) from checkpoint, the shape in current model is torch.Size([75822, 2560]).
lwmlyy commented 4 months ago
  1. qwen1.5 and llama2 have the same architecture now except that qwen1.5 has some bias, so they can share the same finetune script.
  2. please set the EXTRA_VOCAB_SIZE to 293 (151936-151643) when you run the script for qwen1.5-4B. Also refer to this issue for details.
jerryli1981 commented 4 months ago

Qwen1.5的Quick Start已经更新,烦请pull下最新的代码重新测试下