XiangLi1999 / PrefixTuning

Prefix-Tuning: Optimizing Continuous Prompts for Generation
898 stars 162 forks source link

Understanding the Seq2Seq Encoder-Decoder Prefix Implementation #31

Open rajaswa opened 2 years ago

rajaswa commented 2 years ago

Hi @XiangLi1999, thank you for open-sourcing this amazing work! I have been trying to understand your seq2seq implementation: https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/prefixTuning.py#L7

I was wondering if you could help me with a few doubts that I had regarding the same:

  1. What does the mode_para attribute mean? https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/prefixTuning.py#L94
  2. What is the difference between the multiple methods for getting the prompt prefixes? Which ones are used in the paper?
  3. How is the Prefix on Encoder side implemented? I saw that the use_encoder_prefix attribute was only used in one of the prompt methods: def get_prompt_p5 https://github.com/XiangLi1999/PrefixTuning/blob/6519d30e69b15a180f23e2cd41b766d3f62b8e82/seq2seq/prefixTuning.py#L436
  4. What does the use_cross_prefix attribute do?
  5. How is the Encoder side prefix being attended by the Decoder, given that we are using past_key_values to feed the prefixes?
  6. How are the past_key_values fed to the model? As per my understanding, it should contain the key-value pairs for all the preceding tokens on the decoding side. How is the encoder side prefix included in the past_key_values? https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartForConditionalGeneration.forward.past_key_values
  7. Where is the inference code implemented for the seq2seq model? (if one needs to deploy/serve it). In this case we would need a token-by-token decoding, right?
JaniceXiong commented 2 years ago

3. How is the Prefix on Encoder side implemented? I saw that the use_encoder_prefix attribute was only used in one of the prompt methods: def get_prompt_p5

I have the same question🙋‍. Any explanation will be helpful :)

Timothyxxx commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

yahoo17 commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Could you please give a link of your implementation >

Timothyxxx commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Could you please give a link of your implementation >

Sure, we will release it within our work in the end of this year.

rajaswa commented 2 years ago

@Timothyxxx if I understand correctly, get_prompt_p5 gives out a past_key_values of length = prefix length. Whereas, the model expects past_key_values to be of length = sequence length (see Huggging Face documentation).

Does this not throw an error for you? I am trying to replicate get_prompt_p5, and am facing this error:

  File "/home/local/anaconda3/envs/test_env/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 506, in forward
    scores += position_bias
RuntimeError: The size of tensor a (25) must match the size of tensor b (512) at non-singleton dimension 3

Where my prefix length is 25, and input sequence length is 512.

Timothyxxx commented 2 years ago

@Timothyxxx if I understand correctly, get_prompt_p5 gives out a past_key_values of length = prefix length. Whereas, the model expects past_key_values to be of length = sequence length (see Huggging Face documentation).

Does this not throw an error for you? I am trying to replicate get_prompt_p5, and am facing this error:

  File "/home/local/anaconda3/envs/test_env/lib/python3.8/site-packages/transformers/models/t5/modeling_t5.py", line 506, in forward
    scores += position_bias
RuntimeError: The size of tensor a (25) must match the size of tensor b (512) at non-singleton dimension 3

Where my prefix length is 25, and input sequence length is 512.

Indeed, I faced exactly the same problem at my start. This is due to the protection assertion mechanism from huggingface t5. My solution is not using "past key value" to pass prompt, instead, overwrite the modeling_t5.py of your own and add a separate parameter(for example, mine is "past_prefix_prompt") in the parameter, and did its job.

Timothyxxx commented 2 years ago

The thing is, we can utilize past_key_value to pass prefix prompt, but not always, since actually it is not its job. Personally I really recommend to add a parameter and make it performs its own job.

JaniceXiong commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Hi! Do you implement the relevant word initialization for Seq2Seq model (Bart) like Figure 5 in paper? I described the problem in issue 32 but didn't get any reply from authors.

Timothyxxx commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Hi! Do you implement the relevant word initialization for Seq2Seq model (Bart) like Figure 5 in paper? I described the problem in issue 32 but didn't get any reply from authors.

No. I haven't tried that initialization. But I think maybe you just need to make the prompt input initialized with embedding from BART encoding.

JaniceXiong commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Hi! Do you implement the relevant word initialization for Seq2Seq model (Bart) like Figure 5 in paper? I described the problem in issue 32 but didn't get any reply from authors.

No. I haven't tried that initialization. But I think maybe you just need to make the prompt input initialized with embedding from BART encoding.

Thanks for your kindly reply! Do you mean that I need to use BART to encode sequences like "summarize" or "table-to-text:" and get "self" past_key_values to initialize the prompt? https://github.com/XiangLi1999/PrefixTuning/blob/48dbaff350752d130501e1ee63a8e0c20c62868f/gpt2/train_control.py#L131-L144

But I still have a question. If initialize the prompt like above, the prelen of prompt is 1 or 6, which is much shorter than 200 used in random initialize setting. And I got very low performance and the training loss is unstable. I don't know what the problem is:

model = BartForConditionalGeneration.from_pretrained( "facebook/bart-large", from_tf=False, config=config, cache_dir=None ) with torch.no_grad(): output = model(init_shallow_word, return_dict=True, use_cache=True)

output = output.past_key_values init_val = [] for item in output: init_val.append(item[0].unsqueeze(0)) # key, [1, 1, num_heads, sequence_length, dim_head] init_val.append(item[1].unsqueeze(0)) # val init_val = torch.cat(init_val, dim=0)

self.control_trans = nn.Parameter(init_val) # trainable temp = self.control_trans.expand(-1, bsz, -1, -1, -1) past_key_values = temp.split(2)

Ant0082 commented 2 years ago

I think she may too busy to clean her code... We implemented her idea base on her get_prompt_p5 method and it works properly, so I think it is the final and decent decision.

Hi! Do you implement the relevant word initialization for Seq2Seq model (Bart) like Figure 5 in paper? I described the problem in issue 32 but didn't get any reply from authors.

No. I haven't tried that initialization. But I think maybe you just need to make the prompt input initialized with embedding from BART encoding.

Thanks for your kindly reply! Do you mean that I need to use BART to encode sequences like "summarize" or "table-to-text:" and get "self" past_key_values to initialize the prompt?

https://github.com/XiangLi1999/PrefixTuning/blob/48dbaff350752d130501e1ee63a8e0c20c62868f/gpt2/train_control.py#L131-L144

But I still have a question. If initialize the prompt like above, the prelen of prompt is 1 or 6, which is much shorter than 200 used in random initialize setting. And I got very low performance and the training loss is unstable. I don't know what the problem is:

  • the prelen is too short, so if I fix the LM, there are fewer trainable parameters.
  • the "self" past_key_values only works on decoder masked attention, but random initialization has "self", "encoder", "encoder_decoder" past_key_values work on different part.
  • my implementation is wrong.
shallow_word = "summarize"
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
init_shallow_word = tokenizer([shallow_word], add_prefix_space=True)['input_ids']
init_shallow_word = torch.LongTensor(init_shallow_word)

model = BartForConditionalGeneration.from_pretrained(
                "facebook/bart-large",
                from_tf=False,
                config=config,
                cache_dir=None
            )
with torch.no_grad():
    output = model(init_shallow_word, return_dict=True, use_cache=True)

output = output.past_key_values
init_val = []
for item in output:
    init_val.append(item[0].unsqueeze(0)) # key, [1, 1, num_heads, sequence_length, dim_head]
    init_val.append(item[1].unsqueeze(0)) # val
init_val = torch.cat(init_val, dim=0)

self.control_trans = nn.Parameter(init_val) # trainable
temp = self.control_trans.expand(-1, bsz, -1, -1, -1) 
past_key_values = temp.split(2)
  1. If the parameter prelen is too small, the effect will be poor. The smaller the value of prelen, the fewer parameters can be trained. If you want to achieve a better performance, you need to replace a larger language model
  2. I read source code, "self" is used for the self-attention calculation phase of the encoder, "encoder" I think it should be "decoder" is used for the self-attention calculation phase of the decoder, "encoder_decoder" is used for the cross_attention calculation phase of the decoder.
  3. It seems that your code implementation is fine. The paper says that training without reparameterization is very unstable, and P_theta cannot be trained directly, even using the "shallow word" approach.