KMnP / vpt

❄️🔥 Visual Prompt Tuning [ECCV 2022] https://arxiv.org/abs/2203.12119
Other
1.04k stars 91 forks source link

Will the prompts in layer L enter in the layer L+1? #36

Closed RongchangLi closed 1 year ago

RongchangLi commented 1 year ago

We can see the forward process in the code below. It seems the prompts will enter into the next layer . If it is true, it seems not consistent with the Fig.2 of the original paper.

def forward_deep_prompt(self, embedding_output):
        attn_weights = []
        hidden_states = None
        weights = None
        B = embedding_output.shape[0]
        num_layers = self.vit_config.transformer["num_layers"]

        for i in range(num_layers):
            if i == 0:
                hidden_states, weights = self.encoder.layer[i](embedding_output)
            else:
                if i <= self.deep_prompt_embeddings.shape[0]:
                    deep_prompt_emb = self.prompt_dropout(self.prompt_proj(
                        self.deep_prompt_embeddings[i-1]).expand(B, -1, -1))

                    hidden_states = torch.cat((
                        hidden_states[:, :1, :],
                        deep_prompt_emb,
                        hidden_states[:, (1+self.num_tokens):, :]
                    ), dim=1)

                hidden_states, weights = self.encoder.layer[i](hidden_states)

            if self.encoder.vis:
                attn_weights.append(weights)

        encoded = self.encoder.encoder_norm(hidden_states)
        return encoded, attn_weights