Closed santiagxf closed 2 years ago
The code you posted creates an untrained 7-way classification model. The weights for the transformer are going to be the same as huggingface's, but the finally classification layer (that takes the final hidden representation from the transformer and boils it down to a 7-way classifier) is randomly initialized every time you run that code. In fact, you'll find that even running this code twice doesn't produce the same result, because your classifier is randomly initialized every time.
Hi @dirkgr! Thanks for the reply! I got you point. What is the correct way of loading the classifier layer too?
Where would you like your 7-way classifier to come from? I don't think there is one that comes with BERT.
The model is based on BERT
, but it does contains a classification layer at the end. In transformers
, I can load the model using BERTModelForSequenceClassification.from_pretrained(model_path)
if this helps.
When I run sc = transformers.AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
, I see the following warning:
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
This tells you that Huggingface also randomly initializes the final classifier. We can maybe figure out how to make sure you get the same random weights into your AllenNLP classifier, but what would be the point? They are random either way, and need to be trained.
Also, Huggingface by default initializes a 2-way classifier, not 7-way.
@dirkgr, this is not a vanilla BERT model. As I mentioned, the model is based on BERT (emphasis on the based on), but it is not bert-base-uncased
, for instance. It is based on it - meaning it has the same top architecture. The model has been fine tunned to performed a different downstream task. My model does have a classification layer at the end and it works as expected if I run it.
I think you are right about the fact I'm missing the initialization of the classifier onAllenNLP
. Maybe I'm missing something like model.load(...)
instead of BasicClassifier
. But not sure how can I load that from the model I already have. Does it help?
I don't think there is a built-in way to translate the classification head from Huggingface to AllenNLP. But it might work if you just do something like this:
allennlp_model._classification_layer.weight = huggingface_model.classifier.weight
allennlp_model._classification_layer.bias = huggingface_model.classifier.bias
Great tip @dirkgr! I managed to make it work with your suggestion:
from transformers import BertForSequenceClassification
model = BasicClassifier(vocab=transformer_vocab, text_field_embedder=token_embedder, seq2vec_encoder=transformer_encoder, dropout=0.1, num_labels=7)
classifier = BertForSequenceClassification.from_pretrained(model_name)
model._classification_layer.weight = classifier.classifier.weight
model._classification_layer.bias = classifier.classifier.bias
_ = model.eval()
Mystery solved! Quick question, maybe you know. The amount of code required to import a transformer in AllenNLP is significant. When I asked about this issue on Stack Overflow, one person mentioned that there were easier ways to accomplish this (but he never mentioned which one and didn't reply back). Do you know if this is the "right" way to import a transformer into AllenNLP?
I'm mainly using AllenNLP
for their interpret package, that's why development of the models is done on transformers
.
The interpret module looks for *Embedder
classes I think, so you probably need to do it this way. That said, maybe you could write your dataset reader to produce a TransformerTextField
instead, and then use a PassThroughTokenEmbedder
. You would only need the token embedder so the interpret module has something to latch on to. It doesn't actually perform any function.
I'm honestly not sure if that would work. Maybe there is an easier way to tell the interpret module where to look, so you don't need the PassThroughTokenEmbedder
at all.
Thanks for the hint! I will close the issue since we already got it solved. Thanks for the insights and helping me to get it done!
@santiagxf Could you post the solution you got working for the interpret portion?
@tanmaylaud I wrote a blog post about it: Model interpretability — Making your model confesses: Saliency maps
Checklist
main
branch of AllenNLP.pip freeze
.Description
When loading a model created with HuggingFace (transformer library) into AllenNLP, the models works but predictions are different from what I get if running the model with transformers. Both model are running on evaluation mode.
Details
I have a custom classification model trained using transformers library based on a BERT model. The model classifies text into 7 different categories. It is persisted in a directory using:
I'm trying to load such persisted model using the
allennlp
library for further analysis. However, when running the model inside theallennlp
framework, the model tends to predict very different from the predictions I get when I run it usingtransformers
, which lead me think some part of the loading was not done correctly. There are no errors during the inference, it is just that the predictions don't match.There is little documentation about how to load an existing model (persisted in a path for instance), so I'm wondering if someone faced the same situation before. There is just one example of how to do QA classification with ROBERTA, but couldn't extrapolate to what I'm looking for. This is how I'm loading the trained model:
I also had to implement my own DatasetReader as follows:
It is instantiated as follows:
To run the model and test out if it works I'm doing the following:
Python traceback:
``` ```
Related issues or possible duplicates
Environment
OS:
Ubuntu 20.04 LTS
Python version:
3.8
Output of
pip freeze
:``` absl-py==1.0.0 alabaster==0.7.12 albumentations==0.1.12 allennlp==2.9.0 altair==4.2.0 appdirs==1.4.4 argon2-cffi==21.3.0 argon2-cffi-bindings==21.2.0 arviz==0.11.4 astor==0.8.1 astropy==4.3.1 astunparse==1.6.3 atari-py==0.2.9 atomicwrites==1.4.0 attrs==21.4.0 audioread==2.1.9 autograd==1.3 Babel==2.9.1 backcall==0.2.0 backports.csv==1.0.7 base58==2.1.1 beautifulsoup4==4.6.3 bleach==4.1.0 blis==0.4.1 bokeh==2.3.3 boto3==1.21.9 botocore==1.24.9 Bottleneck==1.3.2 branca==0.4.2 bs4==0.0.1 CacheControl==0.12.10 cached-path==1.0.2 cached-property==1.5.2 cachetools==4.2.4 catalogue==1.0.0 certifi==2021.10.8 cffi==1.15.0 cftime==1.5.2 chardet==3.0.4 charset-normalizer==2.0.12 checklist==0.0.11 cheroot==8.6.0 CherryPy==18.6.1 click==7.1.2 cloudpickle==1.3.0 cmake==3.12.0 cmdstanpy==0.9.5 colorcet==3.0.0 colorlover==0.3.0 community==1.0.0b1 contextlib2==0.5.5 convertdate==2.4.0 coverage==3.7.1 coveralls==0.5 crcmod==1.7 cryptography==36.0.1 cufflinks==0.17.3 cupy-cuda111==9.4.0 cvxopt==1.2.7 cvxpy==1.0.31 cycler==0.11.0 cymem==2.0.6 Cython==0.29.28 daft==0.0.4 dask==2.12.0 datascience==0.10.6 debugpy==1.0.0 decorator==4.4.2 defusedxml==0.7.1 descartes==1.1.0 dill==0.3.4 distributed==1.25.3 dlib @ file:///dlib-19.18.0-cp37-cp37m-linux_x86_64.whl dm-tree==0.1.6 docker-pycreds==0.4.0 docopt==0.6.2 docutils==0.17.1 dopamine-rl==1.0.5 earthengine-api==0.1.299 easydict==1.9 ecos==2.0.10 editdistance==0.5.3 en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz entrypoints==0.4 ephem==4.1.3 et-xmlfile==1.1.0 fa2==0.3.5 fairscale==0.4.5 fastai==1.0.61 fastdtw==0.3.4 fastprogress==1.0.2 fastrlock==0.8 fbprophet==0.7.1 feather-format==0.4.1 feedparser==6.0.8 filelock==3.4.2 firebase-admin==4.4.0 fix-yahoo-finance==0.0.22 Flask==1.1.4 flatbuffers==2.0 folium==0.8.3 future==0.16.0 gast==0.5.3 GDAL==2.2.2 gdown==4.2.1 gensim==3.6.0 geographiclib==1.52 geopy==1.17.0 gin-config==0.5.0 gitdb==4.0.9 GitPython==3.1.27 glob2==0.7 google==2.0.3 google-api-core==1.26.3 google-api-python-client==1.12.10 google-auth==1.35.0 google-auth-httplib2==0.0.4 google-auth-oauthlib==0.4.6 google-cloud-bigquery==1.21.0 google-cloud-bigquery-storage==1.1.0 google-cloud-core==1.7.2 google-cloud-datastore==1.8.0 google-cloud-firestore==1.7.0 google-cloud-language==1.2.0 google-cloud-storage==1.40.0 google-cloud-translate==1.5.0 google-colab @ file:///colabtools/dist/google-colab-1.0.0.tar.gz google-crc32c==1.3.0 google-pasta==0.2.0 google-resumable-media==1.3.3 googleapis-common-protos==1.54.0 googledrivedownloader==0.4 graphviz==0.10.1 greenlet==1.1.2 grpcio==1.44.0 gspread==3.4.2 gspread-dataframe==3.0.8 gym==0.17.3 h5py==3.1.0 HeapDict==1.0.1 hijri-converter==2.2.3 holidays==0.10.5.2 holoviews==1.14.8 html5lib==1.0.1 httpimport==0.5.18 httplib2==0.17.4 httplib2shim==0.0.3 huggingface-hub==0.2.1 humanize==0.5.1 hyperopt==0.1.2 ideep4py==2.0.0.post3 idna==2.10 imageio==2.4.1 imagesize==1.3.0 imbalanced-learn==0.8.1 imblearn==0.0 imgaug==0.2.9 importlib-metadata==4.11.1 importlib-resources==5.4.0 imutils==0.5.4 inflect==2.1.0 iniconfig==1.1.1 intel-openmp==2022.0.2 intervaltree==2.1.0 ipykernel==4.10.1 ipython==5.5.0 ipython-genutils==0.2.0 ipython-sql==0.3.9 ipywidgets==7.6.5 iso-639==0.4.5 itsdangerous==1.1.0 jaraco.classes==3.2.1 jaraco.collections==3.5.1 jaraco.context==4.1.1 jaraco.functools==3.5.0 jaraco.text==3.7.0 jax==0.3.1 jaxlib @ https://storage.googleapis.com/jax-releases/cuda11/jaxlib-0.3.0+cuda11.cudnn805-cp37-none-manylinux2010_x86_64.whl jedi==0.18.1 jieba==0.42.1 Jinja2==2.11.3 jmespath==0.10.0 joblib==1.1.0 jpeg4py==0.1.4 jsonnet==0.18.0 jsonschema==4.3.3 jupyter==1.0.0 jupyter-client==5.3.5 jupyter-console==5.2.0 jupyter-core==4.9.2 jupyterlab-pygments==0.1.2 jupyterlab-widgets==1.0.2 kaggle==1.5.12 kapre==0.3.7 keras==2.8.0 Keras-Preprocessing==1.1.2 keras-vis==0.4.1 kiwisolver==1.3.2 korean-lunar-calendar==0.2.1 libclang==13.0.0 librosa==0.8.1 lightgbm==2.2.3 llvmlite==0.34.0 lmdb==0.99 LunarCalendar==0.0.9 lxml==4.2.6 Markdown==3.3.6 MarkupSafe==2.0.1 matplotlib==3.2.2 matplotlib-inline==0.1.3 matplotlib-venn==0.11.6 missingno==0.5.0 mistune==0.8.4 mizani==0.6.0 mkl==2019.0 mlxtend==0.14.0 more-itertools==8.12.0 moviepy==0.2.3.5 mpmath==1.2.1 msgpack==1.0.3 multiprocess==0.70.12.2 multitasking==0.0.10 munch==2.5.0 murmurhash==1.0.6 music21==5.5.0 natsort==5.5.0 nbclient==0.5.11 nbconvert==5.6.1 nbformat==5.1.3 nest-asyncio==1.5.4 netCDF4==1.5.8 networkx==2.6.3 nibabel==3.0.2 nltk==3.2.5 notebook==5.3.1 numba==0.51.2 numexpr==2.8.1 numpy==1.21.5 nvidia-ml-py3==7.352.0 oauth2client==4.1.3 oauthlib==3.2.0 okgrade==0.4.3 opencv-contrib-python==4.1.2.30 opencv-python==4.1.2.30 openpyxl==3.0.9 opt-einsum==3.3.0 osqp==0.6.2.post0 packaging==21.3 palettable==3.3.0 pandas==1.3.5 pandas-datareader==0.9.0 pandas-gbq==0.13.3 pandas-profiling==1.4.1 pandocfilters==1.5.0 panel==0.12.1 param==1.12.0 parso==0.8.3 pathlib==1.0.1 pathtools==0.1.2 patsy==0.5.2 patternfork-nosql==3.6 pdfminer.six==20211012 pep517==0.12.0 pexpect==4.8.0 pickleshare==0.7.5 Pillow==7.1.2 pip-tools==6.2.0 plac==1.1.3 plotly==5.5.0 plotnine==0.6.0 pluggy==0.7.1 pooch==1.6.0 portend==3.1.0 portpicker==1.3.9 prefetch-generator==1.0.1 preshed==3.0.6 prettytable==3.1.1 progressbar2==3.38.0 prometheus-client==0.13.1 promise==2.3 prompt-toolkit==1.0.18 protobuf==3.17.3 psutil==5.4.8 psycopg2==2.7.6.1 ptyprocess==0.7.0 py==1.11.0 pyarrow==6.0.1 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycocotools==2.0.4 pycparser==2.21 pyct==0.4.8 pydata-google-auth==1.3.0 pydot==1.3.0 pydot-ng==2.0.0 pydotplus==2.0.2 PyDrive==1.3.1 pyemd==0.5.1 pyerfa==2.0.0.1 pyglet==1.5.0 Pygments==2.6.1 pygobject==3.26.1 pymc3==3.11.4 PyMeeus==0.5.11 pymongo==4.0.1 pymystem3==0.2.0 PyOpenGL==3.1.5 pyparsing==3.0.7 pyrsistent==0.18.1 pysndfile==1.3.8 PySocks==1.7.1 pystan==2.19.1.1 pytest==3.6.4 python-apt==0.0.0 python-chess==0.23.11 python-dateutil==2.8.2 python-docx==0.8.11 python-louvain==0.16 python-slugify==6.0.1 python-utils==3.1.0 pytz==2018.9 pyviz-comms==2.1.0 PyWavelets==1.2.0 PyYAML==6.0 pyzmq==22.3.0 qdldl==0.1.5.post0 qtconsole==5.2.2 QtPy==2.0.1 regex==2019.12.20 requests==2.23.0 requests-oauthlib==1.3.1 resampy==0.2.2 rpy2==3.4.5 rsa==4.8 s3transfer==0.5.2 sacremoses==0.0.47 scikit-image==0.18.3 scikit-learn==1.0.2 scipy==1.4.1 screen-resolution-extra==0.0.0 scs==3.1.0 seaborn==0.11.2 semver==2.13.0 Send2Trash==1.8.0 sentencepiece==0.1.96 sentry-sdk==1.5.6 setuptools-git==1.2 sgmllib3k==1.0.0 Shapely==1.8.1.post1 shortuuid==1.0.8 simplegeneric==0.8.1 six==1.15.0 sklearn==0.0 sklearn-pandas==1.8.0 smart-open==5.2.1 smmap==5.0.0 snowballstemmer==2.2.0 sortedcontainers==2.4.0 SoundFile==0.10.3.post1 spacy==2.2.4 Sphinx==1.8.6 sphinxcontrib-serializinghtml==1.1.5 sphinxcontrib-websupport==1.2.4 SQLAlchemy==1.4.31 sqlparse==0.4.2 srsly==1.0.5 statsmodels==0.10.2 sympy==1.7.1 tables==3.7.0 tabulate==0.8.9 tblib==1.7.0 tempora==5.0.1 tenacity==8.0.1 tensorboard==2.8.0 tensorboard-data-server==0.6.1 tensorboard-plugin-wit==1.8.1 tensorboardX==2.5 tensorflow @ file:///tensorflow-2.8.0-cp37-cp37m-linux_x86_64.whl tensorflow-datasets==4.0.1 tensorflow-estimator==2.8.0 tensorflow-gcs-config==2.8.0 tensorflow-hub==0.12.0 tensorflow-io-gcs-filesystem==0.24.0 tensorflow-metadata==1.6.0 tensorflow-probability==0.16.0 termcolor==1.1.0 terminado==0.13.1 testpath==0.5.0 text-unidecode==1.3 textblob==0.15.3 Theano-PyMC==1.1.2 thinc==7.4.0 threadpoolctl==3.1.0 tifffile==2021.11.2 tokenizers==0.10.3 tomli==2.0.1 toolz==0.11.2 torch==1.9.0 torchaudio @ https://download.pytorch.org/whl/cu111/torchaudio-0.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl torchsummary==1.5.1 torchtext==0.11.0 torchvision==0.10.0 tornado==5.1.1 tqdm==4.62.3 traitlets==5.1.1 transformers==4.12.3 tweepy==3.10.0 typeguard==2.7.1 typing-extensions==3.10.0.2 tzlocal==1.5.1 uritemplate==3.0.1 urllib3==1.25.11 vega-datasets==0.9.0 wandb==0.12.10 wasabi==0.9.0 wcwidth==0.2.5 webencodings==0.5.1 Werkzeug==1.0.1 widgetsnbextension==3.5.2 wordcloud==1.5.0 wrapt==1.13.3 xarray==0.18.2 xgboost==0.90 xkit==0.0.0 xlrd==1.1.0 xlwt==1.3.0 yaspin==2.1.0 zc.lockfile==2.0 zict==2.0.0 zipp==3.7.0 ```
Steps to reproduce
Follow the description of the isssue.
Example source:
``` ## Provided above. ```