huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
25.01k stars 5.17k forks source link

auraflow: attention mask is not passed into the sdpa function #8886

Closed bghira closed 1 month ago

bghira commented 1 month ago

Describe the bug

        hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, scale=attn.scale, is_causal=False
        )

Reproduction

when we calculate the hidden states inside the auraflow attention class, we are not passing the attention mask into the sdpa function

this leads to non-zero attention scores on the padded positions of the input. when then training on long sequence lengths, the model is unnecessarily perturbed by the change. loss can be as high as 2.0! it is about as bad as reparameterising the model.

so that's an issue for another day, but we should at least make the attention mask optional in the transformer __call__ method that then passes it through to the attention class, similar to how deepfloyd handles them as an input to the unet __call__ method.

Logs

No response

System Info

diffusers git

Who can help?

@sayakpaul @yiyixuxu (and @DN6 since your tag is related to SD3)

sayakpaul commented 1 month ago

I don't see it being handled here and here as well, which is our point of reference as the original implementation.

So, if this becomes an issue for training, feel free to open a PR for it and we can take it from there.

bghira commented 1 month ago

yes, it's being resolved there too now - this issue wasn't identified early enough. it definitely harms training for Auraflow as well as SD3.

sayakpaul commented 1 month ago

Cool. Please ping us here when it's resolved upstream and we will get to it.

bghira commented 1 month ago

there is not much interest in resolving it since it's not really a problem for long prompts.