BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
11.99k stars 825 forks source link

如何训练rwkv-5-0.1b,显示权重加载错误 #197

Open enbiwudi opened 8 months ago

enbiwudi commented 8 months ago

显示3072和2688维度不一致,观察:一个是emb乘3.5 另一个是乘4 于是手动将--dim_ffn改为3072,但是:

RuntimeError: Error(s) in loading state_dict for RWKV: Missing key(s) in state_dict: "blocks.0.att.time_mix_g", "blocks.0.att.time_faaaa", "blocks.0.att.gate.weight", "blocks.1.att.time_mix_g", "blocks.1.att.time_faaaa", "blocks.1.att.gate.weight", "blocks.2.att.time_mix_g", "blocks.2.att.time_faaaa", "blocks.2.att.gate.weight", "blocks.3.att.time_mix_g", "blocks.3.att.time_faaaa", "blocks.3.att.gate.weight", "blocks.4.att.time_mix_g", "blocks.4.att.time_faaaa", "blocks.4.att.gate.weight", "blocks.5.att.time_mix_g", "blocks.5.att.time_faaaa", "blocks.5.att.gate.weight", "blocks.6.att.time_mix_g", "blocks.6.att.time_faaaa", "blocks.6.att.gate.weight", "blocks.7.att.time_mix_g", "blocks.7.att.time_faaaa", "blocks.7.att.gate.weight", "blocks.8.att.time_mix_g", "blocks.8.att.time_faaaa", "blocks.8.att.gate.weight", "blocks.9.att.time_mix_g", "blocks.9.att.time_faaaa", "blocks.9.att.gate.weight", "blocks.10.att.time_mix_g", "blocks.10.att.time_faaaa", "blocks.10.att.gate.weight", "blocks.11.att.time_mix_g", "blocks.11.att.time_faaaa", "blocks.11.att.gate.weight". Unexpected key(s) in state_dict: "blocks.0.att.time_first", "blocks.1.att.time_first", "blocks.2.att.time_first", "blocks.3.att.time_first", "blocks.4.att.time_first", "blocks.5.att.time_first", "blocks.6.att.time_first", "blocks.7.att.time_first", "blocks.8.att.time_first", "blocks.9.att.time_first", "blocks.10.att.time_first", "blocks.11.att.time_first". size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.1.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.2.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.3.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.4.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.5.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.6.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.7.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.8.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.9.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.10.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.11.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]).

BlinkDL commented 8 months ago

目前这个0.1b的模型旧一些,用 v4neo 和 --my_testing "r2" 0.4b和更大的模型是正常的rwkv v5.2