tencent-ailab / IP-Adapter

The image prompt adapter is designed to enable a pretrained text-to-image diffusion model to generate images with image prompt.
Apache License 2.0
4.45k stars 289 forks source link

Cross Attention in training and generation #378

Open athena913 opened 3 weeks ago

athena913 commented 3 weeks ago

Hi, Thank you for releasing your code. I would like to understand where is the decoupled cross-attention being used in the code, as stated in the paper. In the code, I only say concatenation. I would appreciate any explanation you can provide - thank you.

1) In the training code in tutorial_train.py in the IPAdapter class, the encoder state is being concatenated with the output of the image projection model. But in the architectural diagram of the paper, it is being input through a decoupled cross-attention model, which I don't see in the code below.

 def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
        ip_tokens = self.image_proj_model(image_embeds)
        encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
        # Predict the noise residual
        noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
        return noise_pred

2) Likewise, in the generator code, the optional text prompt and image prompt embeddings (using the trained projection model) are concatenated and input to the stable diffusion model for generation. I don't see any cross attention during generation.

prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)

I dont see the following ip_layers being used during generation.

 ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
 ip_layers.load_state_dict(state_dict["ip_adapter"])
xiaohu2015 commented 3 weeks ago

see https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/attention_processor.py#L82

athena913 commented 3 weeks ago

Thank you. I understand where your have implemented the cross-attention. But is the cross-attention being used during generation? It looks like the image is embedded using the projection layer and then concatenated with the (optional) text embedding for without using cross-attention during generation . Is this correct?

xiaohu2015 commented 3 weeks ago

@athena913 it is used, see https://github.com/tencent-ailab/IP-Adapter/blob/main/ip_adapter/ip_adapter.py#L94