huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
131.74k stars 26.23k forks source link

"inverted" form required for 4D masking not defined / 4D attention masks breaks with transformers >=4.40 #32195

Open jpgard opened 1 month ago

jpgard commented 1 month ago

System Info

transformers==4.43

Who can help?

@ArthurZucker @poeda

Information

Tasks

Reproduction

  1. Take a working fine-tuning pipeline that uses custom 4D attention masks in transformers 4.40 and fine-tunes a Llama 3 model
  2. Run that same pipeline with transformers 4.41 (or the most recent version, 4.43)

Expected behavior

I expect that behavior with 4D attention masking will stay consistent from 4.40 to 4.43. However, I understand that the 4D masking was a new feature, and perhaps some changes were necessary in order to make it work with the rest of the framework.

First, thanks again for the implementation of 4D masking. This is really useful to my work and was critical for us in developing+releasing our recent work on TabuLa.

It seems that perhaps a breaking change was introduced to masking, specifically in this PR, where masks were no longer "negated" for the user. After this change, it appears that masks that previously worked (before the PR) now need to be "negated" in order to work; otherwise ValueError is raised here when fine-tuning Llama model.

However, to me it's not clear what "negation" actually means. Negation doesn't appear to be documented anywhere. Furthermore, it seems easy to make an attention mask that would pass this block (i.e., having a max value of zero) but that might be incorrect in other ways. It seems like there is some negation logic here, but this won't work for a typical binary attention mask: doing 1.0 - attention_mask simply flips the mask, so if there were any zero entries before they will now be 1, triggering the same ValueError as above.

So, in this issue I have the following question:

And in this issue I also suggest the following changes:

Happy to contribute to this if someone can provide answers to these questions -- again, this is a terrific capability to have in the library and I am super grateful to the team for the work on it!

jpgard commented 1 month ago

linking https://github.com/huggingface/transformers/pull/27539 in case readers there encounter this issue

ArthurZucker commented 1 month ago

Hey! Sorry for the frustration this caused, we realized that maintaining both path for mask was a bit cumbersome.

What is a "negated" mask, and how can I get from a "standard" binary attention mask used elsewhere in the transformers library to a "negated" attention mask that works with the new 4D attention masking scheme?

the negated mask has 0 where you want to pay attention, and torch.finfo(model.dtype).max (so super big) value when you do not want to pay attention.

The idea is that we don't apply anything to it, this way if it does not work it's not "our fault". Creating a 4d attention mask can be challenging, which is why we rather leave it to the user! z

jpgard commented 1 month ago

Thanks for the clarification.

Just to make sure I understand: is this a breaking change?

By that I mean: if I take a model that was trained with transformers==4.40, with the old masking scheme, and then use that model with transformers==4.43, with the new masking scheme, should I expect the same behavior (i.e. exact same predictions/logits from the model for the same inputs)?

ArthurZucker commented 1 month ago

You should expect the same behaviour for sure. You can't run the same training script, but if you pass a 2D mask you should have no differences

github-actions[bot] commented 1 week ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.