huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.87k stars 27.19k forks source link

Some casualLM models don't get position_ids in their forward pass. #32937

Open avishaiElmakies opened 3 months ago

avishaiElmakies commented 3 months ago

Feature request

There are some models such that their forward pass doesn't get position_ids. e.g. we can see that OPTModel doesn't get position_ids, while GPTJModel does get position_ids. most newer models do have position_ids.

Motivation

There are two main reasons we would like for all LM models to get positions ids.

  1. to have the API be consistent with all models.
  2. position_ids are very important if you want to use flash-attention without padding, during training. if i want to be able to pack two or more sentences in the same sequence. I would like to know that the model handles the sentences accordingly and treats each sentence as it's own different sentence. flash-attention code uses position_ids to check if some sequences are packed and runs an appropriate function to make sure there is no cross example contamination. but without this the model can't use this feature. the code always checks if position_ids is not None:

https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/modeling_flash_attention_utils.py#L270

Your contribution

I may be able to fix this and help with a PR. but would love a more experienced person to guide me.

ArthurZucker commented 3 months ago

I really don't mind having better support for this 🤗 As long as we follow the way it is done for gemma or llama

avishaiElmakies commented 3 months ago

OK, I might try to do the OPT one as a first try. I should also say that i can only help with pytorch as i know nothing about jax and keras. I will look to gemma or llama for some inspiration.

If it works well, I might try my hand in other models.