RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
RuntimeError: Error(s) in loading state_dict for RWKV: Missing key(s) in state_dict: "blocks.0.att.time_mix_g", "blocks.0.att.time_faaaa", "blocks.0.att.gate.weight", "blocks.1.att.time_mix_g", "blocks.1.att.time_faaaa", "blocks.1.att.gate.weight", "blocks.2.att.time_mix_g", "blocks.2.att.time_faaaa", "blocks.2.att.gate.weight", "blocks.3.att.time_mix_g", "blocks.3.att.time_faaaa", "blocks.3.att.gate.weight", "blocks.4.att.time_mix_g", "blocks.4.att.time_faaaa", "blocks.4.att.gate.weight", "blocks.5.att.time_mix_g", "blocks.5.att.time_faaaa", "blocks.5.att.gate.weight", "blocks.6.att.time_mix_g", "blocks.6.att.time_faaaa", "blocks.6.att.gate.weight", "blocks.7.att.time_mix_g", "blocks.7.att.time_faaaa", "blocks.7.att.gate.weight", "blocks.8.att.time_mix_g", "blocks.8.att.time_faaaa", "blocks.8.att.gate.weight", "blocks.9.att.time_mix_g", "blocks.9.att.time_faaaa", "blocks.9.att.gate.weight", "blocks.10.att.time_mix_g", "blocks.10.att.time_faaaa", "blocks.10.att.gate.weight", "blocks.11.att.time_mix_g", "blocks.11.att.time_faaaa", "blocks.11.att.gate.weight". Unexpected key(s) in state_dict: "blocks.0.att.time_first", "blocks.1.att.time_first", "blocks.2.att.time_first", "blocks.3.att.time_first", "blocks.4.att.time_first", "blocks.5.att.time_first", "blocks.6.att.time_first", "blocks.7.att.time_first", "blocks.8.att.time_first", "blocks.9.att.time_first", "blocks.10.att.time_first", "blocks.11.att.time_first". size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.1.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.2.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.3.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.4.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.5.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.6.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.7.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.8.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.9.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.10.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.11.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]).
显示3072和2688维度不一致,观察:一个是emb乘3.5 另一个是乘4 于是手动将--dim_ffn改为3072,但是:
RuntimeError: Error(s) in loading state_dict for RWKV: Missing key(s) in state_dict: "blocks.0.att.time_mix_g", "blocks.0.att.time_faaaa", "blocks.0.att.gate.weight", "blocks.1.att.time_mix_g", "blocks.1.att.time_faaaa", "blocks.1.att.gate.weight", "blocks.2.att.time_mix_g", "blocks.2.att.time_faaaa", "blocks.2.att.gate.weight", "blocks.3.att.time_mix_g", "blocks.3.att.time_faaaa", "blocks.3.att.gate.weight", "blocks.4.att.time_mix_g", "blocks.4.att.time_faaaa", "blocks.4.att.gate.weight", "blocks.5.att.time_mix_g", "blocks.5.att.time_faaaa", "blocks.5.att.gate.weight", "blocks.6.att.time_mix_g", "blocks.6.att.time_faaaa", "blocks.6.att.gate.weight", "blocks.7.att.time_mix_g", "blocks.7.att.time_faaaa", "blocks.7.att.gate.weight", "blocks.8.att.time_mix_g", "blocks.8.att.time_faaaa", "blocks.8.att.gate.weight", "blocks.9.att.time_mix_g", "blocks.9.att.time_faaaa", "blocks.9.att.gate.weight", "blocks.10.att.time_mix_g", "blocks.10.att.time_faaaa", "blocks.10.att.gate.weight", "blocks.11.att.time_mix_g", "blocks.11.att.time_faaaa", "blocks.11.att.gate.weight". Unexpected key(s) in state_dict: "blocks.0.att.time_first", "blocks.1.att.time_first", "blocks.2.att.time_first", "blocks.3.att.time_first", "blocks.4.att.time_first", "blocks.5.att.time_first", "blocks.6.att.time_first", "blocks.7.att.time_first", "blocks.8.att.time_first", "blocks.9.att.time_first", "blocks.10.att.time_first", "blocks.11.att.time_first". size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.1.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.2.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.3.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.4.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.5.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.6.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.7.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.8.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.9.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.10.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]). size mismatch for blocks.11.att.time_decay: copying a param with shape torch.Size([12]) from checkpoint, the shape in current model is torch.Size([12, 64]).