ostris / ai-toolkit

Various AI scripts. Mostly Stable Diffusion stuff.
MIT License
2.99k stars 290 forks source link

Misinterpretation of attention masking #192

Open AmericanPresidentJimmyCarter opened 3 hours ago

AmericanPresidentJimmyCarter commented 3 hours ago

The software doesn't implement attention masking correctly. Simply zeroing out the text embedding does not result in stopping the model from using registers, but does result in skewing the input distribution away from that typically used during sampling.

https://github.com/ostris/ai-toolkit/commit/94529293008684d8c90ebd6255c04052d2b71d52

If you want to use attention masking you need to use it in scaled dot product attention and that needs to be used during inference too:

https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html

Please refer to the SimpleTuner codebase where I have been maintaining the feature:

https://github.com/bghira/SimpleTuner/blob/main/helpers/models/flux/transformer.py

KohakuBlueleaf commented 3 hours ago

Agree with this At least it should provide an option for switch different masking type