BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.05k stars 827 forks source link

[WIP] wrap retnet #165

Closed hanlinxuy closed 7 months ago

hanlinxuy commented 1 year ago

Add a wrapper for official RETNET implementation, from https://github.com/microsoft/torchscale. TODO:

  1. check the parameters' meaning;
  2. check output correctness.
hanlinxuy commented 1 year ago

One should install torch scale from source (https://github.com/microsoft/torchscale) to include multiscale retention, I have tested command below on V100 with tf32, now I am testing on bf16 on RTX4090: python3 train.py --load_model "" --wandb "" --proj_dir "out" \ --data_file "" --data_type "dummy" --vocab_size 0 \ --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 \ --use_retnet

hanlinxuy commented 12 months ago

Did not implement some hyperparameters like adam beta/eps, keep them controllable by original args configuration.