Open yssjtu opened 3 years ago
prefix-tuning (the general method) is compatible with T5, but the code base is not. However, you need to modify more then past_key_values. You probably need to delve deep into the T5 source code and fix the issue you mentioned above (aka, if past_key_vals presents, then it will only take the last input token). Note this setting tends to be the default, since past_key_vals is designed to be used at generation time for caching purposes.
Thanks for the reply!
When I found these codes in finetune.py, I thought that T5 model is also supported.
https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/finetune.py#L128-L132
Actually, the issue I mentioned only exists in T5ForConditionalGeneration module. As for T5Model , there's no limitation to use past_key_values together with decoder_input_ids. But using T5Model rather than T5F.C.G. means that I need to manually add a LM_head and implement the generate() method(including beam search maybe).
Another solution is that maybe I could delete that "assert" source code in T5. It seems that the inner modules like T5Block and T5Attention are able to handle both the past_key_values and decoder_input_ids properly(though I'm not 100% sure).
Thanks for the reply! When I found these codes in finetune.py, I thought that T5 model is also supported.
Actually, the issue I mentioned only exists in T5ForConditionalGeneration module. As for T5Model , there's no limitation to use past_key_values together with decoder_input_ids. But using T5Model rather than T5F.C.G. means that I need to manually add a LM_head and implement the generate() method(including beam search maybe). Another solution is that maybe I could delete that "assert" source code in T5. It seems that the inner modules like T5Block and T5Attention are able to handle both the past_key_values and decoder_input_ids properly(though I'm not 100% sure).
Hello! I have the same problem as you, how did you modify it in the end?
Hello! I'm trying to use PrefixTuning with T5 model. After reading source codes in seq2seq, I figure that generally speaking, prefix is added to the BART model by using the parameter _past_keyvalues.
But in T5, when the parameter _past_keyvalues is provided together with _decoder_inputids(the ground truth while training), the forward() function will only use the last token of decoder_input_ids
https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/transformers/src/transformers/modeling_t5.py#L1201-L1206
while the BART use the full decoder_input_ids.
https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/transformers/src/transformers/modeling_bart.py#L1448-L1465
However, I don't see any code handling this difference in seq2seq folder. The only codes I find about T5 are handing input ids or freezing embeddings.
Is PrefixTuning compatible with T5 model? If not, could you give some advice to make it so? Thanks a lot!