awslabs / mlm-scoring

Python library & examples for Masked Language Model Scoring (ACL 2020)
https://www.aclweb.org/anthology/2020.acl-main.240/
Apache License 2.0
333 stars 59 forks source link

PyTorch models #2

Closed gerardb7 closed 3 years ago

gerardb7 commented 3 years ago

Hi, It seems that support for PyTorch models is currently limited to bert and xlm. Would it be possible to add support for lighter models, e.g. DistilBERT or ALBERT? Do you think that using these models would hurt the performance of the scorers significantly?

Thanks!

JulianSlzr commented 3 years ago

Thanks for the suggestion! I updated to Transformers 3.3.1 and added DistilBERT & ALBERT. The main thing is defining an *BERTForMaskedLMOptimized class for speed. You can follow my example to add support for other MLMs. Pull requests welcome 🙂.

I also ran the two on BLiMP. Though ALBERT improves on downstream tasks like GLUE over RoBERTa, its BLiMP scoring is on par with BERT. Likewise, DistilBERT does similar to BERT on GLUE, but is much worse on BLiMP (78% vs 84%).

Possible takeaways:

# distilbert-base-cased
anaphor_agreement:  0.983
argument_structure:  0.7857777777777778
binding:  0.7335714285714285
control_raising:  0.7788
determiner:  0.970375
ellipsis:  0.915
filler_gap:  0.7464285714285716
irregular_forms:  0.9555
island_effects:  0.54925
npi_licensing:  0.7901428571428571
quantifiers:  0.5895
subject_verb_agreement:  0.8965000000000001
overall:  0.782955223880597

# albert-xxlarge-v2
anaphor_agreement:  0.956
argument_structure:  0.8375555555555555
binding:  0.7912857142857143
control_raising:  0.865
determiner:  0.9395
ellipsis:  0.8735
filler_gap:  0.8188571428571427
irregular_forms:  0.9255
island_effects:  0.74975
npi_licensing:  0.9115714285714285
quantifiers:  0.6739999999999999
subject_verb_agreement:  0.8808333333333334
overall:  0.8435820895522389
gerardb7 commented 3 years ago

Grand job, thanks a lot!

Ago3 commented 3 years ago

Hi,

I'm extending the framework to include another PyTorch model. When using MLMScorerPT we don't need to pass a vocab, do we? I couldn't find any function where it is actually used..

Thank you!

PS: Very cool work :)