RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
AssertionError while finetuning RWKVv5 #216

Open Ethan-Chen-plus opened 5 months ago

Ethan-Chen-plus commented 5 months ago

While finetuning RWKV, I use this script(using demo dataset by and put demo.bin and demo.idx in ./data):


M_BSZ="16" # takes 16G VRAM (reduce this to save VRAM)
GRAD_CP=0 # set to 1 to save VRAM (will be slower)

# magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 1498226207/512-1 = 2926222.06 in this case)
# use

python --load_model "../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth" --wandb "RWKV-5-Test" --proj_dir $BASE_NAME \
 --ctx_len 512 --my_pile_stage 3 --epoch_count 999999 --epoch_begin 0 \
 --data_file "data/demo" --my_exit_tokens 1498226207 --magic_prime 2926181 \
 --num_nodes 1 --micro_bsz $M_BSZ --n_layer $N_LAYER --n_embd $N_EMBD --pre_ffn 0 --head_qk 0 \
 --lr_init $LR_INIT --lr_final $LR_FINAL --warmup_steps 10 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --my_pile_edecay 0 --data_type "binidx" --vocab_size 65536 \
 --weight_decay 0.001 --epoch_save $EPOCH_SAVE --head_size_a 64 \
 --accelerator gpu --devices 1 --precision bf16 --strategy deepspeed_stage_2 --grad_cp $GRAD_CP --ds_bucket_mb 200

I caught this error:

(rwkv) ubuntu@ip-172-31-67-197:~/MedicalGPT/rwkv/RWKV-LM/RWKV-v5$ CUDA_VISIBLE_DEVICES=2 bash
INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpw45qi2d_
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpw45qi2d_/
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-13 12:22:27,924] [INFO] [] Setting ds_accelerator to cuda (auto detect)
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
# Data = data/demo (binidx), ProjDir = model/demo
# Epoch = 0 to 71 (will continue afterwards), save every 10 epoch
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/", line 248, in <module>
    train_data = MyDataset(args)
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/src/", line 56, in __init__
    assert args.magic_prime / dataset_slot > 0.99 and args.magic_prime / dataset_slot <= 1
BlinkDL commented 5 months ago

Data has 200499 tokens

therefore set my_exit_tokens to 200499, and note: magic_prime = the largest 3n+2 prime smaller than datalen/ctxlen-1 (= 200499 /512-1 = 390.599609375 in this case) use

therefore set magic_prime = 389

Ethan-Chen-plus commented 5 months ago

Thanks for answering.But still some errors occur:


INFO:torch.distributed.nn.jit.instantiator:Created a temporary directory at /tmp/tmpyy254i6t
INFO:torch.distributed.nn.jit.instantiator:Writing /tmp/tmpyy254i6t/
INFO:pytorch_lightning.utilities.rank_zero:########## work in progress ##########
[2024-01-16 12:35:16,735] [INFO] [] Setting ds_accelerator to cuda (auto detect)
# RWKV-5 BF16 on 1x1 GPU, bsz 1x1x16=16, deepspeed_stage_2 
# Data = data/demo (binidx), ProjDir = model/demo
# Epoch = 0 to -1 (will continue afterwards), save every 10 epoch
# Each "epoch" = 2520 steps, 40320 samples, 20643840 tokens
# Model = 12 n_layer, 768 n_embd, 512 ctx_len
# Adam = lr 0.0006 to 6e-05, warmup 10 steps, beta (0.9, 0.99), eps 1e-08
# Found torch 1.13.1+cu117, recommend 1.13.1+cu117 or newer
# Found deepspeed 0.12.6, recommend 0.7.0 (faster than newer versions)
# Found pytorch_lightning 1.9.5, recommend 1.9.5

INFO:pytorch_lightning.utilities.rank_zero:{'load_model': '../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth', 'wandb': 'RWKV-5-Test', 'proj_dir': 'model/demo', 'data_file': 'data/demo', 'data_type': 'binidx', 'vocab_size': 65536, 'ctx_len': 512, 'micro_bsz': 16, 'n_layer': 12, 'n_embd': 768}

INFO:pytorch_lightning.utilities.rank_zero:Current vocab size = 65536 (make sure it's correct)
INFO:pytorch_lightning.utilities.rank_zero:Data has 200499 tokens.
INFO:pytorch_lightning.utilities.rank_zero:########## Pile 20b-tokenized stage 3 ##########
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/wkv5/
Building extension module wkv5...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
[1/2] /usr/bin/g++-10 -MMD -MF wkv5_op.o.d -DTORCH_EXTENSION_NAME=wkv5 -DTORCH_API_INCLUDE_EXTENSION_H -DPYBIND11_COMPILER_TYPE=\"_gcc\" -DPYBIND11_STDLIB=\"_libstdcpp\" -DPYBIND11_BUILD_ABI=\"_cxxabi1011\" -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/torch/csrc/api/include -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/TH -isystem /home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/include/THC -isystem /home/ubuntu/micromamba/envs/rwkv/include/python3.10 -D_GLIBCXX_USE_CXX11_ABI=0 -fPIC -std=c++14 -c /home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/cuda/wkv5_op.cpp -o wkv5_op.o 
[2/2] /usr/bin/g++-10 wkv5_op.o wkv5_cuda.cuda.o -shared -L/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/lib -lc10 -lc10_cuda -ltorch_cpu -ltorch_cuda_cu -ltorch_cuda_cpp -ltorch -ltorch_python -L/usr/lib64 -lcudart -o
Loading extension module wkv5...
INFO:pytorch_lightning.utilities.rank_zero:########## Loading ../../models/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth... ##########
Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/", line 284, in <module>
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/nn/modules/", line 1671, in load_state_dict
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for RWKV:
        size mismatch for emb.weight: copying a param with shape torch.Size([65536, 1024]) from checkpoint, the shape in current model is torch.Size([65536, 768]).
        size mismatch for blocks.0.ln1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln2.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ln0.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_v: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_mix_g: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.att.time_decay: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.time_faaaa: copying a param with shape torch.Size([16, 64]) from checkpoint, the shape in current model is torch.Size([12, 64]).
        size mismatch for blocks.0.att.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.key.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.value.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.output.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.gate.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.att.ln_x.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.att.ln_x.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
        size mismatch for blocks.0.ffn.time_mix_k: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.time_mix_r: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
        size mismatch for blocks.0.ffn.key.weight: copying a param with shape torch.Size([3584, 1024]) from checkpoint, the shape in current model is torch.Size([2688, 768]).
        size mismatch for blocks.0.ffn.receptance.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([768, 768]).
        size mismatch for blocks.0.ffn.value.weight: copying a param with shape torch.Size([1024, 3584]) from checkpoint, the shape in current model is torch.Size([768, 2688]).
BlinkDL commented 5 months ago

for 0.4B finetuning, set: N_LAYER="24" N_EMBD="1024" LR_INIT="2e-5" LR_FINAL="2e-5" GRAD_CP="1"

Ethan-Chen-plus commented 5 months ago

Thanks for helping! But I wonder why set LR_INIT==LR_FINAL? Another Question is that if I set GRAD_CP=0, the cost of mem will be more and I will receive OOM.

INFO:pytorch_lightning.strategies.deepspeed:initializing deepspeed distributed: GLOBAL_RANK: 0, MEMBER: 1/1
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
INFO:pytorch_lightning.utilities.rank_zero:Enabling DeepSpeed BF16.
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]
Using /home/ubuntu/.cache/torch_extensions/py310_cu117 as PyTorch extensions root...
Detected CUDA files, patching ldflags
Emitting ninja build file /home/ubuntu/.cache/torch_extensions/py310_cu117/fused_adam/
Building extension module fused_adam...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module fused_adam...
Time to load fused_adam op: 0.06461381912231445 seconds
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:2 to store for rank: 0
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
  | Name   | Type       | Params
0 | emb    | Embedding  | 67.1 M
1 | blocks | ModuleList | 327 M 
2 | ln_out | LayerNorm  | 2.0 K 
3 | head   | Linear     | 67.1 M
461 M     Trainable params
0         Non-trainable params
461 M     Total params
1,846.886 Total estimated model params size (MB)
Epoch 0:   0%|                                           | 0/2520 [00:00<?, ?it/s]
{'zero_optimization': {'stage': 2}, 'gradient_accumulation_steps': 1, 'train_micro_batch_size_per_gpu': 16, 'gradient_clipping': 1.0, 'bf16': {'enabled': True}}

Traceback (most recent call last):
  File "/home/ubuntu/MedicalGPT/rwkv/RWKV-LM/RWKV-v5/", line 312, in <module>, data_loader)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 608, in fit
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 36, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/", line 88, in launch
    return function(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 650, in _fit_impl
    self._run(model, ckpt_path=self.ckpt_path)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1112, in _run
    results = self._run_stage()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1191, in _run_stage
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1214, in _run_train
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/", line 267, in advance
    self._outputs =
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/", line 213, in advance
    batch_output =
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/batch/", line 88, in advance
    outputs =, kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/", line 199, in run
    self.advance(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 202, in advance
    result = self._run_optimization(kwargs, self._optimizers[self.optim_progress.optimizer_position])
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 249, in _run_optimization
    self._optimizer_step(optimizer, opt_idx, kwargs.get("batch_idx", 0), closure)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 370, in _optimizer_step
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1356, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/", line 1754, in optimizer_step
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/core/", line 169, in step
    step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/", line 280, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/", line 234, in optimizer_step
    return self.precision_plugin.optimizer_step(
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/", line 132, in optimizer_step
    closure_result = closure()
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 149, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 144, in closure
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/loops/optimization/", line 305, in backward_fn
    self.trainer._call_strategy_hook("backward", loss, optimizer, opt_idx)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1494, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/strategies/", line 207, in backward
    self.precision_plugin.backward(closure_loss, self.lightning_module, optimizer, optimizer_idx, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/pytorch_lightning/plugins/precision/", line 118, in backward
    deepspeed_engine.backward(tensor, *args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/utils/", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/", line 1955, in backward
    self.optimizer.backward(loss, retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/zero/", line 2019, in backward
    self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/deepspeed/runtime/fp16/", line 63, in backward
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/", line 488, in backward
  File "/home/ubuntu/micromamba/envs/rwkv/lib/python3.10/site-packages/torch/autograd/", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 21.99 GiB total capacity; 20.70 GiB already allocated; 287.00 MiB free; 20.88 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
image I currently have 4 cards of A10-22G, how can I maximize the utilization of computing power and memory?

BlinkDL commented 5 months ago

set --devices 4 to use 4 GPU


Ethan-Chen-plus commented 5 months ago

Thanks again @BlinkDL image

I have another question I'd like to ask: Currently, I'm using a context length (ctx_len) of 1024 for full fine-tuning a model with only 0.4B parameters, specifically rwkv5, but it's almost maxing out the memory on all four of my A10 GPUs. However, llama2-7b can run full-scale on four A10 cards with a context length of 4096. Is there a way I can enable my v5 model to run full-scale training with a context length of 4096 using model parallelism across four GPUs?

PicoCreator commented 5 months ago

Check your "gradient checkpoint" flag, disabling gives a speed boost, for much more VRAM usage (llama typically have that set to true)

BlinkDL commented 5 months ago

@Ethan-Chen-plus set GRAD_CP=1