allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.76k stars 2.25k forks source link

Error in interpreter when the input sequence is more than 1d #5300

Open zegnog opened 3 years ago

zegnog commented 3 years ago

Checklist

Description

In integrated_gradient.py, the gradient is normed assuming the input is of shape (b, seq_len, emb_size...), but sometimes the input is more than one sequence. It will trigger the following error:

Python traceback:

``` in (.0) 2 embedding_grad = numpy.sum(grad[0], axis=1) 3 norm = numpy.linalg.norm(embedding_grad, ord=1) ----> 4 normalized_grad = [math.fabs(e) / norm for e in embedding_grad] 5 return normalized_grad TypeError: only size-1 arrays can be converted to Python scalars ```

Related issues or possible duplicates

Environment

OS: MacOS

Python version: 3.8.3

Output of pip freeze:

``` absl-py==0.11.0 allennlp==2.5.0 allennlp-models==2.5.0 altair==4.1.0 appnope @ file:///opt/concourse/worker/volumes/live/0291c9e1-4b15-459f-623e-2770f55be269/volume/appnope_1594338395037/work argon2-cffi==20.1.0 astor==0.8.1 astroid @ file:///opt/concourse/worker/volumes/live/21fd14a9-2a7e-484b-7394-5a9912cdcf80/volume/astroid_1592498459180/work astunparse==1.6.3 async-generator==1.10 attrs==20.2.0 autopep8 @ file:///tmp/build/80754af9/autopep8_1596578164842/work backcall==0.2.0 backports.csv==1.0.7 base58==2.0.1 beautifulsoup4==4.9.3 bidict==0.21.2 bleach==3.2.1 blinker==1.4 blis==0.4.1 bokeh @ file:///opt/concourse/worker/volumes/live/00176cca-b1c3-4fe7-4da0-2ea50fa27756/volume/bokeh_1599056223916/work boto3==1.15.16 botocore==1.18.16 brotlipy==0.7.0 cachetools==4.2.0 catalogue==1.0.0 certifi==2020.12.5 cffi @ file:///opt/concourse/worker/volumes/live/b9607b09-b777-4ff7-53dc-287727eb8574/volume/cffi_1600699191154/work chardet==4.0.0 checklist==0.0.11 cheroot==8.5.2 CherryPy==18.6.0 click==7.1.2 cloudpickle @ file:///tmp/build/80754af9/cloudpickle_1598884132938/work conda==4.9.2 conda-package-handling==1.7.0+0.g7c4a471.dirty configparser==5.0.2 conllu==4.4 cryptography @ file:///opt/concourse/worker/volumes/live/aeb63a26-659e-4edb-5405-74ba8e0c76f2/volume/cryptography_1601046839724/work cycler==0.10.0 cymem==2.0.3 cytoolz==0.11.0 dask @ file:///tmp/build/80754af9/dask-core_1600699676946/work dataclasses==0.6 decorator==4.4.2 defusedxml==0.6.0 dill==0.3.3 distributed @ file:///opt/concourse/worker/volumes/live/88b24afb-2456-43bf-56b4-38902ff2143a/volume/distributed_1601910790507/work docker-pycreds==0.4.0 email-reply-parser==0.5.12 en-core-web-md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-2.3.1/en_core_web_md-2.3.1.tar.gz en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz entrypoints==0.3 et-xmlfile==1.0.1 feedparser==6.0.8 filelock==3.0.12 flake8 @ file:///tmp/build/80754af9/flake8_1601911421857/work Flask==1.1.2 Flask-SocketIO==5.0.1 flatbuffers==1.12 fsspec @ file:///tmp/build/80754af9/fsspec_1597944003862/work ftfy==5.8 future==0.18.2 gast==0.3.3 gensim==3.8.3 gitdb==4.0.5 GitPython==3.1.11 google-api-core==1.30.0 google-auth==1.32.1 google-auth-oauthlib==0.4.2 google-cloud-core==1.7.1 google-cloud-storage==1.38.0 google-crc32c==1.1.2 google-pasta==0.2.0 google-resumable-media==1.3.1 googleapis-common-protos==1.53.0 greenlet==1.0.0 grpcio==1.34.0 h5py==2.10.0 HeapDict==1.0.1 huggingface-hub==0.0.13 idna==2.10 importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1602276842396/work iniconfig==1.0.1 ipaddress==1.0.23 ipykernel @ file:///opt/concourse/worker/volumes/live/88f541d3-5a27-498f-7391-f2e50ca36560/volume/ipykernel_1596206680118/work/dist/ipykernel-5.3.4-py3-none-any.whl ipython @ file:///opt/concourse/worker/volumes/live/90527372-a5f7-4d48-68c2-a99de2ba31c7/volume/ipython_1599056199422/work ipython-genutils==0.2.0 ipywidgets==7.5.1 iso-639==0.4.5 isort @ file:///opt/concourse/worker/volumes/live/97586196-dd0e-4a3f-4d06-330d40d66119/volume/isort_1601490198153/work itsdangerous==1.1.0 jaraco.classes==3.2.1 jaraco.collections==3.3.0 jaraco.functools==3.3.0 jaraco.text==3.5.0 jdcal==1.4.1 jedi @ file:///opt/concourse/worker/volumes/live/c63c70ef-654a-4d26-6b78-37aef096f225/volume/jedi_1596490811751/work Jinja2==2.11.2 jmespath==0.10.0 joblib==0.16.0 jsonnet==0.16.0 jsonpickle==1.4.1 jsonschema==3.2.0 jupyter==1.0.0 jupyter-client @ file:///tmp/build/80754af9/jupyter_client_1601311786391/work jupyter-console==6.4.0 jupyter-core==4.6.3 jupyterlab-pygments==0.1.2 Keras-Preprocessing==1.1.2 kiwisolver==1.2.0 langdetect==1.0.7 lazy-object-proxy==1.4.3 lmdb==1.1.1 locket==0.2.0 lxml==4.6.3 mail-parser==3.12.0 Markdown==3.3.3 MarkupSafe @ file:///opt/concourse/worker/volumes/live/cb778296-98db-45ad-411e-6f726e102dc3/volume/markupsafe_1594371638608/work matplotlib==3.3.2 mccabe==0.6.1 mistune==0.8.4 mittens==0.2 mkl-fft==1.2.0 mkl-random==1.1.1 mkl-service==2.3.0 more-itertools==8.7.0 msgpack==1.0.0 munch==2.5.0 murmurhash==1.0.2 nbclient==0.5.1 nbconvert==6.0.7 nbformat==5.0.8 neovim==0.3.1 nest-asyncio==1.4.3 nltk==3.5 notebook==6.1.5 numpy==1.19.4 oauthlib==3.1.0 olefile==0.46 openpyxl @ file:///tmp/build/80754af9/openpyxl_1610651698508/work opt-einsum==3.3.0 overrides==3.1.0 packaging==20.9 pandas==1.2.2 pandocfilters==1.4.3 parso==0.7.0 partd==1.1.0 pathtools==0.1.2 patternfork-nosql==3.6 pdfminer.six==20201018 pexpect @ file:///opt/concourse/worker/volumes/live/8701bb20-ad87-46c7-5108-30c178cf97e5/volume/pexpect_1594383388344/work pickleshare @ file:///opt/concourse/worker/volumes/live/93ec39d8-05bb-4f84-7efc-98735bc39b70/volume/pickleshare_1594384101884/work Pillow @ file:///opt/concourse/worker/volumes/live/be1e8a56-c4be-4ffe-4fa7-5a0e9c460b1a/volume/pillow_1594307312933/work plac==1.1.3 pluggy==0.13.1 portend==2.7.1 preshed==3.0.2 prettyprint==0.1.5 prometheus-client==0.9.0 promise==2.3 prompt-toolkit @ file:///tmp/build/80754af9/prompt-toolkit_1598885458782/work protobuf==3.14.0 psutil @ file:///opt/concourse/worker/volumes/live/ff72f822-991c-4030-4f3a-8c41d3ac4e4f/volume/psutil_1598370232375/work ptyprocess==0.6.0 py==1.9.0 py-rouge==1.1 pyarrow==2.0.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 pycodestyle==2.6.0 pycosat==0.6.3 pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work pydeck==0.5.0 pyeasyga==0.3.1 pyflakes==2.2.0 Pygments @ file:///tmp/build/80754af9/pygments_1600458456400/work pylint @ file:///opt/concourse/worker/volumes/live/ed0164b6-bcc7-4f6b-7dd4-ad89660b5dcb/volume/pylint_1598624018129/work pynvim==0.4.2 pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1594392929924/work pyparsing==2.4.7 pyrsistent==0.17.3 PySocks @ file:///opt/concourse/worker/volumes/live/85a5b906-0e08-41d9-6f59-084cee4e9492/volume/pysocks_1594394636991/work pyswarms==1.2.0 pytest==6.1.1 python-crfsuite==0.9.7 python-dateutil==2.8.1 python-docx==0.8.11 python-engineio==4.0.0 python-socketio==5.0.4 pytz==2020.1 PyYAML==5.3.1 pyzmq==19.0.2 qtconsole==5.1.1 QtPy==1.9.0 QuantLib==1.20 regex==2020.9.27 requests==2.25.1 requests-oauthlib==1.3.0 rsa==4.6 ruamel-yaml==0.15.87 s3transfer==0.3.3 sacremoses==0.0.43 scikit-learn==0.23.2 scipy==1.5.2 seaborn==0.11.1 Send2Trash==1.5.0 sentencepiece==0.1.91 sentry-sdk==1.1.0 sgmllib3k==1.0.0 shortuuid==1.0.1 simplejson==3.17.0 six @ file:///opt/concourse/worker/volumes/live/5b31cb27-1e37-4ca5-6e9f-86246eb206d2/volume/six_1605205320872/work sklearn-crfsuite==0.3.6 smart-open==4.0.1 smmap==3.0.4 sortedcontainers==2.2.2 soupsieve==2.2.1 spacy==2.3.2 srsly==1.0.2 streamlit==0.83.0 subprocess32==3.5.4 tabulate==0.8.7 tblib @ file:///tmp/build/80754af9/tblib_1597928476713/work tempora==4.1.1 tensorboard==2.4.0 tensorboard-plugin-wit==1.7.0 tensorboardX==2.1 tensorflow-estimator==2.4.0 termcolor==1.1.0 terminado==0.9.1 testpath==0.4.4 textblob==0.15.3 thinc==7.4.1 threadpoolctl==2.1.0 tokenizers==0.10.1 toml @ file:///tmp/build/80754af9/toml_1592853716807/work toolz @ file:///tmp/build/80754af9/toolz_1601054250827/work torch==1.7.1 torchtext==0.7.0 torchvision==0.8.2 tornado==6.0.4 tqdm @ file:///tmp/build/80754af9/tqdm_1601923723745/work traitlets @ file:///tmp/build/80754af9/traitlets_1600712679583/work transformers==4.3.3 typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1598376058250/work tzlocal==2.1 urllib3==1.25.11 validators==0.18.2 wandb==0.10.33 wasabi==0.8.0 watchdog==1.0.2 wcwidth @ file:///tmp/build/80754af9/wcwidth_1593447189090/work webencodings==0.5.1 Werkzeug==1.0.1 widgetsnbextension==3.5.1 word2number==1.1 wrapt==1.12.1 xlrd @ file:///tmp/build/80754af9/xlrd_1608072521494/work zc.lockfile==2.0 zict==2.0.0 zipp @ file:///tmp/build/80754af9/zipp_1604001098328/work ```

Steps to reproduce

run the following code, the old one is current way and the new one is proposed new way.

Example source:

```python # To add a new cell, type '# %%' # To add a new markdown cell, type '# %% [markdown]' # %% import numpy import math # %% def old_norm_grad(grad:numpy.ndarray): embedding_grad = numpy.sum(grad[0], axis=1) norm = numpy.linalg.norm(embedding_grad, ord=1) normalized_grad = [math.fabs(e) / norm for e in embedding_grad] return normalized_grad # %% def new_norm_grad(grad: numpy.ndarray): embedding_grad = numpy.sum(grad[0], axis=-1) norm = numpy.linalg.norm(embedding_grad, ord=1, keepdims=True) normalized_grad = embedding_grad / norm return normalized_grad # %% test_grad = numpy.random.rand(1,13,100) numpy.testing.assert_array_equal(old_norm_grad(test_grad), new_norm_grad(test_grad)) # No error raised # %% test_grad = numpy.random.rand(1,2,13,100) numpy.testing.assert_array_equal(old_norm_grad(test_grad), new_norm_grad(test_grad)) # raised TypeError: only size-1 arrays can be converted to Python scalars # %% ```

I have created a commit targeting this issue, if applicable I could pr anytime: https://github.com/zegnog/allennlp/commit/2415cb96d424a71e8ecf6bf1b6e1755eb35de624

epwalsh commented 3 years ago

Hi @zegnog, yes please submit a PR when you get a chance.

github-actions[bot] commented 3 years ago

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇