Closed bghira closed 1 month ago
yes, it's being resolved there too now - this issue wasn't identified early enough. it definitely harms training for Auraflow as well as SD3.
Cool. Please ping us here when it's resolved upstream and we will get to it.
there is not much interest in resolving it since it's not really a problem for long prompts.
Describe the bug
Reproduction
when we calculate the hidden states inside the auraflow attention class, we are not passing the attention mask into the sdpa function
this leads to non-zero attention scores on the padded positions of the input. when then training on long sequence lengths, the model is unnecessarily perturbed by the change. loss can be as high as 2.0! it is about as bad as reparameterising the model.
so that's an issue for another day, but we should at least make the attention mask optional in the transformer
__call__
method that then passes it through to the attention class, similar to how deepfloyd handles them as an input to the unet__call__
method.Logs
No response
System Info
diffusers git
Who can help?
@sayakpaul @yiyixuxu (and @DN6 since your tag is related to SD3)