huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.62k stars 26.92k forks source link

The forced adoption of low_cpu_mem_usage behavior ruins models that require absolute position embedding. #33326

Closed rangehow closed 1 month ago

rangehow commented 2 months ago

The loading of Hugging Face models adheres to the following logic: models using device_map must forcibly enable low_cpu_mem_usage. (Or a user might have manually enabled low_cpu_mem_usage in from_pretrained without realizing its consequence) https://github.com/huggingface/transformers/blob/47b096412da9cbeb9351806e9f0eb70a693b2859/src/transformers/modeling_utils.py#L3321-L3325 https://github.com/huggingface/transformers/blob/47b096412da9cbeb9351806e9f0eb70a693b2859/src/transformers/modeling_utils.py#L3837-L3838 This behavior harms some models that require absolute positional encoding, such as the 'fsmt' model(like wmt19 winner model) which bridges Fairseq and Hugging Face.

Because tensors created on the meta device are all meaningless empty tensors, this renders the following initialization code completely ineffective. https://github.com/huggingface/transformers/blob/47b096412da9cbeb9351806e9f0eb70a693b2859/src/transformers/models/fsmt/modeling_fsmt.py#L1343-L1360

Interestingly, this behavior does not affect embed_tokens or other weights, as they are overwritten by the state_dict. https://github.com/huggingface/transformers/blob/47b096412da9cbeb9351806e9f0eb70a693b2859/src/transformers/modeling_utils.py#L818

So far, a optional approach is to include the weight matrix for embed_positions in the checkpoint converted from Fairseq, even though it can be calculated by rule. Given that this parameter count is negligible for today's large models, I believe it should have no adverse effects.

We just need to change the following one line. https://github.com/huggingface/transformers/blob/c6d2848a23aba42404784ba52e421ae7b8c68eda/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py#L250

If needed, I can make a PR to solve this.

LysandreJik commented 2 months ago

Hey @rangehow, thanks for opening an issue! Is this a widespread issue across models or only affecting FSMT from what you have observed?

rangehow commented 2 months ago

Hey @rangehow, thanks for opening an issue! Is this a widespread issue across models or only affecting FSMT from what you have observed?

This should be determined by the model's saving mechanism. Theoretically, as long as absolute position encoding is used and the entire position encoding weights are not saved (which is common, as sine position encoding can be directly calculated), this issue will occur. It’s difficult to say whether other model structures, besides FSMT, would have similar issues, because that requires a deep understanding of the whole hf model. Fixing this issue for all potential models would be challenging, so I was thinking, could we make a minor adjustment to the FSMT conversion script first to prevent machine translation practitioners from being unable to migrate from Fairseq to Transformers?

github-actions[bot] commented 1 month ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

ArthurZucker commented 1 month ago

@rangehow PR to update FSMT is welcome for sure, sorry for the dalay! 🤗