THUDM / P-tuning-v2

An optimized deep prompt tuning strategy comparable to fine-tuning across scales and tasks
Apache License 2.0
1.96k stars 198 forks source link

Questions about deep prompt per layer #41

Closed eunjiinkim closed 2 years ago

eunjiinkim commented 2 years ago

Hi, I have a question about deep prompt. I understand that deep prompts are implemented through past_key_values in model. Then how can I see the actual prompt weights per layer? I mean, the shape of prompt is (prefix_len, config.num_hidden_layers 2 config.hidden_size) if without trans. And the shape of past_key_values for input is [2, batch_size, n_head, prefix_len, n_embd] per each layer. I believe that the first '2' corresponds key and value for attention mechanism. Here I want to obtain [prefix_len, config.hidden_size] vector just like embedding vector of prompt-tuning v1.

Do you have any idea for this?

Thanks : )

Xiao9905 commented 2 years ago

@eunjiinkim Hi,

Thanks for your interest in P-Tuning v2. I think the code you are looking for lies in here. And it is transformed into [2, batch_size, n_head, prefix_len, n_embd] at here.

Do these codes solve your problems? Feel free to ask if you still have other questions.

Xiao9905 commented 2 years ago

I think the attention mechanism is interesting regarding p-tuning v2's behavior. I notice that for many attention heads actually prefix tokens do not change anything in the attention computation. For other heads, it is a rather valuable research problem to see what are prompt tokens actually doing in p-tuning v2.

eunjiinkim commented 2 years ago

@Xiao9905 Thanks for your comment! Actually, I'm trying to investigate the prefix tokens in terms of not only attention weights but also the relationship with model own embedding. So I'd like to get the vectors whose shape is like [prefx_len, hidden_size] in order to calculate similarity with model parameters.

As you said, past_key_values are splitted into [2, batch_size, n_head, prefix_len, n_embd]. So, do you mean that I can only see past_key_values separated by key and value? Here, hidden_size is n_head*n_embd, but '2' is a key issue in my case because these two vectors have their own weights. Should I just see key and value vectors separately or can I multiply key and value vectors?

Just let me know if my question is not clear :)

Xiao9905 commented 2 years ago

@eunjiinkim Hi,

I now understand your question, that you want the prompt embedding, instead of its corresponding key and value, right?

In p-tuning (v1), the key and value are derived using attention heads' linear projection matrices to transform prompt embeddings. However, for the consideration of optimization, in p-tuning v2 the keys and values are separate from the very beginning; if not, it still converges but may require another set of hyper-parameters according to my experience.

Existing p-tuning v2 implementation doesn't allow such features. To achieve the goal, it might take you some efforts to modify the huggingface transformer source code (e.g., for roberta, see this line) to only pass [1, batch_size, n_head, prefix_len, n_embd] as past_key_value, and project it into corresponding keys and values for concatenation, such as (without testing, not sure if my code really works):

key_layer = torch.cat([self.key(past_key_value[0]), key_layer], dim=2)
value_layer = torch.cat([self.value(past_key_value[0]), value_layer], dim=2)

and in this case after training converges, I think the [prefix_len, hidden_size] is exactly what you need.

eunjiinkim commented 2 years ago

@Xiao9905 HI, thanks for your effort. It really helped me.

I've changed some codes in soft prompt and modeling_gpt2.py.

For soft prompt, I removed *2 (which was a problem for me) self.new_size =self.config.num_hidden_layers * self.config.hidden_size and removed split(2) in past_key_values input. past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).to(self.device)

For gpt2 modeling, I modified past_key and past_value just like you've suggested in Roberta.

        if layer_past is not None:
            past_key, past_value = layer_past, layer_past 
            key = torch.cat((past_key, key), dim=-2)
            value = torch.cat((past_value, value), dim=-2)

I think it works and gives quite similar result with the original size prompt-v2 in my experiment. Thank you SO much. Now I can do what I want. :)

Xiao9905 commented 2 years ago

That's cool. Looking forward to your research!