Closed hanlinxuy closed 7 months ago
One should install torch scale from source (https://github.com/microsoft/torchscale) to include multiscale retention, I have tested command below on V100 with tf32, now I am testing on bf16 on RTX4090: python3 train.py --load_model "" --wandb "" --proj_dir "out" \ --data_file "" --data_type "dummy" --vocab_size 0 \ --ctx_len 128 --epoch_steps 1000 --epoch_count 20 --epoch_begin 0 --epoch_save 10 \ --micro_bsz 16 --n_layer 12 --n_embd 768 --pre_ffn 0 --head_qk 0 \ --lr_init 6e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 \ --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0 \ --use_retnet
Did not implement some hyperparameters like adam beta/eps, keep them controllable by original args configuration.
Add a wrapper for official RETNET implementation, from https://github.com/microsoft/torchscale. TODO: