Closed Birch-san closed 1 year ago
That's a good point! Also cc @williamberman and @patil-suraj here.
Stable Diffusion never used a attention mask when training, however some other models that make use of UNet2DCondition
use it (such as UnCLIP). With more and more people fine-tuning stable diffusion we could actually allow attention_mask
to also be used in Stable Diffusion.
@patil-suraj @williamberman wdyt?
I see two masking use-cases:
waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).
but, what should the API be? attention_mask
could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both.
is cross_attention_kwargs
a useful piece in this puzzle? are we supposed to do something like cross_attention_kwargs = { 'attention_mask': my_cool_mask }
?
side-note: why is the class named CrossAttention
? it's super confusing, given that it implements self-attention too. would MultiHeadAttention
be a better name?
side-note 2: is there a reason why you didn't use PyTorch's native torch.nn.MultiheadAttention
? you can swap diffusers' CrossAttention
class for the native implementation pretty easily.
Hey @Birch-san,
Thanks for your thoughtful answer here. I see the need to pass customized attention_masks
. For this IMO the user should set up a customized attention processor class as was merged in: https://github.com/huggingface/diffusers/pull/1639 . Now since we need different attention processors for the self- and cross-attention layers we need to leverage some new functionality that will be added as part of this PR (hope to get it merged this week).
As you can see in the comment, it allows to set layers depending on the weight name which should make it to set attention processors only to the self- or cross-attention processors. Does this make sense?
Regarding the naming, yes I think you're right here - we should give it a better name. Would you like to open a new seperate issue for this as this new issue would be quite actionable? :-)
On the topic of masking when training with padded pixels? Conv doesn't allow mask. Would the padded value be zero and passed into conv, then masked at attention? Would the model still function right when conv might learn the letterbox edges, while attention ignores them?
Good point @Lime-Cakes!
Actually for me the attention mask only really makes sense for cross attention as well
I see two masking use-cases:
* self-attention mask (if you're training on a batch of images with mixed aspect ratios: you can tell it not to attend to padding pixels) * cross-attention mask (don't attend to PAD token embeddings in CLIP text condition)
waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).
but, what should the API be?
attention_mask
could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both. iscross_attention_kwargs
a useful piece in this puzzle? are we supposed to do something likecross_attention_kwargs = { 'attention_mask': my_cool_mask }
?side-note: why is the class named
CrossAttention
? it's super confusing, given that it implements self-attention too. wouldMultiHeadAttention
be a better name? side-note 2: is there a reason why you didn't use PyTorch's nativetorch.nn.MultiheadAttention
? you can swap diffusers'CrossAttention
class for the native implementation pretty easily.
Do you have any update for the result of the training using cross attention mask? Does it work?
want to give +1 to importance of having cross-attention masks. for SD having too much padding tokens affects the image generation (this is probably one of the reason people found longer prompts to work better, to avoid too much attention drawn to the PAD token). I don't see why padding/not padding should change the output image even a tiny bit, so proper support for masks is required
upd. it seems that currently the proper masks can be supported by writing a custom CrossAttnProcessor
and passing cross_attention_kwargs = { 'attention_mask_': my_cool_mask }
. using the default name attention_mask
raises errors
It indeed makes sense to pass an attention mask to cross attention.But the stable diffusion model is trained without the masks and to keep the implem 1:1 with the original implem we don't pass attention mask. But it makes sense to allow passing masks optionally.
Having the option to use it would be great! It should by off be default, to produce same training/inference result as the original implementation.
A lot people are experimenting with fine tuning with more tokens, meaning more padding are used. A few test runs suggest having too much padded tokens fed into unet for training causes lower quality. Personally, I think masking at unet cross attention should be a solution to this issue.
I ended up using such implementation which works nicely. maybe this could make it to the upstream? Currently this could only be implemented for default CrossAttention, because xformers do not support custom attention bias to be passed (but I think they are working on it).
class FixedCrossAttnProcessor:
"""Copy-paste from HF diffusers, but with support for passing `encoder_attention_mask` which avoids giving
attention to padded tokens"""
def __call__(
self,
attn: CrossAttention,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
encoder_attention_mask=None,
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
query = attn.to_q(hidden_states)
if encoder_attention_mask is not None and encoder_hidden_states is not None:
# B x 77 -> B x 4096 x 77
attention_mask = encoder_attention_mask.unsqueeze(1).repeat(1, hidden_states.size(1), 1)
attention_mask = attention_mask.repeat_interleave(attn.heads, dim=0)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.cross_attention_norm:
encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
and then call like this:
noise_pred = unet(samle, timestep, encoder_hidden_states, cross_attention_kwargs=dict(encoder_attention_mask=mask))
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@Birch-san In our experiment, we find attention mask results in inferior generation compared to unmasked version, and here's our potential examplanation: https://github.com/audiojourney/audiojourney.github.io/blob/main/neurIPS_2023_appendix_v1.3.pdf
want to give +1 to importance of having cross-attention masks. for SD having too much padding tokens affects the image generation (this is probably one of the reason people found longer prompts to work better, to avoid too much attention drawn to the PAD token). I don't see why padding/not padding should change the output image even a tiny bit, so proper support for masks is required
upd. it seems that currently the proper masks can be supported by writing a custom
CrossAttnProcessor
and passingcross_attention_kwargs = { 'attention_mask_': my_cool_mask }
. using the default nameattention_mask
raises errors
Hi, glad to find your answer! I am a green-hand in SD, and I am recently also thinking about the effect of pad tokens during inference. Could please tell me where I can find the observation or discussion about " for SD having too much padding tokens affects the image generation"? Any papers or blogs will be much appreciated! Thank you for guidance!
I tried passing in an
attention_mask
, for use in a stable-diffusion Unet but it doesn't actually get passed down as deep asCrossAttention#forward
.I tried fixing it to pass the param down, but it blows up on tensor size mismatch, because self-attention and cross-attention have different masking requirements.
I made my own implementation of cross-attention masking a few weeks ago (before the refactor) but never upstreamed it. mainly because I didn't understand whether I'd done it correctly (I re-used the lucidrains implementation that CompVis used):
https://github.com/huggingface/diffusers/commit/cbb4c02e32ca74b7ac539857c91aa06bd4c9338c
EDIT: rebased implementation to show how it would fit in with the existing attention masking and the refactored attention:
https://github.com/Birch-san/diffusers/commit/e3a93e9d80a6b4e5122e5b9d02ad4ee60c7d1354
I explicitly named the parameter as a cross attention mask, because a self-attention mask has entirely different requirements.
in terms of wider API design, I wonder whether it should be an attention map (i.e. so you can use it to increase/decrease attention scores for certain token embeds). but for now I'm mostly interested in the mask aspect. because waifu-diffusion makes use of "multiple CLIP embeddings stitched together", so attention masking is useful to avoid attending to padding token embeddings, which would be biased towards conveying high-level semantic of the final CLIP segment only.
@patrickvonplaten