alan-turing-institute / ARC-MTQE

Critical Error Detection for Machine Translation
MIT License
1 stars 0 forks source link

Load checkpoint #55

Closed radka-j closed 6 months ago

radka-j commented 6 months ago

Closes #37

Updates train_ced.py script to look for whether the config file contains:

model_path:
  path: <path>

If not, loads COMETKiwi checkpoint from HuggingFace as before. Otherwise loads them from the provided file.

Running the training script with any of the existing configs (that do not have a model_path in the config) starts training the COMETKiwi model as we've been doing until now.

I also tried running the training script with a path to my local COMET-QE checkpoints in the config which failed because the estimators are different shapes. But at least it shows the path was getting picked up and the weights from that file loaded.

Needs to be tested with a valid COMETKiwi checkpoint to make sure all works.

joannacknight commented 6 months ago

@radka-j - this is ready for you to review. I can't add you as a reviewer again, perhaps because you created the PR?

joannacknight commented 6 months ago

Should have said: this PR includes changes to load checkpoint from a file and either make predictions or train the model. In doing so I've moved some of the functionality from train_ced.py to utils and loaders. I may not have moved everything to the 'best' place!

joannacknight commented 6 months ago

Yes! great idea. Will do that on Wednesday

joannacknight commented 6 months ago

I've done a sanity check, as you suggest on one of the existing checkpoints with the dev data and did get the same MCC as on WandB

Experiment: en-cs_dev_train_monolingual_auth_data__en_cs10720240418_105335 Checkpoint: epoch=8-step=270.ckpt DEV MCC: 0.47

joannacknight commented 6 months ago

Just updated the filename of the preds as it still had the .ckpt suffix in it.

Just a note that there is still more coding to do in order to turn the predictions into an evaluation - I just ran a few commands in interactive mode to get the MCC above and haven't written a script to do it yet.

joannacknight commented 6 months ago

As well as updating the config to read in the correct files, I made a small update to load_model_from_file in train_ced.py so that all_multilingual_demetr can be selected when the language pairs are all