haoosz / ViCo

Official PyTorch codes for the paper: "ViCo: Detail-Preserving Visual Condition for Personalized Text-to-Image Generation"
MIT License
238 stars 15 forks source link

Image cross attention across all tokens? #7

Closed okaris closed 1 year ago

okaris commented 1 year ago

@haoosz Thank you for the amazing work and the open source code. I have been working to implement it on huggingface/diffusers. I believe the architecture in place but even with regularization and masking my models don't converge in terms of loss but the results are overfitted and distorted with the subject appearance.

I have gone through the code and the paper several time and I have one question I can't answer.

Is image cross attention applied for all tokens in the prompt or is it only computed for S* token (therefore using vanilla attention maps for other tokens)

haoosz commented 1 year ago

Thank you for your interest in our work and the efforts to implement it on diffusers. The image cross attention is applied between two sets of patch tokens (from the reference image and the denoised image), which doesn't involve tokens from the prompt. Let me simplify the whole operation process in one conditioning block (omit the mask):

  1. denoised_image_cond = CrossAttention(denoised_image, cond = prompt)
  2. reference_image_cond = CrossAttention(reference_image, cond = prompt)
  3. out_visual_conditioned = ImageCrossAttention(denoised_image_cond, cond = reference_image_cond)

All cross-attention operations are applied to all tokens. Only the mask is obtained from the single S* token.

Hope the above may help your work!

okaris commented 1 year ago

Thanks @haoosz it does help! 🙌🏻

okaris commented 1 year ago

One more question. To be able to implement in diffusers without disrupting the whole code, I am running the unet seperately for referenced image conditions, collect the attention maps and apply them while running the unet for denoised image condition + out visual conditioned.

Do you see any errors running the vanilla unet separetely for reference image?

haoosz commented 1 year ago

I think it is fine because the reference image only goes through the vanilla unet.