huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.06k stars 26.31k forks source link

Some casualLM models don't get position_ids in their forward pass. #32937

Open avishaiElmakies opened 3 weeks ago

avishaiElmakies commented 3 weeks ago

Feature request

There are some models such that their forward pass doesn't get position_ids. e.g. we can see that OPTModel doesn't get position_ids, while GPTJModel does get position_ids. most newer models do have position_ids.

Motivation

There are two main reasons we would like for all LM models to get positions ids.

  1. to have the API be consistent with all models.
  2. position_ids are very important if you want to use flash-attention without padding, during training. if i want to be able to pack two or more sentences in the same sequence. I would like to know that the model handles the sentences accordingly and treats each sentence as it's own different sentence. flash-attention code uses position_ids to check if some sequences are packed and runs an appropriate function to make sure there is no cross example contamination. but without this the model can't use this feature. the code always checks if position_ids is not None:

https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/modeling_flash_attention_utils.py#L270

Your contribution

I may be able to fix this and help with a PR. but would love a more experienced person to guide me.

ArthurZucker commented 2 weeks ago

I really don't mind having better support for this 🤗 As long as we follow the way it is done for gemma or llama

avishaiElmakies commented 2 weeks ago

OK, I might try to do the OPT one as a first try. I should also say that i can only help with pytorch as i know nothing about jax and keras. I will look to gemma or llama for some inspiration.

If it works well, I might try my hand in other models.