allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.77k stars 2.25k forks source link

Loading a HuggingFace model into AllenNLP gives different predictions #5582

Closed santiagxf closed 2 years ago

santiagxf commented 2 years ago

Checklist

Description

When loading a model created with HuggingFace (transformer library) into AllenNLP, the models works but predictions are different from what I get if running the model with transformers. Both model are running on evaluation mode.

Details

I have a custom classification model trained using transformers library based on a BERT model. The model classifies text into 7 different categories. It is persisted in a directory using:

trainer.save_model(model_name)
tokenizer.save_pretrained(model_name)

I'm trying to load such persisted model using the allennlp library for further analysis. However, when running the model inside the allennlp framework, the model tends to predict very different from the predictions I get when I run it using transformers, which lead me think some part of the loading was not done correctly. There are no errors during the inference, it is just that the predictions don't match.

There is little documentation about how to load an existing model (persisted in a path for instance), so I'm wondering if someone faced the same situation before. There is just one example of how to do QA classification with ROBERTA, but couldn't extrapolate to what I'm looking for. This is how I'm loading the trained model:

transformer_vocab = Vocabulary.from_pretrained_transformer(model_name)
transformer_tokenizer = PretrainedTransformerTokenizer(model_name)
transformer_encoder = BertPooler(model_name)

params = Params(
    {
     "token_embedders": {
        "tokens": {
          "type": "pretrained_transformer",
          "model_name": model_name,
        }
      }
    }
)
token_embedder = BasicTextFieldEmbedder.from_params(vocab=vocab, params=params)
token_indexer = PretrainedTransformerIndexer(model_name)

transformer_model = BasicClassifier(vocab=transformer_vocab,
                                    text_field_embedder=token_embedder, 
                                    seq2vec_encoder=transformer_encoder, 
                                    dropout=0.1, 
                                    num_labels=7)

transformer_model.eval()

I also had to implement my own DatasetReader as follows:

class ClassificationTransformerReader(DatasetReader):
    def __init__(
        self,
        tokenizer: Tokenizer,
        token_indexer: TokenIndexer,
        max_tokens: int,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.tokenizer = tokenizer
        self.token_indexers: Dict[str, TokenIndexer] = { "tokens": token_indexer }
        self.max_tokens = max_tokens
        self.vocab = vocab

    def text_to_instance(self, text: str, label: str = None) -> Instance:
        tokens = self.tokenizer.tokenize(text)
        if self.max_tokens:
            tokens = tokens[: self.max_tokens]

        inputs = TextField(tokens, self.token_indexers)
        fields: Dict[str, Field] = { "tokens": inputs }

        if label:
            fields["label"] = LabelField(label)

        return Instance(fields)

It is instantiated as follows:

dataset_reader = ClassificationTransformerReader(tokenizer=transformer_tokenizer,
                                                 token_indexer=token_indexer,
                                                 max_tokens=400)

To run the model and test out if it works I'm doing the following:

instance = dataset_reader.text_to_instance("some sample text here")
dataset = Batch([instance])
dataset.index_instances(transformer_vocab)
model_input = util.move_to_device(dataset.as_tensor_dict(), 
                                  transformer_model._get_prediction_device())

outputs = transformer_model.make_output_human_readable(transformer_model(**model_input))
Python traceback:

``` ```

Related issues or possible duplicates

Environment

OS: Ubuntu 20.04 LTS

Python version: 3.8

Output of pip freeze:

``` absl-py==1.0.0 alabaster==0.7.12 albumentations==0.1.12 allennlp==2.9.0 altair==4.2.0 appdirs==1.4.4 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arviz==0.11.4 astor==0.8.1 astropy==4.3.1 astunparse==1.6.3 atari-py==0.2.9 atomicwrites==1.4.0 attrs==21.4.0 audioread==2.1.9 autograd==1.3 Babel==2.9.1 backcall==0.2.0 backports.csv==1.0.7 base58==2.1.1 beautifulsoup4==4.6.3 bleach==4.1.0 blis==0.4.1 bokeh==2.3.3 boto3==1.21.9 botocore==1.24.9 Bottleneck==1.3.2 branca==0.4.2 bs4==0.0.1 CacheControl==0.12.10 cached-path==1.0.2 cached-property==1.5.2 cachetools==4.2.4 catalogue==1.0.0 certifi==2021.10.8 cffi==1.15.0 cftime==1.5.2 chardet==3.0.4 charset-normalizer==2.0.12 checklist==0.0.11 cheroot==8.6.0 CherryPy==18.6.1 click==7.1.2 cloudpickle==1.3.0 cmake==3.12.0 cmdstanpy==0.9.5 colorcet==3.0.0 colorlover==0.3.0 community==1.0.0b1 contextlib2==0.5.5 convertdate==2.4.0 coverage==3.7.1 coveralls==0.5 crcmod==1.7 cryptography==36.0.1 cufflinks==0.17.3 cupy-cuda111==9.4.0 cvxopt==1.2.7 cvxpy==1.0.31 cycler==0.11.0 cymem==2.0.6 Cython==0.29.28 daft==0.0.4 dask==2.12.0 datascience==0.10.6 debugpy==1.0.0 decorator==4.4.2 defusedxml==0.7.1 descartes==1.1.0 dill==0.3.4 distributed==1.25.3 dlib @ file:///dlib-19.18.0-cp37-cp37m-linux_x86_64.whl dm-tree==0.1.6 docker-pycreds==0.4.0 docopt==0.6.2 docutils==0.17.1 dopamine-rl==1.0.5 earthengine-api==0.1.299 easydict==1.9 ecos==2.0.10 editdistance==0.5.3 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz entrypoints==0.4 ephem==4.1.3 et-xmlfile==1.1.0 fa2==0.3.5 fairscale==0.4.5 fastai==1.0.61 fastdtw==0.3.4 fastprogress==1.0.2 fastrlock==0.8 fbprophet==0.7.1 feather-format==0.4.1 feedparser==6.0.8 filelock==3.4.2 firebase-admin==4.4.0 fix-yahoo-finance==0.0.22 Flask==1.1.4 flatbuffers==2.0 folium==0.8.3 future==0.16.0 gast==0.5.3 GDAL==2.2.2 gdown==4.2.1 gensim==3.6.0 geographiclib==1.52 geopy==1.17.0 gin-config==0.5.0 gitdb==4.0.9 GitPython==3.1.27 glob2==0.7 google==2.0.3 google-api-core==1.26.3 google-api-python-client==1.12.10 google-auth==1.35.0 google-auth-httplib2==0.0.4 google-auth-oauthlib==0.4.6 google-cloud-bigquery==1.21.0 google-cloud-bigquery-storage==1.1.0 google-cloud-core==1.7.2 google-cloud-datastore==1.8.0 google-cloud-firestore==1.7.0 google-cloud-language==1.2.0 google-cloud-storage==1.40.0 google-cloud-translate==1.5.0 google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==1.3.3 googleapis-common-protos==1.54.0 googledrivedownloader==0.4 graphviz==0.10.1 greenlet==1.1.2 grpcio==1.44.0 gspread==3.4.2 gspread-dataframe==3.0.8 gym==0.17.3 h5py==3.1.0 HeapDict==1.0.1 hijri-converter==2.2.3 holidays==0.10.5.2 holoviews==1.14.8 html5lib==1.0.1 httpimport==0.5.18 httplib2==0.17.4 httplib2shim==0.0.3 huggingface-hub==0.2.1 humanize==0.5.1 hyperopt==0.1.2 ideep4py==2.0.0.post3 idna==2.10 imageio==2.4.1 imagesize==1.3.0 imbalanced-learn==0.8.1 imblearn==0.0 imgaug==0.2.9 importlib-metadata==4.11.1 importlib-resources==5.4.0 imutils==0.5.4 inflect==2.1.0 iniconfig==1.1.1 intel-openmp==2022.0.2 intervaltree==2.1.0 ipykernel==4.10.1 ipython==5.5.0 ipython-genutils==0.2.0 ipython-sql==0.3.9 ipywidgets==7.6.5 iso-639==0.4.5 itsdangerous==1.1.0 jaraco.classes==3.2.1 jaraco.collections==3.5.1 jaraco.context==4.1.1 jaraco.functools==3.5.0 jaraco.text==3.7.0 jax==0.3.1 jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.0+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl jedi==0.18.1 jieba==0.42.1 Jinja2==2.11.3 jmespath==0.10.0 joblib==1.1.0 jpeg4py==0.1.4 jsonnet==0.18.0 jsonschema==4.3.3 jupyter==1.0.0 jupyter-client==5.3.5 jupyter-console==5.2.0 jupyter-core==4.9.2 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.2 kaggle==1.5.12 kapre==0.3.7 keras==2.8.0 Keras-Preprocessing==1.1.2 keras-vis==0.4.1 kiwisolver==1.3.2 korean-lunar-calendar==0.2.1 libclang==13.0.0 librosa==0.8.1 lightgbm==2.2.3 llvmlite==0.34.0 lmdb==0.99 LunarCalendar==0.0.9 lxml==4.2.6 Markdown==3.3.6 MarkupSafe==2.0.1 matplotlib==3.2.2 matplotlib-inline==0.1.3 matplotlib-venn==0.11.6 missingno==0.5.0 mistune==0.8.4 mizani==0.6.0 mkl==2019.0 mlxtend==0.14.0 more-itertools==8.12.0 moviepy==0.2.3.5 mpmath==1.2.1 msgpack==1.0.3 multiprocess==0.70.12.2 multitasking==0.0.10 munch==2.5.0 murmurhash==1.0.6 music21==5.5.0 natsort==5.5.0 nbclient==0.5.11 nbconvert==5.6.1 nbformat==5.1.3 nest-asyncio==1.5.4 netCDF4==1.5.8 networkx==2.6.3 nibabel==3.0.2 nltk==3.2.5 notebook==5.3.1 numba==0.51.2 numexpr==2.8.1 numpy==1.21.5 nvidia-ml-py3==7.352.0 oauth2client==4.1.3 oauthlib==3.2.0 okgrade==0.4.3 opencv-contrib-python==4.1.2.30 opencv-python==4.1.2.30 openpyxl==3.0.9 opt-einsum==3.3.0 osqp==0.6.2.post0 packaging==21.3 palettable==3.3.0 pandas==1.3.5 pandas-datareader==0.9.0 pandas-gbq==0.13.3 pandas-profiling==1.4.1 pandocfilters==1.5.0 panel==0.12.1 param==1.12.0 parso==0.8.3 pathlib==1.0.1 pathtools==0.1.2 patsy==0.5.2 patternfork-nosql==3.6 pdfminer.six==20211012 pep517==0.12.0 pexpect==4.8.0 pickleshare==0.7.5 Pillow==7.1.2 pip-tools==6.2.0 plac==1.1.3 plotly==5.5.0 plotnine==0.6.0 pluggy==0.7.1 pooch==1.6.0 portend==3.1.0 portpicker==1.3.9 prefetch-generator==1.0.1 preshed==3.0.6 prettytable==3.1.1 progressbar2==3.38.0 prometheus-client==0.13.1 promise==2.3 prompt-toolkit==1.0.18 protobuf==3.17.3 psutil==5.4.8 psycopg2==2.7.6.1 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.1 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycocotools==2.0.4 pycparser==2.21 pyct==0.4.8 pydata-google-auth==1.3.0 pydot==1.3.0 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 pyemd==0.5.1 pyerfa==2.0.0.1 pyglet==1.5.0 Pygments==2.6.1 pygobject==3.26.1 pymc3==3.11.4 PyMeeus==0.5.11 pymongo==4.0.1 pymystem3==0.2.0 PyOpenGL==3.1.5 pyparsing==3.0.7 pyrsistent==0.18.1 pysndfile==1.3.8 PySocks==1.7.1 pystan==2.19.1.1 pytest==3.6.4 python-apt==0.0.0 python-chess==0.23.11 python-dateutil==2.8.2 python-docx==0.8.11 python-louvain==0.16 python-slugify==6.0.1 python-utils==3.1.0 pytz==2018.9 pyviz-comms==2.1.0 PyWavelets==1.2.0 PyYAML==6.0 pyzmq==22.3.0 qdldl==0.1.5.post0 qtconsole==5.2.2 QtPy==2.0.1 regex==2019.12.20 requests==2.23.0 requests-oauthlib==1.3.1 resampy==0.2.2 rpy2==3.4.5 rsa==4.8 s3transfer==0.5.2 sacremoses==0.0.47 scikit-image==0.18.3 scikit-learn==1.0.2 scipy==1.4.1 screen-resolution-extra==0.0.0 scs==3.1.0 seaborn==0.11.2 semver==2.13.0 Send2Trash==1.8.0 sentencepiece==0.1.96 sentry-sdk==1.5.6 setuptools-git==1.2 sgmllib3k==1.0.0 Shapely==1.8.1.post1 shortuuid==1.0.8 simplegeneric==0.8.1 six==1.15.0 sklearn==0.0 sklearn-pandas==1.8.0 smart-open==5.2.1 smmap==5.0.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 SoundFile==0.10.3.post1 spacy==2.2.4 Sphinx==1.8.6 sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-websupport==1.2.4 SQLAlchemy==1.4.31 sqlparse==0.4.2 srsly==1.0.5 statsmodels==0.10.2 sympy==1.7.1 tables==3.7.0 tabulate==0.8.9 tblib==1.7.0 tempora==5.0.1 tenacity==8.0.1 tensorboard==2.8.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorboardX==2.5 tensorflow @ file:///tensorflow-2.8.0-cp37-cp37m-linux_x86_64.whl tensorflow-datasets==4.0.1 tensorflow-estimator==2.8.0 tensorflow-gcs-config==2.8.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.24.0 tensorflow-metadata==1.6.0 tensorflow-probability==0.16.0 termcolor==1.1.0 terminado==0.13.1 testpath==0.5.0 text-unidecode==1.3 textblob==0.15.3 Theano-PyMC==1.1.2 thinc==7.4.0 threadpoolctl==3.1.0 tifffile==2021.11.2 tokenizers==0.10.3 tomli==2.0.1 toolz==0.11.2 torch==1.9.0 torchaudio @ https://download.pytorch.org/whl/cu111/torchaudio-0.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl torchsummary==1.5.1 torchtext==0.11.0 torchvision==0.10.0 tornado==5.1.1 tqdm==4.62.3 traitlets==5.1.1 transformers==4.12.3 tweepy==3.10.0 typeguard==2.7.1 typing-extensions==3.10.0.2 tzlocal==1.5.1 uritemplate==3.0.1 urllib3==1.25.11 vega-datasets==0.9.0 wandb==0.12.10 wasabi==0.9.0 wcwidth==0.2.5 webencodings==0.5.1 Werkzeug==1.0.1 widgetsnbextension==3.5.2 wordcloud==1.5.0 wrapt==1.13.3 xarray==0.18.2 xgboost==0.90 xkit==0.0.0 xlrd==1.1.0 xlwt==1.3.0 yaspin==2.1.0 zc.lockfile==2.0 zict==2.0.0 zipp==3.7.0 ```

Steps to reproduce

Follow the description of the isssue.

Example source:

``` ## Provided above. ```

dirkgr commented 2 years ago

The code you posted creates an untrained 7-way classification model. The weights for the transformer are going to be the same as huggingface's, but the finally classification layer (that takes the final hidden representation from the transformer and boils it down to a 7-way classifier) is randomly initialized every time you run that code. In fact, you'll find that even running this code twice doesn't produce the same result, because your classifier is randomly initialized every time.

santiagxf commented 2 years ago

Hi @dirkgr! Thanks for the reply! I got you point. What is the correct way of loading the classifier layer too?

dirkgr commented 2 years ago

Where would you like your 7-way classifier to come from? I don't think there is one that comes with BERT.

santiagxf commented 2 years ago

The model is based on BERT, but it does contains a classification layer at the end. In transformers, I can load the model using BERTModelForSequenceClassification.from_pretrained(model_path) if this helps.

dirkgr commented 2 years ago

When I run sc = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased"), I see the following warning:

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

This tells you that Huggingface also randomly initializes the final classifier. We can maybe figure out how to make sure you get the same random weights into your AllenNLP classifier, but what would be the point? They are random either way, and need to be trained.

Also, Huggingface by default initializes a 2-way classifier, not 7-way.

santiagxf commented 2 years ago

@dirkgr, this is not a vanilla BERT model. As I mentioned, the model is based on BERT (emphasis on the based on), but it is not bert-base-uncased, for instance. It is based on it - meaning it has the same top architecture. The model has been fine tunned to performed a different downstream task. My model does have a classification layer at the end and it works as expected if I run it.

I think you are right about the fact I'm missing the initialization of the classifier onAllenNLP. Maybe I'm missing something like model.load(...) instead of BasicClassifier. But not sure how can I load that from the model I already have. Does it help?

dirkgr commented 2 years ago

I don't think there is a built-in way to translate the classification head from Huggingface to AllenNLP. But it might work if you just do something like this:

allennlp_model._classification_layer.weight = huggingface_model.classifier.weight
allennlp_model._classification_layer.bias = huggingface_model.classifier.bias
santiagxf commented 2 years ago

Great tip @dirkgr! I managed to make it work with your suggestion:

from transformers import BertForSequenceClassification

model = BasicClassifier(vocab=transformer_vocab, text_field_embedder=token_embedder, seq2vec_encoder=transformer_encoder, dropout=0.1, num_labels=7)
classifier = BertForSequenceClassification.from_pretrained(model_name)
model._classification_layer.weight = classifier.classifier.weight
model._classification_layer.bias = classifier.classifier.bias
_ = model.eval()

Mystery solved! Quick question, maybe you know. The amount of code required to import a transformer in AllenNLP is significant. When I asked about this issue on Stack Overflow, one person mentioned that there were easier ways to accomplish this (but he never mentioned which one and didn't reply back). Do you know if this is the "right" way to import a transformer into AllenNLP?

I'm mainly using AllenNLP for their interpret package, that's why development of the models is done on transformers.

dirkgr commented 2 years ago

The interpret module looks for *Embedder classes I think, so you probably need to do it this way. That said, maybe you could write your dataset reader to produce a TransformerTextField instead, and then use a PassThroughTokenEmbedder. You would only need the token embedder so the interpret module has something to latch on to. It doesn't actually perform any function.

I'm honestly not sure if that would work. Maybe there is an easier way to tell the interpret module where to look, so you don't need the PassThroughTokenEmbedder at all.

santiagxf commented 2 years ago

Thanks for the hint! I will close the issue since we already got it solved. Thanks for the insights and helping me to get it done!

tanmaylaud commented 2 years ago

@santiagxf Could you post the solution you got working for the interpret portion?

santiagxf commented 2 years ago

@tanmaylaud I wrote a blog post about it: Model interpretability — Making your model confesses: Saliency maps