SapienzaNLP / spring

SPRING is a seq2seq model for Text-to-AMR and AMR-to-Text (AAAI2021).
Other
125 stars 23 forks source link

Error when loading checkpoint. #1

Closed mrdrozdov closed 3 years ago

mrdrozdov commented 3 years ago

RuntimeError: Error(s) in loading state_dict for AMRBartForConditionalGeneration: Unexpected key(s) in state_dict: "model.encoder.embed_backreferences.weight", "model.encoder.embed_backreferences.transform.weight", "model.encoder.embed_backreferences.transform.bias", "model.decoder.embed_backreferences.weight", "model.decoder.embed_backreferences.transform.weight", "model.decoder.embed_backreferences.transform.bias".

When running:

python bin/predict_amrs.py \
    --datasets <AMR-ROOT>/data/amrs/split/test/*.txt \
    --gold-path data/tmp/amr2.0/gold.amr.txt \
    --pred-path data/tmp/amr2.0/pred.amr.txt \
    --checkpoint runs/<checkpoint>.pt \
    --beam-size 5 \
    --batch-size 500 \
    --device cuda \
    --penman-linearization --use-pointer-tokens

With the http://nlp.uniroma1.it/AMR/AMR2.parsing-1.0.tar.bz2 checkpoint (AMR2.amr-lin3.pt).

Can those keys be ignored from the checkpoint?

mbevila commented 3 years ago

Thank you for pointing out the issue! Seems I have uploaded unpatched checkpoints.

Yeah, those parameters are not used and can be completely ignored.

I have added to the repo a simple script (bin/patch_legacy_checkpoint.py) to patch checkpoints so that they work with the code that you have. Meanwhile I'll upload patched files to the server.

Sorry for the inconvenience!

(Btw, I really liked DIORA/S-DIORA)

mbevila commented 3 years ago

I have uploaded new checkpoints. They don't need patching.

mrdrozdov commented 3 years ago

Thanks for the help!