KMnP / vpt

❄️🔥 Visual Prompt Tuning [ECCV 2022] https://arxiv.org/abs/2203.12119
Other
1k stars 91 forks source link

Ensemble seems degrades the performance on CIFAR-100 #7

Closed nuaajeff closed 2 years ago

nuaajeff commented 2 years ago

Hi, thanks for your great work and thorough experiments. I found that in Fig. 15., the results show ensemble can improve the performance of VPT. I reimplemented the ensemble method in a quick way: I trained two sets of visual prompts in parallel on the CIFAR-100 dataset. And during the testing, I directly ensembled those prompts, i.e., using 2X prompts to perform inference. However, I found the ensemble surely degraded the performance, i.e., two sets of visual prompts (6 tokens each set) achieved accuracies 80.36, 80.85 respectively. But after ensembling those prompts (6*2=12 tokens), the accuracy was 74.75. That's a little strange to me because the result of the ensemble was much worse than the separate results. Could you please share some ideas on this phenomenon? Thanks a lot!

By the way, the best numbers of tokens of VPT are different for different datasets. And it seems that sometimes too many prompts would lead to performance degradation. And I noticed the 5 sets of different prompts were ensembled in the paper, which means maybe the number of prompts is up to 500. Will it degrade the final performance?

nuaajeff commented 2 years ago

This is the way I implement function incorporate_prompt,

    def incorporate_prompt(self, x, test=False):
        # combine prompt embeddings with image-patch embeddings
        B = x.shape[0]
        # after CLS token, all before image patches
        x = self.embeddings(x)  # (batch_size, 1 + n_patches, hidden_dim)
        if not test:
            x1 = torch.cat((
                    x[:, :1, :],
                    self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                    x[:, 1:, :]
                ), dim=1)

            x2 = torch.cat((
                    x[:, :1, :],
                    self.prompt_dropout(self.prompt_proj(self.prompt_embeddings_cw).expand(B, -1, -1)),
                    x[:, 1:, :]
                ), dim=1)
            x = [x1, x2]

        else:
            x = torch.cat((
                    x[:, :1, :],
                    self.prompt_dropout(self.prompt_proj(self.prompt_embeddings).expand(B, -1, -1)),
                    self.prompt_dropout(self.prompt_proj(self.prompt_embeddings_cw).expand(B, -1, -1)),
                    x[:, 1:, :]
                ), dim=1)
        # (batch_size, cls_token + n_prompt + n_patches, hidden_dim)

        return x

and I implement forward_deep_prompt in a similar way.

    def forward_deep_prompt(self, embedding_output, test=False):
        attn_weights = []
        hidden_states = None
        weights = None
        if not test:
            B = embedding_output[0].shape[0]
        else:
            B = embedding_output.shape[0]
        num_layers = self.vit_config.transformer["num_layers"]

        if not test:
            encoded = []
            # for the first branch
            for i in range(num_layers):
                if i == 0:
                    hidden_states, weights = self.encoder.layer[i](embedding_output[0])
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)

                    hidden_states, weights = self.encoder.layer[i](hidden_states)
                #print('train1', i, hidden_states.shape)

                if self.encoder.vis:
                    attn_weights.append(weights)

            enc = self.encoder.encoder_norm(hidden_states)
            encoded.append(enc)

            # for the second branch
            for i in range(num_layers):
                if i == 0:
                    hidden_states, weights = self.encoder.layer[i](embedding_output[1])
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings_cw[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            hidden_states[:, (1+self.num_tokens):, :]
                        ), dim=1)

                    hidden_states, weights = self.encoder.layer[i](hidden_states)

                if self.encoder.vis:
                    attn_weights.append(weights)

                #print('train2', i, hidden_states.shape)

            enc = self.encoder.encoder_norm(hidden_states)
            encoded.append(enc)

        else:
            for i in range(num_layers):
                if i == 0:
                    hidden_states, weights = self.encoder.layer[i](embedding_output)
                else:
                    if i <= self.deep_prompt_embeddings.shape[0]:
                        deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))
                        deep_prompt_emb_cw = self.prompt_dropout(self.prompt_proj(
                            self.deep_prompt_embeddings_cw[i-1]).expand(B, -1, -1))

                        hidden_states = torch.cat((
                            hidden_states[:, :1, :],
                            deep_prompt_emb,
                            deep_prompt_emb_cw,
                            hidden_states[:, (1+self.num_tokens*2):, :]
                        ), dim=1)

                    hidden_states, weights = self.encoder.layer[i](hidden_states)

                #print('test', i, hidden_states.shape)

                if self.encoder.vis:
                    attn_weights.append(weights)

            encoded = self.encoder.encoder_norm(hidden_states)
        return encoded, attn_weights

Two separate MLPs are used as classifiers to calculate ouputs and losses, and when testing I just use of one of the two classifiers to perform inference for ensemble of prompts.

Tsingularity commented 2 years ago

Hi, thanks for ur interest in our work!

I think you might misunderstand our meaning of "prompt ensembling" here. Different sets of prompts are still seperated in the inference process instead of combined together as you described. But because the pre-trained backbone is shared, this can be implemented more efficiently in a "batch" manner, where each sample is the concatenation of the prompts and the original image. The similar idea is also experimented in the NLP paper "The Power of Scale for Parameter-Efficient Prompt Tuning" section 6.

Feel free to let us know if u have more questions.

nuaajeff commented 2 years ago

@Tsingularity Thanks for your explanation. But If different sets of prompts are separated, why is this method more efficient? I think the total calculation complexity of inference is the same with traditional fully fine-tuning method, except there is no need to store another backbone.

Tsingularity commented 2 years ago

This is a good question! I think it actually depends on how do u define "efficient" especially in practical senerios. For example, you fully fine-tune 5 times and then try to ensemble them together. During inference time, it's pretty hard to run all the 5 models in parallel. But for prompt tuning, we can finish it within only 1 forward pass by making a special "batch" for each example, where the data is replicated but prompts are varied. I think section 6 and figure 2 of "The Power of Scale for Parameter-Efficient Prompt Tuning" makes a better explanation than mine.

nuaajeff commented 2 years ago

This is a good question! I think it actually depends on how do u define "efficient" especially in practical senerios. For example, you fully fine-tune 5 times and then try to ensemble them together. During inference time, it's pretty hard to run all the 5 models in parallel. But for prompt tuning, we can finish it within only 1 forward pass by making a special "batch" for each example, where the data is replicated but prompts are varied. I think section 6 and figure 2 of "The Power of Scale for Parameter-Efficient Prompt Tuning" makes a better explanation than mine.

Thanks for your helpful discussion. I think maybe a proper description is VPT can easily achieve efficient implementation in ensemble situation rather than say VPT is "more" efficient, because it is still possible to run many models in parallel.

miss-rain commented 2 years ago

I have tried a long time for training Cifar100, but it still does not work. please help me to run successfully, thanks very much!

1.Have downloaded the Cifar100 dataset:

1660829391508

2.Have downloaded the pretrained_model and rename

1660829299011

3.Configs file: vpt/configs/prompt/cifar100.yaml

1660829612675

Configs file: vpt/src/configs/config.py

1660829798311

4.Train model: /vpt/run.sh

1660829880467

5. Error log

W
Tsingularity commented 2 years ago

@miss-rain #9