sled-group / InfEdit

[CVPR 2024] Official implementation of CVPR 2024 paper: "Inversion-Free Image Editing with Natural Language"
https://sled-group.github.io/InfEdit/
Other
269 stars 8 forks source link

Inconsistent implementation with description in the paper #22

Closed fkcptlst closed 5 months ago

fkcptlst commented 5 months ago

regarding mutual self attention control

Hello, I noticed an inconsistent implementation with description in the paper.

https://github.com/sled-group/InfEdit/blob/d9f6c1bcffdb25c163a4a8b4fc80467f08b2add5/app_infedit.py#L244-L253

This piece of code correspond to section 4.2 paragraph 1 in the paper. However, it replaces q_tgt, k_tgt with q_src, k_src instead of replacing k_tgt, v_tgt as described in the paper.

I wonder which one should I follow? The code or the paper?

fkcptlst commented 5 months ago

The part in the paper:

image
fkcptlst commented 5 months ago

regarding cross attention control

Additionally, the CrossEdit described in algorithm 3 in the paper is also inconsistent with the implementation.

image

Algorithm 3 used $M^{lay}$ to edit $M^{tgt}$. The implementation is as follows:

https://github.com/sled-group/InfEdit/blob/d9f6c1bcffdb25c163a4a8b4fc80467f08b2add5/app_infedit.py#L255-L268

To my understanding, attn_base, attn_repalce,attn_masa correspond to $M^{src}, M^{tgt}, M^{lay}$ respectively.

As line 263 shows, the code applies CrossEdit with $M^{src}$ instead of $M^{lay}$ as described in the paper.

I'm trying to dive into this work, but I'm puzzled by these inconsistencies. I wonder which one should I stick to? The code or the paper?

Looking forward to your early reply.

fkcptlst commented 5 months ago

regarding local attention blends

The local attention blends are implemented differently as well.

image

I'm assuming the $(m^{tgt} - m^{src})$ and $(1 - m^{tgt} + m^{src})$ are performing arithmetic operations rather than logical operations.

Below are the actual implementations.

https://github.com/sled-group/InfEdit/blob/d9f6c1bcffdb25c163a4a8b4fc80467f08b2add5/app_infedit.py#L65-L93

  1. arithmetic vs logical: Note that torch.where is used here to perform masking. This is equivalent to logical operations rather than arithmetic operations as described in the paper. Since there is no guarantee that $m^{tgt}, m^{src}$ will not overlap each other, I believe the implementation and the description are not consistent.
  2. masks and blending: Please correct me if I'm wrong. In my understanding, alpha_e highlights tokens in target prompt that also appears in target blend prompt, while alpha_m highlights tokens in target prompt that are also present in source prompt. I'm quite confused by the implementation since it blends source x_s, mutual x_m and target x_t. But only the blending of source x_s and target x_t is mentioned in the paper.
  3. implementation of mapper function: The implementation of mapper function is drastically different from that in the Prompt-to-Prompt. As I've mentioned early in issue #21 , I don't get it why the cosine similarity searching is necessary, and consequently, why should the similarity searching alter the value of alpha_e (alpha_e[max_t] = 1). I could not find relevant information in the paper. Could you also elaborate a little more about the logic of mapper implementation?
  4. temperature: there is a temperature in local blend that controls thresh_e, could you please also explain a little about the design of that as well?
SihanXU commented 5 months ago

Hi fkcptlst, Thanks for your issue and carefully reviewing our code. For the first question and second question, I think they are some tiny bugs in our implementation. Replacing both of q_tgt, k_tgt will lead to kv mismatch as mentioned in https://arxiv.org/pdf/2403.02332.pdf Fig.5 and lead to worse performance. And we should use M_lay as our attention map as well. For the attention blending, we did make some updates and we will update this in the future version of our paper. As for the mapper function, this is just to make it easier for users to use the gradio demo without manually selecting whether to replace or refine mentioned in prompt2prompt. It has no relation with our methods in our paper. As for the temperature, I think it should be a feature left over from the development process, possibly due to the previous project not being cleaned up. Sorry for the confusion caused.

fkcptlst commented 5 months ago

Hi fkcptlst, Thanks for your issue and carefully reviewing our code. For the first question and second question, I think they are some tiny bugs in our implementation. Replacing both of q_tgt, k_tgt will lead to kv mismatch as mentioned in https://arxiv.org/pdf/2403.02332.pdf Fig.5 and lead to worse performance. And we should use M_lay as our attention map as well. For the attention blending, we did make some updates and we will update this in the future version of our paper. As for the mapper function, this is just to make it easier for users to use the gradio demo without manually selecting whether to replace or refine mentioned in prompt2prompt. It has no relation with our methods in our paper. As for the temperature, I think it should be a feature left over from the development process, possibly due to the previous project not being cleaned up. Sorry for the confusion caused.

Hi, thanks for your reply! Your work is really impressive by the way.

I'm working on a project that's based on what you've done, so I need to make sure I get some things right in my implementation.

  1. So, for the first question, I should replace k_tgt and v_tgt (like the paper says) in my refactored code, right?
  2. Regarding the second question, I should use $M^{lay}$ (also from the paper) in my refactored code, correct?
  3. About the temperature, would it be recommended to just remove it?

I'm still confused about the return values of the mapper. Here's my understanding and how it's actually implemented.

How I understand they should be:

  1. mapper: mapper[i] = j means tgt[i] = src[j]
  2. alphas: alphas[i] = 1 means tgt[i] has matching token from src
  3. m: unused
  4. alpha_e: alpha_e[i] = 1 iff. exists j s.t. tgt_blend[j] = tgt[i]. (tokens in tgt and tgt_blend)
  5. alpha_m: alpha_m[i] = 1 iff. exists j s.t. src[mapper[i]] = tgt[i] = src_blend[j] (tokens both in src, tgt and src_blend)

How it's actually implemented:

  1. mapper: mapper[i] = j means tgt[i] = src[j] or search based on embedding cosine similarity
  2. alphas: alphas[i] = 1 means tgt[i] has matching token from src
  3. m: a clone of mapper without search based on embedding cosine similarity, unused in later codes.
  4. alpha_e: alpha_e[i] = 1 means: exists j s.t. tgt_blend[j] = tgt[i], or src[mapper[i]] = tgt_blend[j]
  5. alpha_m: alpha_m[i] = 1 means: exists j s.t. src[mapper[i]] = tgt[i] = local_blend[j]

The parts that confuses me are highlighted by bold text. I don't understand why the embedding cosine similarity search is necessary.

h6kplus commented 5 months ago

For the difference in mapper and m, they we generally unused in most cases, since the embedding cosine similarity was introduced when the source prompt and the target prompt was different word-by-word but semantically similar. For the alpha_e, it should be alpha_e[i] = 1 means: exists j s.t. tgt_blend[j] = tgt[i], and for the latter part, it is just in case that there are some words were mapped using embedding cosine similarity. If you are using source prompt and target prompt like "a photo of a dog" and "a photo of a cat" it will have no difference, but if you are using the prompt "a picture of a cat", the embedding cosine similarity may automatically match "photo" and "picture" in the mapper.

fkcptlst commented 5 months ago

For the difference in mapper and m, they we generally unused in most cases, since the embedding cosine similarity was introduced when the source prompt and the target prompt was different word-by-word but semantically similar. For the alpha_e, it should be alpha_e[i] = 1 means: exists j s.t. tgt_blend[j] = tgt[i], and for the latter part, it is just in case that there are some words were mapped using embedding cosine similarity. If you are using source prompt and target prompt like "a photo of a dog" and "a photo of a cat" it will have no difference, but if you are using the prompt "a picture of a cat", the embedding cosine similarity may automatically match "photo" and "picture" in the mapper.

I get it now. Thanks for the example!