Closed nuaajeff closed 2 years ago
This is the way I implement function incorporate_prompt,
def incorporate_prompt(self, x, test=False):
# combine prompt embeddings with image-patch embeddings
B = x.shape[0]
# after CLS token, all before image patches
x = self.embeddings(x) # (batch_size, 1 + n_patches, hidden_dim)
if not test:
x1 = torch.cat((
x[:, :1, :],
self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
x[:, 1:, :]
), dim=1)
x2 = torch.cat((
x[:, :1, :],
self.prompt_dropout(self.prompt_proj(self.prompt_embeddings_cw).expand(B, -1, -1)),
x[:, 1:, :]
), dim=1)
x = [x1, x2]
else:
x = torch.cat((
x[:, :1, :],
self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
self.prompt_dropout(self.prompt_proj(self.prompt_embeddings_cw).expand(B, -1, -1)),
x[:, 1:, :]
), dim=1)
# (batch_size, cls_token + n_prompt + n_patches, hidden_dim)
return x
and I implement forward_deep_prompt in a similar way.
def forward_deep_prompt(self, embedding_output, test=False):
attn_weights = []
hidden_states = None
weights = None
if not test:
B = embedding_output[0].shape[0]
else:
B = embedding_output.shape[0]
num_layers = self.vit_config.transformer["num_layers"]
if not test:
encoded = []
# for the first branch
for i in range(num_layers):
if i == 0:
hidden_states, weights = self.encoder.layer[i](embedding_output[0])
else:
if i <= self.deep_prompt_embeddings.shape[0]:
deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
hidden_states = torch.cat((
hidden_states[:, :1, :],
deep_prompt_emb,
hidden_states[:, (1+self.num_tokens):, :]
), dim=1)
hidden_states, weights = self.encoder.layer[i](hidden_states)
#print('train1', i, hidden_states.shape)
if self.encoder.vis:
attn_weights.append(weights)
enc = self.encoder.encoder_norm(hidden_states)
encoded.append(enc)
# for the second branch
for i in range(num_layers):
if i == 0:
hidden_states, weights = self.encoder.layer[i](embedding_output[1])
else:
if i <= self.deep_prompt_embeddings.shape[0]:
deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
self.deep_prompt_embeddings_cw[i-1]).expand(B, -1, -1))
hidden_states = torch.cat((
hidden_states[:, :1, :],
deep_prompt_emb,
hidden_states[:, (1+self.num_tokens):, :]
), dim=1)
hidden_states, weights = self.encoder.layer[i](hidden_states)
if self.encoder.vis:
attn_weights.append(weights)
#print('train2', i, hidden_states.shape)
enc = self.encoder.encoder_norm(hidden_states)
encoded.append(enc)
else:
for i in range(num_layers):
if i == 0:
hidden_states, weights = self.encoder.layer[i](embedding_output)
else:
if i <= self.deep_prompt_embeddings.shape[0]:
deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
deep_prompt_emb_cw = self.prompt_dropout(self.prompt_proj(
self.deep_prompt_embeddings_cw[i-1]).expand(B, -1, -1))
hidden_states = torch.cat((
hidden_states[:, :1, :],
deep_prompt_emb,
deep_prompt_emb_cw,
hidden_states[:, (1+self.num_tokens*2):, :]
), dim=1)
hidden_states, weights = self.encoder.layer[i](hidden_states)
#print('test', i, hidden_states.shape)
if self.encoder.vis:
attn_weights.append(weights)
encoded = self.encoder.encoder_norm(hidden_states)
return encoded, attn_weights
Two separate MLPs are used as classifiers to calculate ouputs and losses, and when testing I just use of one of the two classifiers to perform inference for ensemble of prompts.
Hi, thanks for ur interest in our work!
I think you might misunderstand our meaning of "prompt ensembling" here. Different sets of prompts are still seperated in the inference process instead of combined together as you described. But because the pre-trained backbone is shared, this can be implemented more efficiently in a "batch" manner, where each sample is the concatenation of the prompts and the original image. The similar idea is also experimented in the NLP paper "The Power of Scale for Parameter-Efficient Prompt Tuning" section 6.
Feel free to let us know if u have more questions.
@Tsingularity Thanks for your explanation. But If different sets of prompts are separated, why is this method more efficient? I think the total calculation complexity of inference is the same with traditional fully fine-tuning method, except there is no need to store another backbone.
This is a good question! I think it actually depends on how do u define "efficient" especially in practical senerios. For example, you fully fine-tune 5 times and then try to ensemble them together. During inference time, it's pretty hard to run all the 5 models in parallel. But for prompt tuning, we can finish it within only 1 forward pass by making a special "batch" for each example, where the data is replicated but prompts are varied. I think section 6 and figure 2 of "The Power of Scale for Parameter-Efficient Prompt Tuning" makes a better explanation than mine.
This is a good question! I think it actually depends on how do u define "efficient" especially in practical senerios. For example, you fully fine-tune 5 times and then try to ensemble them together. During inference time, it's pretty hard to run all the 5 models in parallel. But for prompt tuning, we can finish it within only 1 forward pass by making a special "batch" for each example, where the data is replicated but prompts are varied. I think section 6 and figure 2 of "The Power of Scale for Parameter-Efficient Prompt Tuning" makes a better explanation than mine.
Thanks for your helpful discussion. I think maybe a proper description is VPT can easily achieve efficient implementation in ensemble situation rather than say VPT is "more" efficient, because it is still possible to run many models in parallel.
I have tried a long time for training Cifar100, but it still does not work. please help me to run successfully, thanks very much!
1.Have downloaded the Cifar100 dataset:
2.Have downloaded the pretrained_model and rename
3.Configs file: vpt/configs/prompt/cifar100.yaml
Configs file: vpt/src/configs/config.py
4.Train model: /vpt/run.sh
5. Error log
@miss-rain #9
Hi, thanks for your great work and thorough experiments. I found that in Fig. 15., the results show ensemble can improve the performance of VPT. I reimplemented the ensemble method in a quick way: I trained two sets of visual prompts in parallel on the CIFAR-100 dataset. And during the testing, I directly ensembled those prompts, i.e., using 2X prompts to perform inference. However, I found the ensemble surely degraded the performance, i.e., two sets of visual prompts (6 tokens each set) achieved accuracies 80.36, 80.85 respectively. But after ensembling those prompts (6*2=12 tokens), the accuracy was 74.75. That's a little strange to me because the result of the ensemble was much worse than the separate results. Could you please share some ideas on this phenomenon? Thanks a lot!
By the way, the best numbers of tokens of VPT are different for different datasets. And it seems that sometimes too many prompts would lead to performance degradation. And I noticed the 5 sets of different prompts were ensembled in the paper, which means maybe the number of prompts is up to 500. Will it degrade the final performance?