XiangLi1999 / PrefixTuning

Prefix-Tuning: Optimizing Continuous Prompts for Generation
868 stars 158 forks source link

About the training speed verification #7

Closed Timothyxxx closed 2 years ago

Timothyxxx commented 2 years ago

Hi Lisa~ I rewrite the code refer to yours on BART based on the newest huggingface transformers, and I want to verify a thing that according to my training procedure, the speed of the prefix-training is about 60%~70% of the all parameter finetune, even when I used a very very small prefix prompt module. I want to ask for your help that: does that make sense? And where may be the bottle neck of the speed? Hope for you reply.

XiangLi1999 commented 2 years ago

I think 60%-70% makes sense!

Great question: the speed gains in prefix-tuning happens because you don't have to update as many parameters that's stored in the optimizer (aka fewer trainable parameters), but backprop is still required all the way to the bottom Transformer layer. One thought experiment that could explain this is as follows: imagine when you only train the last one layer of a Transformer model, then both number of trainable parameter and the required number of backprop layer reduced (you only need to backprop one layer, since you are not interested in the gradients of first couple layers). However, if you only train the first layer of the Transformer, then you need backprop all the way, despite the same number of trainable parameters.

Based on the first layer v.s. last layer analogy, let's go back to prefix-tuning. We tune all activation layers, and therefore we need to backprop all the way back to the first layer, so backprop time is not reduced. The only reduced computation is that we don't need to do as much updates.

Let me know if this makes sense.

Timothyxxx commented 2 years ago

Great thank for your analysis! I assume for the same reasons too233. thx again!

lrongzheni commented 2 years ago

What's your GPU hardware environment, a piece of gpu can train? thx~@Tianbao Xie

Timothyxxx commented 2 years ago

Of course, it depends on the model, I think 11GB memory is enough for e2e dataset in GPT2.

lrongzheni commented 2 years ago

When trying to train in GPT2, the bellow problem trouble me. Can you help me to fix it?thx~

python train_e2e.py --optim_prefix yes --preseqlen 5 --epoch 5 --learning_rate 0.00005 --mode webnlg --bsz 5 --seed 101 webnlg_models/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1 python run_language_modeling.py --output_dir=webnlg_models/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1 --model_type=gpt2 --model_name_or_path=gpt2-medium --tokenizer_name=gpt2-medium --per_device_train_batch_size 5 --per_device_eval_batch_size 5 --save_steps 500000 --num_train_epochs 5 --do_train --train_data_file=/u/scr/xlisali/WebNLG/webnlg-dataset/webnlg_challenge_2017/train.json --do_eval --line_by_line --save_total_limit 1 --overwrite_output_dir --task_mode webnlg --eval_data_file=/u/scr/xlisali/WebNLG/webnlg-dataset/webnlg_challenge_2017/dev.json --tuning_mode prefixtune --logging_dir webnlg_models/runs/webnlgprefixtune_y_5_act_cat_b=5-e=5_d=0.0_u=no_lr=5e-05_w=0.0_s=101_r=n_m=512_o=1_o=1 --train_embs no --optim_prefix yes --preseqlen 5 --prefix_mode activation --format_mode cat --gradient_accumulation_steps 1 --learning_rate 5e-05 --weight_decay 0.0 --seed 101 --disable_tqdm --mid_dim 512 --init_random no --use_dropout no --prefix_dropout 0.0 --objective_mode 1 --evaluate_during_training --eval_steps 5000 --cache_dir /u/scr/xlisali/contrast_LM/transformers/examples/control/gpt2-medium-s3 /data/lirongzhen/PrefixTuning/transformers/src/transformers/init.py Traceback (most recent call last): File "/data/lirongzhen/PrefixTuning/gpt2/run_language_modeling.py", line 1159, in main() File "/data/lirongzhen/PrefixTuning/gpt2/run_language_modeling.py", line 498, in main parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) File "/data/lirongzhen/PrefixTuning/transformers/src/transformers/hf_argparser.py", line 40, in init self._add_dataclass_arguments(dtype) File "/data/lirongzhen/PrefixTuning/transformers/src/transformers/hf_argparser.py", line 72, in _add_dataclass_arguments elif hasattr(field.type, "origin") and issubclass(field.type.origin, List): File "/data/anaconda3/envs/PrefixTuning/lib/python3.9/typing.py", line 847, in subclasscheck return issubclass(cls, self.origin) TypeError: issubclass() arg 1 must be a class

Timothyxxx commented 2 years ago

Sorry for forgetting to close this issue, thanks again!