XiangLi1999 / PrefixTuning

Prefix-Tuning: Optimizing Continuous Prompts for Generation
887 stars 161 forks source link

How to full train the model? #10

Open StevenTang1998 opened 3 years ago

StevenTang1998 commented 3 years ago

Hello, I want to fine-tune the prefix along with the whole BART model. And I comment the freeze code in seq2seq/finetune.py#L95. I don't know if it is right. (I see GPU usage getting bigger, that may be right)

However, when I load the model, I find only the prefix part is saved. So, I want to know how to train, save and load the prefix and BART model.

Thank you very much!

XiangLi1999 commented 3 years ago

Hi, I think commenting the freezing code in L95 is not enough. there are two models in PrefixTransformer class self.seq2seq_model is the BART model, and self.model is the trainable prefix. Only self.model is fed into the optimizer, so even if you unfreeze the parameters in self.seq2seq_model the parameters will not be updated or saved.

One simple solution is to change the PrefixTuning class's initialization to also include the BART model. https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/prefixTuning.py#L7 For example, try specify self.base_model= model_gpt2 (the naming is bad, although the variable is named gpt2, it's actually a bart model.) somewhere in the prefixTuning.__init__() function

StevenTang1998 commented 3 years ago

Thanks for your answering. I try to train for some epochs and see if it works.

StevenTang1998 commented 3 years ago

Hi, I commented the freeze code and add the self.base_model= model_gpt2 in prefixTuning.__init__() function. However, the result is lower than vanilla PrefixTuning. I wonder if loading overwrites the previously trained model?

XiangLi1999 commented 3 years ago

I think it wouldn't, but you could double check, by printing the name of the loaded parameters: model.named_parameters()

Aside from that, the reason could be learning rate? You probably want to update the prefix parameters and the model parameters at different lr?

StevenTang1998 commented 3 years ago

This may not be lr's problem, I have done other experiments.

I feel that just adding self.base_model= model_gpt2 might not be enough, because self.base_model is not used in the code, so it might be saved but not used. I wonder if the original BART will be used in the test phase instead of the fine-tuned one.