XiangLi1999 / PrefixTuning

Prefix-Tuning: Optimizing Continuous Prompts for Generation
887 stars 161 forks source link

Applying PrefixTuning with T5ForConditionalGeneration model #15

Open yssjtu opened 3 years ago

yssjtu commented 3 years ago

Hello! I'm trying to use PrefixTuning with T5 model. After reading source codes in seq2seq, I figure that generally speaking, prefix is added to the BART model by using the parameter _past_keyvalues.

But in T5, when the parameter _past_keyvalues is provided together with _decoder_inputids(the ground truth while training), the forward() function will only use the last token of decoder_input_ids

https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/transformers/src/transformers/modeling_t5.py#L1201-L1206

while the BART use the full decoder_input_ids.

https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/transformers/src/transformers/modeling_bart.py#L1448-L1465

However, I don't see any code handling this difference in seq2seq folder. The only codes I find about T5 are handing input ids or freezing embeddings.

Is PrefixTuning compatible with T5 model? If not, could you give some advice to make it so? Thanks a lot!

XiangLi1999 commented 3 years ago

prefix-tuning (the general method) is compatible with T5, but the code base is not. However, you need to modify more then past_key_values. You probably need to delve deep into the T5 source code and fix the issue you mentioned above (aka, if past_key_vals presents, then it will only take the last input token). Note this setting tends to be the default, since past_key_vals is designed to be used at generation time for caching purposes.

yssjtu commented 3 years ago

Thanks for the reply!
When I found these codes in finetune.py, I thought that T5 model is also supported.
https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/finetune.py#L128-L132 Actually, the issue I mentioned only exists in T5ForConditionalGeneration module. As for T5Model , there's no limitation to use past_key_values together with decoder_input_ids. But using T5Model rather than T5F.C.G. means that I need to manually add a LM_head and implement the generate() method(including beam search maybe). Another solution is that maybe I could delete that "assert" source code in T5. It seems that the inner modules like T5Block and T5Attention are able to handle both the past_key_values and decoder_input_ids properly(though I'm not 100% sure).

fade-color commented 2 years ago

Thanks for the reply! When I found these codes in finetune.py, I thought that T5 model is also supported.

https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/finetune.py#L128-L132

Actually, the issue I mentioned only exists in T5ForConditionalGeneration module. As for T5Model , there's no limitation to use past_key_values together with decoder_input_ids. But using T5Model rather than T5F.C.G. means that I need to manually add a LM_head and implement the generate() method(including beam search maybe). Another solution is that maybe I could delete that "assert" source code in T5. It seems that the inner modules like T5Block and T5Attention are able to handle both the past_key_values and decoder_input_ids properly(though I'm not 100% sure).

Hello! I have the same problem as you, how did you modify it in the end?