bjascob / amrlib

A python library that makes AMR parsing, generation and visualization simple.
MIT License
216 stars 33 forks source link

Training error: unexpected keyword argument '_internal_call' #52

Closed plandes closed 1 year ago

plandes commented 1 year ago

I'm training a new corpus from the checkpoint of the pretrained xfm_base using amrlib.models.parse_xfm.Trainer programmatically. However, there appears to be a kwarg added that isn't provided in the HuggingFace transformers library (see the stack trace below).

I am using transformers version 4.11.3.

  File "/work/python/lib/python3.9/site-packages/amrlib/models/parse_xfm/trainer.py", line 64, in train
    trainer.train(resume_from_checkpoint=self.training_args.resume_from_checkpoint)
  File "/work/python/lib/python3.9/site-packages/transformers/trainer.py", line 1391, in train
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/work/python/lib/python3.9/site-packages/transformers/trainer.py", line 1495, in _maybe_log_save_evaluate
    self._save_checkpoint(model, trial, metrics=metrics)
  File "/work/python/lib/python3.9/site-packages/amrlib/models/parse_xfm/amr_trainer.py", line 107, in _save_checkpoint
    self.save_model(new_chk_fpath, _internal_call=True)
TypeError: save_model() got an unexpected keyword argument '_internal_call'

Proposed change

I imagine this kwarg was needed for a previous or future version of transformers. If that's the case, branching on the transformers version would probably be the right choice. For clarity, I am giving the diff that fixed this issue for me:

diff --git a/amrlib/models/parse_xfm/amr_trainer.py b/amrlib/models/parse_xfm/amr_trainer.py
index d8a9347..1dc055e 100644
--- a/amrlib/models/parse_xfm/amr_trainer.py
+++ b/amrlib/models/parse_xfm/amr_trainer.py
@@ -104,7 +104,7 @@ class AMRTrainer(HFTrainer):
             self.state.best_metric           = smatch_val
             self.state.best_model_checkpoint = new_chk_fpath
             # Save the model and trainer state
-            self.save_model(new_chk_fpath, _internal_call=True)
+            self.save_model(new_chk_fpath)
             self.state.save_to_json(os.path.join(new_chk_fpath, TRAINER_STATE_NAME))
         else:
             print(f"Checkpoint not saved. Smatch score lower than best.")

I'm happy to create a pull request for this if you tell me if you want the version check and what that should be.

PS--There also might be a way to use Python introspection to get the allowed kwargs for the method.

bjascob commented 1 year ago

You need to upgrade your transformers version. This param is needed for later versions of the transformers lib. I'll have to go back through the transformers history and see when the code was added but the parse_xfm code was added in in February and transformers 4.11.3 was from October 2021 so it must have been added in a release between those time-frames.

I'll update amrlib's requirements to reflect the need for a later version of the transformers lib.

bjascob commented 1 year ago

The _internal_call parameter showed up in transformers v4.16.0. The call save_model must be called with this equal to True to prevent the code from attempting to push the model to the HF hub each save during training.

plandes commented 1 year ago

@bjascob Thank you for such a quick response and looking in to this issue. This makes my work much easier.

Regarding your requirements.txt commit: may I recommend you change the requirements to:

transformers~=4.16.0

To nail it down to a 4.16 release? Then at least pip will whine if the dependencies are not copasetic when installing new packages and protects this code from later releases since it does appear to be a specific implementation detail rather than a class contract issue.

bjascob commented 1 year ago

My goal is to keep it up-to-date with the latest releases, not a specific one. Most people use this for inference which works fine over a wide range of sub-library versions. I'll consider adding some notes to the readme or install instructions on the "recommended" versions so avoid potential issues.

plandes commented 1 year ago

Sure, but if you at least change the requirements to 4.16 or above you might have others in my situation with an older version already installed.

Thanks again.