Unbabel / COMET

A Neural Framework for MT Evaluation
https://unbabel.github.io/COMET/html/index.html
Apache License 2.0
453 stars 72 forks source link

Not compatible with recent transformers? #169

Closed BramVanroy closed 9 months ago

BramVanroy commented 9 months ago

🐛 Bug

The following code does not work

import comet

if __name__ == '__main__':
    scorer = comet.load_from_checkpoint(comet.download_model("Unbabel/wmt22-comet-da"))

    data = [{"src": "This is a testing phase.", "ref": "Dit is een test.", "mt": "Dees zen test gevalle"}]

    scorer.predict(data)

Error trace:

  File "C:\Users\bramv\AppData\Roaming\JetBrains\PyCharm2023.2\scratches\scratch_6.py", line 8, in <module>
    scorer.predict(data)
  File "F:\python\mateo-demo\.venv\lib\site-packages\comet\models\base.py", line 627, in predict
    predictions = trainer.predict(
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 865, in predict
    return call._call_and_handle_interrupt(
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 904, in _predict_impl
    results = self._run(model, ckpt_path=ckpt_path)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 990, in _run
    results = self._run_stage()
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 1031, in _run_stage
    return self.predict_loop.run()
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\loops\utilities.py", line 181, in _decorator
    return loop_run(self, *args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\loops\prediction_loop.py", line 122, in run
    self._predict_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\loops\prediction_loop.py", line 250, in _predict_step
    predictions = call._call_strategy_hook(trainer, "predict_step", *step_args)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\trainer\call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\pytorch_lightning\strategies\strategy.py", line 429, in predict_step
    return self.lightning_module.predict_step(*args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\comet\models\base.py", line 430, in predict_step
    model_outputs = Prediction(scores=self(**batch).score)
  File "F:\python\mateo-demo\.venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "F:\python\mateo-demo\.venv\lib\site-packages\comet\models\regression\regression_metric.py", line 273, in forward
    return self.estimate(src_sentemb, mt_sentemb, ref_sentemb)
  File "F:\python\mateo-demo\.venv\lib\site-packages\comet\models\regression\regression_metric.py", line 245, in estimate
    return Prediction(score=self.estimator(embedded_sequences).view(-1))
  File "F:\python\mateo-demo\.venv\lib\site-packages\transformers\utils\generic.py", line 327, in __init__
    raise TypeError(
TypeError: comet.models.utils.Prediction is not a dataclasss. This is a subclass of ModelOutput and so must use the @dataclass decorator.

Seems that you subclass ModelOutput but transformers expects something different. I tried adding @dataclass decorator but that did not completely solve the issue either.

Environment

Using comet 2.1.0 and transformers 4.34.0

Seems that this is the commit that broke it: https://github.com/huggingface/transformers/pull/25638. Downgrading to transformers 4.33.3 works.

ricardorei commented 9 months ago

Hey @BramVanroy But i am not sure if I understand the error because the last version adds the @dataclass decorator to the Prediction class

ricardorei commented 9 months ago

you sure you are using 2.1.0?

BramVanroy commented 9 months ago

Ah, that change is part of the 2.1.1 release, I was still on 2.1.0. Works on 2.1.1!