The software doesn't implement attention masking correctly. Simply zeroing out the text embedding does not result in stopping the model from using registers, but does result in skewing the input distribution away from that typically used during sampling.
The software doesn't implement attention masking correctly. Simply zeroing out the text embedding does not result in stopping the model from using registers, but does result in skewing the input distribution away from that typically used during sampling.
https://github.com/ostris/ai-toolkit/commit/94529293008684d8c90ebd6255c04052d2b71d52
If you want to use attention masking you need to use it in scaled dot product attention and that needs to be used during inference too:
https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
Please refer to the SimpleTuner codebase where I have been maintaining the feature:
https://github.com/bghira/SimpleTuner/blob/main/helpers/models/flux/transformer.py