Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
28.42k stars 3.39k forks source link

Question answering example AI won't run #17450

Closed noahmartinwilliams closed 10 months ago

noahmartinwilliams commented 1 year ago

Bug description

Running the example code from here doesn't run, and complains about having 2 predictions and 10784 features.

What version are you seeing the problem on?

2.0+

How to reproduce the bug

#! /usr/bin/python

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.question_answering import (
    QuestionAnsweringTransformer,
    SquadDataModule,
)

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased")
model = QuestionAnsweringTransformer(pretrained_model_name_or_path="bert-base-uncased")

dm = SquadDataModule(batch_size=1, dataset_config_name="plain_text", max_length=384, version_2_with_negative=False, null_score_diff_threshold=0.0, doc_stride=128, n_best_size=20, max_answer_length=30,  tokenizer=tokenizer)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1)
trainer.fit(model, dm)

Error messages and logs

/home/noah/.local/lib/python3.10/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: '/home/noah/.local/lib/python3.10/site-packages/torchvision/image.so: undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
  warn(
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Found cached dataset squad (/home/noah/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 914.69it/s]
Loading cached processed dataset at /home/noah/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-ceb593e8bfc95fd5.arrow

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]
Map:   9%|▉         | 1000/10570 [00:00<00:03, 2838.61 examples/s]
Map:  19%|█▉        | 2000/10570 [00:00<00:02, 3423.89 examples/s]
Map:  28%|██▊       | 3000/10570 [00:00<00:02, 3622.99 examples/s]
Map:  38%|███▊      | 4000/10570 [00:01<00:01, 3716.14 examples/s]
Map:  47%|████▋     | 5000/10570 [00:01<00:01, 3516.12 examples/s]
Map:  57%|█████▋    | 6000/10570 [00:01<00:01, 3547.59 examples/s]
Map:  66%|██████▌   | 7000/10570 [00:01<00:00, 3582.17 examples/s]
Map:  76%|███████▌  | 8000/10570 [00:02<00:00, 3258.43 examples/s]
Map:  85%|████████▌ | 9000/10570 [00:02<00:00, 3366.40 examples/s]
Map:  95%|█████████▍| 10000/10570 [00:02<00:00, 3460.12 examples/s]
Map: 100%|██████████| 10570/10570 [00:03<00:00, 3465.60 examples/s]

/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/metric.py:9: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate
  self.metric = load_metric("squad")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/core/model.py:85: UserWarning: You haven't specified an optimizer or lr scheduler. Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. To change this, override ``configure_optimizers`` in  TransformerModule.
  rank_zero_warn(
Loading `train_dataloader` to estimate number of stepping batches.
/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

  | Name   | Type                     | Params
----------------------------------------------------
0 | model  | BertForQuestionAnswering | 108 M 
1 | metric | SquadMetric              | 0     
----------------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.573   Total estimated model params size (MB)
/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Traceback (most recent call last):
  File "/home/noah/src/infinitesimalPhysicist/./trainer.py", line 17, in <module>
    trainer.fit(model, dm)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 122, in run
    return self.on_run_end()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 244, in on_run_end
    self._on_evaluation_epoch_end()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 326, in _on_evaluation_epoch_end
    call._call_lightning_module_hook(trainer, hook_name)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/model.py", line 64, in on_validation_epoch_end
    metric_dict = self.metric.compute()
  File "/home/noah/.local/lib/python3.10/site-packages/torchmetrics/metric.py", line 532, in wrapped_func
    value = compute(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/metric.py", line 29, in compute
    predictions, references = self.postprocess_func(predictions=predictions)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/data.py", line 46, in postprocess_func
    return post_processing_function(
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/processing.py", line 179, in post_processing_function
    predictions = postprocess_qa_predictions(
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/processing.py", line 247, in postprocess_qa_predictions
    assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
AssertionError: Got 2 predictions and 10784 features.

Environment

Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

It also doesn't work with local datasets.

noahmartinwilliams commented 1 year ago

Just noticed something went wrong with pasting the environment info. So here it is again:

<details>
  <summary>Current environment</summary>

* CUDA:
    - GPU:
        - NVIDIA GeForce RTX 3050 Laptop GPU
    - available:         True
    - version:           12.1
* Lightning:
    - lightning:         2.0.1.post0
    - lightning-cloud:   0.5.33
    - lightning-transformers: 0.2.5
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.1.post0
    - torch:             2.0.0
    - torchdata:         0.6.0
    - torchmetrics:      0.11.4
    - torchtext:         0.15.1a0+c696895
    - torchvision:       0.15.1
* Packages:
    - absl-py:           1.4.0
    - aiodns:            3.0.0
    - aiohttp:           3.8.3
    - aiohttp-socks:     0.7.1
    - aiosignal:         1.3.1
    - alabaster:         0.7.13
    - anyio:             3.6.2
    - anytree:           2.8.0
    - appdirs:           1.4.4
    - arandr:            0.1.11
    - arrow:             1.2.3
    - async-timeout:     4.0.2
    - attrs:             22.2.0
    - autocommand:       2.2.2
    - babel:             2.11.0
    - beautifulsoup4:    4.11.2
    - bleach:            6.0.0
    - blessed:           1.20.0
    - btrfsutil:         6.2.2
    - build:             0.10.0
    - cachetools:        5.3.0
    - cchardet:          2.1.7
    - certifi:           2022.12.7
    - cffi:              1.15.1
    - chardet:           5.1.0
    - charset-normalizer: 3.1.0
    - click:             8.1.3
    - conda:             4.14.0
    - conda-package-handling: 1.8.1
    - contourpy:         1.0.7
    - croniter:          1.3.14
    - cryptography:      40.0.1
    - cupshelpers:       1.0
    - cycler:            0.11.0
    - cython:            0.29.34
    - datasets:          2.11.0
    - dateutils:         0.6.12
    - deepdiff:          6.3.0
    - defusedxml:        0.7.1
    - dill:              0.3.6
    - docutils:          0.19
    - elasticsearch:     7.9.0
    - exceptiongroup:    1.1.1
    - fake-useragent:    1.1.3
    - fastapi:           0.88.0
    - fastjsonschema:    2.16.3
    - ffmpeg-python:     0.2.0
    - filelock:          3.12.0
    - fonttools:         4.39.3
    - frozenlist:        1.3.3
    - fsspec:            2023.4.0
    - future:            0.18.2
    - geographiclib:     2.0
    - geopy:             2.3.0
    - gitdb:             4.0.10
    - gitpython:         3.1.30
    - glances:           3.3.1.1
    - gmpy2:             2.1.5
    - google-auth:       2.16.2
    - google-auth-oauthlib: 1.0.0
    - googletransx:      2.4.2
    - grpcio:            1.53.0
    - guake:             3.9.1.dev0
    - h11:               0.14.0
    - html2text:         2020.1.16
    - html5lib:          1.1
    - huggingface-hub:   0.12.1
    - idna:              3.4
    - imagesize:         1.4.1
    - importlib-metadata: 5.0.0
    - inflect:           6.0.4
    - iniconfig:         2.0.0
    - inquirer:          3.1.3
    - installer:         0.7.0
    - iotop:             0.6
    - itsdangerous:      2.1.2
    - jaraco.context:    4.3.0
    - jaraco.functools:  3.6.0
    - jaraco.text:       3.11.1
    - jinja2:            3.1.2
    - joblib:            1.2.0
    - kiwisolver:        1.4.4
    - lensfun:           0.3.3
    - libarchive-c:      4.0
    - libfdt:            1.6.1
    - lightning:         2.0.1.post0
    - lightning-cloud:   0.5.33
    - lightning-transformers: 0.2.5
    - lightning-utilities: 0.8.0
    - lit:               15.0.7.dev0
    - lxml:              4.9.2
    - mako:              1.2.4
    - mallard-ducktype:  1.0.2
    - markdown:          3.4.3
    - markdown-it-py:    2.2.0
    - markupsafe:        2.1.2
    - matplotlib:        3.7.1
    - mdurl:             0.1.2
    - meson:             1.0.1
    - mock:              3.0.5
    - more-itertools:    9.1.0
    - mpmath:            1.3.0
    - multidict:         6.0.4
    - multiprocess:      0.70.14
    - netsnmp-python:    1.0a1
    - networkx:          3.1
    - nltk:              3.8.1
    - notmuch:           0.37
    - notmuch2:          0.37
    - nspektr:           0.4.0
    - numpy:             1.24.2
    - oauthlib:          3.2.2
    - openai-whisper:    20230124
    - openshot-qt:       3.0.0
    - ordered-set:       4.1.0
    - packaging:         23.0
    - pandas:            1.5.3
    - pbr:               5.11.1
    - pillow:            9.4.0
    - pip:               23.0.1
    - pivy:              0.6.8
    - platformdirs:      3.2.0
    - pluggy:            1.0.0
    - ply:               3.11
    - pooch:             1.7.0
    - portalocker:       2.7.0
    - protobuf:          4.21.12
    - psutil:            5.9.4
    - pyarrow:           11.0.0
    - pyasn1:            0.4.8
    - pyasn1-modules:    0.2.8
    - pyaudio:           0.2.13
    - pybind11:          2.10.4
    - pycairo:           1.23.0
    - pycares:           4.3.0
    - pycosat:           0.6.4
    - pycparser:         2.21
    - pycups:            2.0.1
    - pycurl:            7.45.2
    - pydantic:          1.10.7
    - pygments:          2.14.0
    - pygobject:         3.44.1
    - pyjwt:             2.6.0
    - pyparsing:         3.0.9
    - pyproject-hooks:   1.0.0
    - pyqt5:             5.15.9
    - pyqt5-sip:         12.12.1
    - pyqtwebengine:     5.15.6
    - pyside2:           5.15.9
    - pysocks:           1.7.1
    - pytest:            7.3.1
    - pytest-mock:       3.10.0
    - pytest-runner:     6.0.0
    - python-dateutil:   2.8.2
    - python-editor:     1.0.4
    - python-multipart:  0.0.6
    - python-socks:      2.0.3
    - pytorch-lightning: 2.0.1.post0
    - pytz:              2022.7.1
    - pyxdg:             0.28
    - pyyaml:            6.0
    - pyzmq:             25.0.2
    - qreator:           20.2.1
    - qrencode:          1.2
    - readchar:          4.0.5
    - regex:             2023.3.23
    - requests:          2.28.2
    - requests-oauthlib: 1.3.1
    - responses:         0.18.0
    - rich:              13.3.4
    - rsa:               4.9
    - ruamel.yaml:       0.17.21
    - ruamel.yaml.clib:  0.2.7
    - schedule:          1.1.0
    - scipy:             1.10.1
    - scons:             4.4.0
    - screenkey:         1.5
    - semantic-version:  2.10.0
    - sentencepiece:     0.1.98
    - setuptools:        67.6.1
    - setuptools-rust:   1.5.2
    - shiboken2:         5.15.9
    - six:               1.16.0
    - smmap:             5.0.0
    - sniffio:           1.3.0
    - snowballstemmer:   2.2.0
    - soupsieve:         2.4
    - speedtest-cli:     2.1.3
    - sphinx:            6.1.3
    - sphinxcontrib-applehelp: 1.0.4
    - sphinxcontrib-devhelp: 1.0.2
    - sphinxcontrib-htmlhelp: 2.0.1
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.3
    - sphinxcontrib-serializinghtml: 1.1.5
    - starlette:         0.22.0
    - starsessions:      1.3.0
    - sympy:             1.11.1
    - tbb:               0.2
    - tensorboard:       2.12.1
    - tensorboard-data-server: 0.8.0a0
    - tensorboard-plugin-wit: 1.8.1
    - tokenizers:        0.13.2
    - tomli:             2.0.1
    - toot:              0.28.0
    - torch:             2.0.0
    - torchdata:         0.6.0
    - torchmetrics:      0.11.4
    - torchtext:         0.15.1a0+c696895
    - torchvision:       0.15.1
    - tqdm:              4.65.0
    - traitlets:         5.9.0
    - transformers:      4.26.1
    - trove-classifiers: 2023.4.20
    - tweepy:            4.12.1
    - typing-extensions: 4.5.0
    - ujson:             5.7.0
    - urllib3:           1.26.13
    - urwid:             2.1.2
    - uvicorn:           0.21.1
    - validate-pyproject: 0.12.2.post1.dev0+g2940279.d20230328
    - vobject:           0.9.6.1
    - wcwidth:           0.2.5
    - webencodings:      0.5.1
    - websocket-client:  1.5.1
    - websockets:        11.0.1
    - werkzeug:          2.2.3
    - wheel:             0.40.0
    - wxpython:          4.2.0
    - xxhash:            3.2.0
    - yarl:              1.8.2
    - youtube-dl:        2021.12.17
    - yt-dlp:            2023.3.4
    - zim:               0.75.1
    - zipp:              3.15.0
* System:
    - OS:                Linux
    - architecture:
        - 64bit
        - ELF
    - processor:         
    - python:            3.10.10
    - version:           #1 SMP PREEMPT_DYNAMIC Thu, 13 Apr 2023 16:59:24 +0000

**</details>**
awaelchli commented 1 year ago

@noahmartinwilliams Is there any indication that this is a problem with Lightning?

noahmartinwilliams commented 1 year ago

It's copied and pasted from the example code.

carmocca commented 1 year ago

Can you try using Lightning 1.7? Lightning transformers is archived and not updated to run with latest lightning

noahmartinwilliams commented 1 year ago

Oh, I didn't know that. Does the latest version of lightning have transformers for question answering?

awaelchli commented 10 months ago

Closing now, since lightning_transformers is archived. We don't have a QA example like you showed. If there a desire to have on in our examples, we need to build one.