Closed bduclaux closed 2 years ago
Good catch, thanks for opening the issue! Would you be interested to open a PR to add this function in flax mT5?
Hey Suraj,
Sure ! Will do it tomorrow and keep you posted. Thanks !
Also noticed that adafactor
parameter is not used to instantiate the optimizer in the run_summarization_flax.py
script.
Will add it in my PR, to have support for both adamW and adafactor.
Also noticed that
adafactor
parameter is not used to instantiate the optimizer in therun_summarization_flax.py
script. Will add it in my PR, to have support for both adamW and adafactor.
Ahh, Good catch! It would be better two open a new PR for this since these are two different changes.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Bump - will do PR soon
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
Will make a PR for this shortly :)
Environment info
transformers
version: 4.17.0.dev0Who can help
@patil-suraj
Information
Hello
I am trying to finetune with Flax on TPU a mt5-small model on a summarization task, using the examples/flax/summarization/run_summarization_flax.py script. When I run the script, I get an error about the fact that shift_tokens_right is not defined for mt5 models:
``
File "/home/prod/transformers/examples/flax/summarization/run_summarization_flax.py", line 521, in main
shift_tokens_right_fn = getattr(model_module, "shift_tokens_right")
AttributeError: module 'transformers.models.mt5.modeling_flax_mt5' has no attribute 'shift_tokens_right' ``
Moreover, the current flax summarization script has a typo line 516 :
See https://github.com/huggingface/transformers/blob/master/examples/flax/summarization/run_summarization_flax.py#L516 :
model_module = __import__(model.__module__, fromlist=["shift_tokens_tight"])
(tight should be right !)
I have been able to fix the issue by copying the function shift_tokens_right defined in
src/transformers/models/t5/modeling_flax_mt5.py
into the filesrc/transformers/models/mt5/modeling_flax_mt5.py
. Now the Flax summarization script works fine.Hope you can fix the mt5 code accordingly !