BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
11.99k stars 825 forks source link

fintune RWKV5-7B Missing key(s) in state_dict: #228

Open liuao743 opened 4 months ago

liuao743 commented 4 months ago

!/bin/bash

BASE_NAME="./model/models--RWKV--HF_v5-Eagle-7B/snapshots/bb01ae9434eb9f4934c1ebe486eb7d3e25883d72/pytorch_model.bin" N_LAYER="32" N_EMBD="4096" M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM) LR_INIT="1e-5" LR_FINAL="1e-5" GRAD_CP=0 # set to 1 to save VRAM (will be slower) EPOCH_SAVE=10

magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)

use https://www.dcode.fr/prime-numbers-search

python3 train.py --load_model $BASE_NAME \ --epoch_count 999999 --epoch_begin 0 \ --data_file "data/train_emotion" --data_type "binidx" --my_exit_tokens 1498226207 --magic_prime 2926181 \ --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \ --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \ --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \ --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --enable_progress_bar True --ds_bucket_mb 200

Error: RuntimeError: Error(s) in loading state_dict for RWKV: Missing key(s) in state_dict: "emb.weight", "blocks.0.ln1.weight", "blocks.0.ln1.bias", "blocks.0.ln2.weight", "blocks.0.ln2.bias", "blocks.0.ln0.weight", "blocks.0.ln0.bias", "blocks.0.att.time_mix_k", "blocks.0.att.time_mix_v", "blocks.0.att.time_mix_r", "blocks.0.att.time_mix_g", "blocks.0.att.time_decay", "blocks.0.att.time_faaaa", "blocks.0.att.receptance.weight", "blocks.0.att.key.weight", "blocks.0.att.value.weight", "blocks.0.att.output.weight", "blocks.0.att.gate.weight", "blocks.0.att.ln_x.weight", "blocks.0.att.ln_x.bias", "blocks.0.ffn.time_mix_k", "blocks.0.ffn.time_mix_r", "blocks.0.ffn.key.weight", "blocks.0.ffn.receptance.weight", "blocks.0.ffn.value.weight", "blocks.1.ln1.weight", "blocks.1.ln1.bias", "blocks.1.ln2.weight", "blocks.1.ln2.bias", "blocks.1.att.time_mix_k", "blocks.1.att.time_mix_v", "blocks.1.att.time_mix_r", "blocks.1.att.time_mix_g", "blocks.1.att.time_decay", "blocks.1.att.time_faaaa", "blocks.1.att.receptance.weight", "blocks.1.att.key.weight", "blocks.1.att.value.weight", "blocks.1.att.output.weight", "blocks.1.att.gate.weight", "blocks.1.att.ln_x.weight", "blocks.1.att.ln_x.bias", "blocks.1.ffn.time_mix_k", "blocks.1.ffn.time_mix_r", "blocks.1.ffn.key.weight", "blocks.1.ffn.receptance.weight", "blocks.1.ffn.value.weight", "blocks.2.ln1.weight", "blocks.2.ln1.bias", "blocks.2.ln2.weight", "blocks.2.ln2.bias", "blocks.2.att.time_mix_k"

BlinkDL commented 4 months ago

BASE_NAME should be a .pth file when you are using RWKV-LM

such as https://huggingface.co/BlinkDL/rwkv-5-world/blob/main/RWKV-5-World-7B-v2-20240128-ctx4096.pth