huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.09k stars 5.38k forks source link

Does attention masking actually work? #1890

Closed Birch-san closed 1 year ago

Birch-san commented 1 year ago

I tried passing in an attention_mask, for use in a stable-diffusion Unet but it doesn't actually get passed down as deep as CrossAttention#forward.

I tried fixing it to pass the param down, but it blows up on tensor size mismatch, because self-attention and cross-attention have different masking requirements.

I made my own implementation of cross-attention masking a few weeks ago (before the refactor) but never upstreamed it. mainly because I didn't understand whether I'd done it correctly (I re-used the lucidrains implementation that CompVis used):
https://github.com/huggingface/diffusers/commit/cbb4c02e32ca74b7ac539857c91aa06bd4c9338c
EDIT: rebased implementation to show how it would fit in with the existing attention masking and the refactored attention:
https://github.com/Birch-san/diffusers/commit/e3a93e9d80a6b4e5122e5b9d02ad4ee60c7d1354

I explicitly named the parameter as a cross attention mask, because a self-attention mask has entirely different requirements.

in terms of wider API design, I wonder whether it should be an attention map (i.e. so you can use it to increase/decrease attention scores for certain token embeds). but for now I'm mostly interested in the mask aspect. because waifu-diffusion makes use of "multiple CLIP embeddings stitched together", so attention masking is useful to avoid attending to padding token embeddings, which would be biased towards conveying high-level semantic of the final CLIP segment only.

@patrickvonplaten

patrickvonplaten commented 1 year ago

That's a good point! Also cc @williamberman and @patil-suraj here.

Stable Diffusion never used a attention mask when training, however some other models that make use of UNet2DCondition use it (such as UnCLIP). With more and more people fine-tuning stable diffusion we could actually allow attention_mask to also be used in Stable Diffusion.

@patil-suraj @williamberman wdyt?

Birch-san commented 1 year ago

I see two masking use-cases:

waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).

but, what should the API be? attention_mask could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both.
is cross_attention_kwargs a useful piece in this puzzle? are we supposed to do something like cross_attention_kwargs = { 'attention_mask': my_cool_mask }?

side-note: why is the class named CrossAttention? it's super confusing, given that it implements self-attention too. would MultiHeadAttention be a better name?
side-note 2: is there a reason why you didn't use PyTorch's native torch.nn.MultiheadAttention? you can swap diffusers' CrossAttention class for the native implementation pretty easily.

patrickvonplaten commented 1 year ago

Hey @Birch-san,

Thanks for your thoughtful answer here. I see the need to pass customized attention_masks . For this IMO the user should set up a customized attention processor class as was merged in: https://github.com/huggingface/diffusers/pull/1639 . Now since we need different attention processors for the self- and cross-attention layers we need to leverage some new functionality that will be added as part of this PR (hope to get it merged this week).

As you can see in the comment, it allows to set layers depending on the weight name which should make it to set attention processors only to the self- or cross-attention processors. Does this make sense?

patrickvonplaten commented 1 year ago

Regarding the naming, yes I think you're right here - we should give it a better name. Would you like to open a new seperate issue for this as this new issue would be quite actionable? :-)

Lime-Cakes commented 1 year ago

On the topic of masking when training with padded pixels? Conv doesn't allow mask. Would the padded value be zero and passed into conv, then masked at attention? Would the model still function right when conv might learn the letterbox edges, while attention ignores them?

patrickvonplaten commented 1 year ago

Good point @Lime-Cakes!

Actually for me the attention mask only really makes sense for cross attention as well

Lime-Cakes commented 1 year ago

I see two masking use-cases:

* self-attention mask (if you're training on a batch of images with mixed aspect ratios: you can tell it not to attend to padding pixels)

* cross-attention mask (don't attend to PAD token embeddings in CLIP text condition)

waifu-diffusion already wants to use cross-attention masks for fine-tuning stable-diffusion (in fact will begin a big training run in a few days).

but, what should the API be? attention_mask could refer to self-attention or cross-attention. but it'd be catastrophic to pass it to both. is cross_attention_kwargs a useful piece in this puzzle? are we supposed to do something like cross_attention_kwargs = { 'attention_mask': my_cool_mask }?

side-note: why is the class named CrossAttention? it's super confusing, given that it implements self-attention too. would MultiHeadAttention be a better name? side-note 2: is there a reason why you didn't use PyTorch's native torch.nn.MultiheadAttention? you can swap diffusers' CrossAttention class for the native implementation pretty easily.

Do you have any update for the result of the training using cross attention mask? Does it work?

bonlime commented 1 year ago

want to give +1 to importance of having cross-attention masks. for SD having too much padding tokens affects the image generation (this is probably one of the reason people found longer prompts to work better, to avoid too much attention drawn to the PAD token). I don't see why padding/not padding should change the output image even a tiny bit, so proper support for masks is required

upd. it seems that currently the proper masks can be supported by writing a custom CrossAttnProcessor and passing cross_attention_kwargs = { 'attention_mask_': my_cool_mask }. using the default name attention_mask raises errors

patil-suraj commented 1 year ago

It indeed makes sense to pass an attention mask to cross attention.But the stable diffusion model is trained without the masks and to keep the implem 1:1 with the original implem we don't pass attention mask. But it makes sense to allow passing masks optionally.

Lime-Cakes commented 1 year ago

Having the option to use it would be great! It should by off be default, to produce same training/inference result as the original implementation.

A lot people are experimenting with fine tuning with more tokens, meaning more padding are used. A few test runs suggest having too much padded tokens fed into unet for training causes lower quality. Personally, I think masking at unet cross attention should be a solution to this issue.

bonlime commented 1 year ago

I ended up using such implementation which works nicely. maybe this could make it to the upstream? Currently this could only be implemented for default CrossAttention, because xformers do not support custom attention bias to be passed (but I think they are working on it).

class FixedCrossAttnProcessor:
    """Copy-paste from HF diffusers, but with support for passing `encoder_attention_mask` which avoids giving 
    attention to padded tokens"""
    def __call__(
        self,
        attn: CrossAttention,
        hidden_states,
        encoder_hidden_states=None,
        attention_mask=None,
        encoder_attention_mask=None,
    ):
        batch_size, sequence_length, _ = hidden_states.shape
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
        query = attn.to_q(hidden_states)

        if encoder_attention_mask is not None and encoder_hidden_states is not None:
            # B x 77 -> B x 4096 x 77
            attention_mask = encoder_attention_mask.unsqueeze(1).repeat(1, hidden_states.size(1), 1)
            attention_mask = attention_mask.repeat_interleave(attn.heads, dim=0)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.cross_attention_norm:
            encoder_hidden_states = attn.norm_cross(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        return hidden_states

and then call like this:

noise_pred = unet(samle, timestep, encoder_hidden_states, cross_attention_kwargs=dict(encoder_attention_mask=mask))
github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

lijuncheng16 commented 1 year ago

@Birch-san image In our experiment, we find attention mask results in inferior generation compared to unmasked version, and here's our potential examplanation: https://github.com/audiojourney/audiojourney.github.io/blob/main/neurIPS_2023_appendix_v1.3.pdf

Ga-Lee commented 3 days ago

want to give +1 to importance of having cross-attention masks. for SD having too much padding tokens affects the image generation (this is probably one of the reason people found longer prompts to work better, to avoid too much attention drawn to the PAD token). I don't see why padding/not padding should change the output image even a tiny bit, so proper support for masks is required

upd. it seems that currently the proper masks can be supported by writing a custom CrossAttnProcessor and passing cross_attention_kwargs = { 'attention_mask_': my_cool_mask }. using the default name attention_mask raises errors

Hi, glad to find your answer! I am a green-hand in SD, and I am recently also thinking about the effect of pad tokens during inference. Could please tell me where I can find the observation or discussion about " for SD having too much padding tokens affects the image generation"? Any papers or blogs will be much appreciated! Thank you for guidance!