huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
134.64k stars 26.92k forks source link

shift_tokens_right function missing for mt5 models #15771

Closed bduclaux closed 2 years ago

bduclaux commented 2 years ago

Environment info

Who can help

@patil-suraj

Information

Hello

I am trying to finetune with Flax on TPU a mt5-small model on a summarization task, using the examples/flax/summarization/run_summarization_flax.py script. When I run the script, I get an error about the fact that shift_tokens_right is not defined for mt5 models:

``
File "/home/prod/transformers/examples/flax/summarization/run_summarization_flax.py", line 521, in main

shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")

AttributeError: module 'transformers.models.mt5.modeling_flax_mt5' has no attribute 'shift_tokens_right' ``

Moreover, the current flax summarization script has a typo line 516 :

See https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py#L516 : model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])

(tight should be right !)

I have been able to fix the issue by copying the function shift_tokens_right defined in src/transformers/models/t5/modeling_flax_mt5.py into the file src/transformers/models/mt5/modeling_flax_mt5.py . Now the Flax summarization script works fine.

Hope you can fix the mt5 code accordingly !

patil-suraj commented 2 years ago

Good catch, thanks for opening the issue! Would you be interested to open a PR to add this function in flax mT5?

bduclaux commented 2 years ago

Hey Suraj,

Sure ! Will do it tomorrow and keep you posted. Thanks !

bduclaux commented 2 years ago

Also noticed that adafactor parameter is not used to instantiate the optimizer in the run_summarization_flax.py script. Will add it in my PR, to have support for both adamW and adafactor.

patil-suraj commented 2 years ago

Also noticed that adafactor parameter is not used to instantiate the optimizer in the run_summarization_flax.py script. Will add it in my PR, to have support for both adamW and adafactor.

Ahh, Good catch! It would be better two open a new PR for this since these are two different changes.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

bduclaux commented 2 years ago

Bump - will do PR soon

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

patil-suraj commented 2 years ago

Will make a PR for this shortly :)