allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.71k stars 2.24k forks source link

T5 Module Self Attention Overwrites Dropout Configuration #5683

Closed MSLars closed 1 year ago

MSLars commented 2 years ago

Checklist

Description

When I create a T5 Module and configure a custom dropout rate

T5Module.from_pretrained_module(
            model_name,
            beam_search=beam_search,
            ddp_accelerator=self.ddp_accelerator,
            auto_config_kwargs={"dropout_rate": 0.0},
            checkpoint_wrapper=checkpoint_wrapper,
            output_attentions=self.interpretation,
            weights_path=weights_path,
            tie_word_embeddings=tie_word_embeddings,
            label_smoothing=label_smoothing,
        )

the value gets replaced with the standard configuration in T5EncoderStack. The behavior is similar in the Decoder.

It seems not possible to set a value different from 0.1.

This

Related issues or possible duplicates

Proposed Solution

This can be easily solved by adding the dropout parameter to the T5Block call.

  block = T5Block(
      attention=T5LayerSelfAttention(
          self_attention=block_self_attention.construct(
              is_decoder=False, has_relative_attention_bias=(i == 0)
          ),
          dropout=dropout
      ),

If I am right that this would be the desired behavior, I can prepare a merge request.

MSLars commented 1 year ago

Evironment

Python 3.9, Ubuntu 10.04

pip freeze:

absl-py==1.0.0
aiohttp==3.8.1
aiosignal==1.2.0
alembic==1.8.0
allennlp==2.8.0
allennlp-models==2.8.0
allennlp-optuna==0.1.7
altair==4.2.0
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
astor==0.8.1
asttokens==2.0.5
astunparse==1.6.3
async-timeout==4.0.2
atomicwrites @ file:///home/conda/feedstock_root/build_artifacts/atomicwrites_1588182545583/work
attrs==21.4.0
autopage==0.5.1
backcall==0.2.0
backports.csv==1.0.7
base58==2.1.1
beautifulsoup4==4.10.0
black==21.12b0
bleach==4.1.0
blinker==1.4
blis==0.7.5
boto3==1.20.37
botocore==1.23.37
bottle==0.12.21
cached-path==0.3.2
cachetools==4.2.4
catalogue==2.0.6
certifi==2021.10.8
cffi==1.15.0
chardet==4.0.0
charset-normalizer==2.0.10
checklist==0.0.11
cheroot==8.6.0
CherryPy==18.6.1
click==8.0.3
cliff==3.10.1
cmaes==0.8.2
cmd2==2.4.1
colorlog==6.6.0
configparser==5.2.0
conllu==4.4.1
cryptography==36.0.1
cycler @ file:///home/conda/feedstock_root/build_artifacts/cycler_1635519461629/work
cymem==2.0.6
datasets==1.17.0
de-core-news-sm @ https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.1.0/de_core_news_sm-3.1.0-py3-none-any.whl
debugpy==1.5.1
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.4
docker-pycreds==0.4.0
en-core-web-md @ https://github.com/explosion/spacy-models/releases/download/en_core_web_md-3.1.0/en_core_web_md-3.1.0-py3-none-any.whl
entrypoints==0.3
executing==0.8.2
fairscale==0.4.0
feedparser==6.0.8
filelock==3.3.2
flatbuffers==2.0
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1651017733844/work
frozenlist==1.2.0
fsspec==2022.1.0
ftfy==6.0.3
future==0.18.2
gast==0.4.0
gitdb==4.0.9
GitPython==3.1.26
google-api-core==2.4.0
google-auth==2.3.3
google-auth-oauthlib==0.4.6
google-cloud-core==2.2.1
google-cloud-storage==1.44.0
google-crc32c==1.3.0
google-pasta==0.2.0
google-resumable-media==2.1.0
googleapis-common-protos==1.54.0
greenlet==1.1.2
grpcio==1.43.0
h5py==3.6.0
huggingface-hub==0.1.2
idna==3.3
importlib-metadata==4.10.1
iniconfig==1.1.1
ipykernel==6.7.0
ipython==8.0.0
ipython-genutils==0.2.0
ipywidgets==7.6.5
iso-639==0.4.5
jaraco.classes==3.2.1
jaraco.collections==3.5.1
jaraco.functools==3.5.0
jaraco.text==3.6.0
jedi==0.18.1
Jinja2==3.0.3
jmespath==0.10.0
joblib==1.1.0
Js2Py @ file:///home/conda/feedstock_root/build_artifacts/js2py_1617952197358/work
jsonnet==0.18.0
jsonschema==4.4.0
jupyter==1.0.0
jupyter-client==7.1.1
jupyter-console==6.4.0
jupyter-core==4.9.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.2
keras==2.7.0
Keras-Preprocessing==1.1.2
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/kiwisolver_1648854392795/work
lark @ file:///home/conda/feedstock_root/build_artifacts/lark_1636976757037/work
libclang==12.0.0
lmdb==1.3.0
lxml==4.7.1
Mako==1.2.0
Markdown==3.3.6
MarkupSafe==2.0.1
matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1639359646028/work
matplotlib-inline==0.1.3
mistune==0.8.4
more-itertools==8.12.0
multidict==5.2.0
multiprocess==0.70.12.2
munch==2.5.0
munkres==1.1.4
murmurhash==1.0.6
mypy-extensions==0.4.3
nbclient==0.5.10
nbconvert==6.4.0
nbformat==5.1.3
nest-asyncio==1.5.4
nltk==3.6.3
notebook==6.4.7
numpy @ file:///home/conda/feedstock_root/build_artifacts/numpy_1651020388495/work
oauthlib==3.1.1
olefile @ file:///home/conda/feedstock_root/build_artifacts/olefile_1602866521163/work
opt-einsum==3.3.0
optuna==2.10.1
optuna-dashboard==0.7.1
overrides==3.1.0
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1637239678211/work
pandas==1.4.2
pandocfilters==1.5.0
parso==0.8.3
pathspec==0.9.0
pathtools==0.1.2
pathy==0.6.1
patsy @ file:///home/conda/feedstock_root/build_artifacts/patsy_1632667180946/work
patternfork-nosql==3.6
pbr==5.9.0
pdfminer.six==20211012
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.0.0
platformdirs==2.4.1
pluggy==1.0.0
portend==3.1.0
preshed==3.0.6
prettytable==3.3.0
prometheus-client==0.12.0
promise==2.3
prompt-toolkit==3.0.24
protobuf==3.19.3
psutil==5.9.0
ptyprocess==0.7.0
pure-eval==0.2.1
py==1.11.0
py-rouge==1.1
pyarrow==6.0.1
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.21
pydantic==1.8.2
pydeck==0.7.1
Pygments==2.11.2
pyjsparser==2.7.1
Pympler==1.0.1
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/pyparsing_1649603503565/work
pyperclip==1.8.2
pyrsistent==0.18.1
pytest==6.2.5
pytest-lazy-fixture==0.6.3
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/python-dateutil_1626286286081/work
python-docx==0.8.11
pytz @ file:///home/conda/feedstock_root/build_artifacts/pytz_1647961439546/work
pytz-deprecation-shim @ file:///home/conda/feedstock_root/build_artifacts/pytz-deprecation-shim_1637074462379/work
PyYAML==6.0
pyzmq==22.3.0
qtconsole==5.2.2
QtPy==2.0.0
regex @ file:///home/conda/feedstock_root/build_artifacts/regex_1642551759955/work
requests==2.27.1
requests-oauthlib==1.3.0
rsa==4.8
s3transfer==0.5.0
sacremoses==0.0.47
scikit-learn==1.0.2
scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1644357253444/work
seaborn @ file:///home/conda/feedstock_root/build_artifacts/seaborn-split_1629095986539/work
Send2Trash==1.8.0
sentencepiece==0.1.96
sentry-sdk==1.5.2
sgmllib3k==1.0.0
shortuuid==1.0.8
six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
smart-open==5.2.1
smmap==5.0.0
soupsieve==2.3.1
spacy==3.1.4
spacy-legacy==3.0.8
SQLAlchemy==1.4.39
sqlitedict==1.7.0
srsly==2.4.2
stack-data==0.1.4
statsmodels @ file:///home/conda/feedstock_root/build_artifacts/statsmodels_1644535581977/work
stevedore==3.5.0
streamlit==1.4.0
subprocess32==3.5.4
tempora==5.0.0
tensorboard==2.7.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorboardX==2.4.1
tensorflow-estimator==2.7.0
tensorflow-io-gcs-filesystem==0.23.1
termcolor==1.1.0
terminado==0.12.1
termplotlib==0.3.9
testpath==0.5.0
thinc==8.0.13
threadpoolctl==3.0.0
tokenizers==0.10.3
toml==0.10.2
tomli==1.2.3
toolz==0.11.2
torch==1.10.1
torchvision==0.11.2
tornado==6.1
tqdm==4.62.3
traitlets==5.1.1
transformers==4.12.5
typer==0.4.0
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1638334978229/work
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1635187324762/work
tzlocal @ file:///home/conda/feedstock_root/build_artifacts/tzlocal_1637088138782/work
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1649111919389/work
urllib3==1.26.8
validators==0.18.2
wandb==0.12.9
wasabi==0.9.0
watchdog==2.1.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.2
widgetsnbextension==3.5.2
word2number==1.1
wrapt==1.13.3
xxhash==2.0.2
yarl==1.7.2
yaspin==2.1.0
zc.lockfile==2.0
zipp==3.7.0

Steps to Reproduce

Create a model with a T5Module as an argument and set the dropout rate to 0. Check the dropout or dropout_rate of all created submodules.

github-actions[bot] commented 1 year ago

This issue is being closed due to lack of activity. If you think it still needs to be addressed, please comment on this thread 👇