Lightning-AI / pytorch-lightning

Pretrain, finetune ANY AI model of ANY size on multiple GPUs, TPUs with zero code changes.
Apache License 2.0
28.42k stars 3.39k forks source link

Question answering example AI won't run #17450

Closed noahmartinwilliams closed 10 months ago

noahmartinwilliams commented 1 year ago

Bug description

Running the example code from here doesn't run, and complains about having 2 predictions and 10784 features.

What version are you seeing the problem on?


How to reproduce the bug

#! /usr/bin/python

import pytorch_lightning as pl
from transformers import AutoTokenizer

from lightning_transformers.task.nlp.question_answering import (

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path="bert-base-uncased")
model = QuestionAnsweringTransformer(pretrained_model_name_or_path="bert-base-uncased")

dm = SquadDataModule(batch_size=1, dataset_config_name="plain_text", max_length=384, version_2_with_negative=False, null_score_diff_threshold=0.0, doc_stride=128, n_best_size=20, max_answer_length=30,  tokenizer=tokenizer)

trainer = pl.Trainer(accelerator="auto", devices="auto", max_epochs=1), dm)

Error messages and logs

/home/noah/.local/lib/python3.10/site-packages/torchvision/io/ UserWarning: Failed to load image Python extension: '/home/noah/.local/lib/python3.10/site-packages/torchvision/ undefined symbol: _ZN3c106detail23torchInternalAssertFailEPKcS2_jS2_RKSs'If you don't plan on using image functionality from ``, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForQuestionAnswering: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read
Found cached dataset squad (/home/noah/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 2/2 [00:00<00:00, 914.69it/s]
Loading cached processed dataset at /home/noah/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-ceb593e8bfc95fd5.arrow

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]
Map:   9%|▉         | 1000/10570 [00:00<00:03, 2838.61 examples/s]
Map:  19%|█▉        | 2000/10570 [00:00<00:02, 3423.89 examples/s]
Map:  28%|██▊       | 3000/10570 [00:00<00:02, 3622.99 examples/s]
Map:  38%|███▊      | 4000/10570 [00:01<00:01, 3716.14 examples/s]
Map:  47%|████▋     | 5000/10570 [00:01<00:01, 3516.12 examples/s]
Map:  57%|█████▋    | 6000/10570 [00:01<00:01, 3547.59 examples/s]
Map:  66%|██████▌   | 7000/10570 [00:01<00:00, 3582.17 examples/s]
Map:  76%|███████▌  | 8000/10570 [00:02<00:00, 3258.43 examples/s]
Map:  85%|████████▌ | 9000/10570 [00:02<00:00, 3366.40 examples/s]
Map:  95%|█████████▍| 10000/10570 [00:02<00:00, 3460.12 examples/s]
Map: 100%|██████████| 10570/10570 [00:03<00:00, 3465.60 examples/s]

/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/ FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate:
  self.metric = load_metric("squad")
/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/core/ UserWarning: You haven't specified an optimizer or lr scheduler. Defaulting to AdamW with an lr of 1e-5 and linear warmup for 10% of steps. To change this, override ``configure_optimizers`` in  TransformerModule.
Loading `train_dataloader` to estimate number of stepping batches.
/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/ PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.

  | Name   | Type                     | Params
0 | model  | BertForQuestionAnswering | 108 M 
1 | metric | SquadMetric              | 0     
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.573   Total estimated model params size (MB)
/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/ PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 16 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
Traceback (most recent call last):
  File "/home/noah/src/infinitesimalPhysicist/./", line 17, in <module>, dm)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 520, in fit
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 935, in _run
    results = self._run_stage()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 976, in _run_stage
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 1005, in _run_sanity_check
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/", line 177, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/", line 122, in run
    return self.on_run_end()
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/", line 244, in on_run_end
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/loops/", line 326, in _on_evaluation_epoch_end
    call._call_lightning_module_hook(trainer, hook_name)
  File "/home/noah/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/", line 142, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/", line 64, in on_validation_epoch_end
    metric_dict = self.metric.compute()
  File "/home/noah/.local/lib/python3.10/site-packages/torchmetrics/", line 532, in wrapped_func
    value = compute(*args, **kwargs)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/", line 29, in compute
    predictions, references = self.postprocess_func(predictions=predictions)
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/", line 46, in postprocess_func
    return post_processing_function(
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/", line 179, in post_processing_function
    predictions = postprocess_qa_predictions(
  File "/home/noah/.local/lib/python3.10/site-packages/lightning_transformers/task/nlp/question_answering/datasets/squad/", line 247, in postprocess_qa_predictions
    assert len(predictions[0]) == len(features), f"Got {len(predictions[0])} predictions and {len(features)} features."
AssertionError: Got 2 predictions and 10784 features.


Current environment ``` #- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow): #- PyTorch Lightning Version (e.g., 1.5.0): #- Lightning App Version (e.g., 0.5.2): #- PyTorch Version (e.g., 2.0): #- Python version (e.g., 3.9): #- OS (e.g., Linux): #- CUDA/cuDNN version: #- GPU models and configuration: #- How you installed Lightning(`conda`, `pip`, source): #- Running environment of LightningApp (e.g. local, cloud): ```

More info

It also doesn't work with local datasets.

noahmartinwilliams commented 1 year ago

Just noticed something went wrong with pasting the environment info. So here it is again:

  <summary>Current environment</summary>

    - GPU:
        - NVIDIA GeForce RTX 3050 Laptop GPU
    - available:         True
    - version:           12.1
* Lightning:
    - lightning:         2.0.1.post0
    - lightning-cloud:   0.5.33
    - lightning-transformers: 0.2.5
    - lightning-utilities: 0.8.0
    - pytorch-lightning: 2.0.1.post0
    - torch:             2.0.0
    - torchdata:         0.6.0
    - torchmetrics:      0.11.4
    - torchtext:         0.15.1a0+c696895
    - torchvision:       0.15.1
* Packages:
    - absl-py:           1.4.0
    - aiodns:            3.0.0
    - aiohttp:           3.8.3
    - aiohttp-socks:     0.7.1
    - aiosignal:         1.3.1
    - alabaster:         0.7.13
    - anyio:             3.6.2
    - anytree:           2.8.0
    - appdirs:           1.4.4
    - arandr:            0.1.11
    - arrow:             1.2.3
    - async-timeout:     4.0.2
    - attrs:             22.2.0
    - autocommand:       2.2.2
    - babel:             2.11.0
    - beautifulsoup4:    4.11.2
    - bleach:            6.0.0
    - blessed:           1.20.0
    - btrfsutil:         6.2.2
    - build:             0.10.0
    - cachetools:        5.3.0
    - cchardet:          2.1.7
    - certifi:           2022.12.7
    - cffi:              1.15.1
    - chardet:           5.1.0
    - charset-normalizer: 3.1.0
    - click:             8.1.3
    - conda:             4.14.0
    - conda-package-handling: 1.8.1
    - contourpy:         1.0.7
    - croniter:          1.3.14
    - cryptography:      40.0.1
    - cupshelpers:       1.0
    - cycler:            0.11.0
    - cython:            0.29.34
    - datasets:          2.11.0
    - dateutils:         0.6.12
    - deepdiff:          6.3.0
    - defusedxml:        0.7.1
    - dill:              0.3.6
    - docutils:          0.19
    - elasticsearch:     7.9.0
    - exceptiongroup:    1.1.1
    - fake-useragent:    1.1.3
    - fastapi:           0.88.0
    - fastjsonschema:    2.16.3
    - ffmpeg-python:     0.2.0
    - filelock:          3.12.0
    - fonttools:         4.39.3
    - frozenlist:        1.3.3
    - fsspec:            2023.4.0
    - future:            0.18.2
    - geographiclib:     2.0
    - geopy:             2.3.0
    - gitdb:             4.0.10
    - gitpython:         3.1.30
    - glances: 
    - gmpy2:             2.1.5
    - google-auth:       2.16.2
    - google-auth-oauthlib: 1.0.0
    - googletransx:      2.4.2
    - grpcio:            1.53.0
    - guake:             3.9.1.dev0
    - h11:               0.14.0
    - html2text:         2020.1.16
    - html5lib:          1.1
    - huggingface-hub:   0.12.1
    - idna:              3.4
    - imagesize:         1.4.1
    - importlib-metadata: 5.0.0
    - inflect:           6.0.4
    - iniconfig:         2.0.0
    - inquirer:          3.1.3
    - installer:         0.7.0
    - iotop:             0.6
    - itsdangerous:      2.1.2
    - jaraco.context:    4.3.0
    - jaraco.functools:  3.6.0
    - jaraco.text:       3.11.1
    - jinja2:            3.1.2
    - joblib:            1.2.0
    - kiwisolver:        1.4.4
    - lensfun:           0.3.3
    - libarchive-c:      4.0
    - libfdt:            1.6.1
    - lightning:         2.0.1.post0
    - lightning-cloud:   0.5.33
    - lightning-transformers: 0.2.5
    - lightning-utilities: 0.8.0
    - lit:               15.0.7.dev0
    - lxml:              4.9.2
    - mako:              1.2.4
    - mallard-ducktype:  1.0.2
    - markdown:          3.4.3
    - markdown-it-py:    2.2.0
    - markupsafe:        2.1.2
    - matplotlib:        3.7.1
    - mdurl:             0.1.2
    - meson:             1.0.1
    - mock:              3.0.5
    - more-itertools:    9.1.0
    - mpmath:            1.3.0
    - multidict:         6.0.4
    - multiprocess:      0.70.14
    - netsnmp-python:    1.0a1
    - networkx:          3.1
    - nltk:              3.8.1
    - notmuch:           0.37
    - notmuch2:          0.37
    - nspektr:           0.4.0
    - numpy:             1.24.2
    - oauthlib:          3.2.2
    - openai-whisper:    20230124
    - openshot-qt:       3.0.0
    - ordered-set:       4.1.0
    - packaging:         23.0
    - pandas:            1.5.3
    - pbr:               5.11.1
    - pillow:            9.4.0
    - pip:               23.0.1
    - pivy:              0.6.8
    - platformdirs:      3.2.0
    - pluggy:            1.0.0
    - ply:               3.11
    - pooch:             1.7.0
    - portalocker:       2.7.0
    - protobuf:          4.21.12
    - psutil:            5.9.4
    - pyarrow:           11.0.0
    - pyasn1:            0.4.8
    - pyasn1-modules:    0.2.8
    - pyaudio:           0.2.13
    - pybind11:          2.10.4
    - pycairo:           1.23.0
    - pycares:           4.3.0
    - pycosat:           0.6.4
    - pycparser:         2.21
    - pycups:            2.0.1
    - pycurl:            7.45.2
    - pydantic:          1.10.7
    - pygments:          2.14.0
    - pygobject:         3.44.1
    - pyjwt:             2.6.0
    - pyparsing:         3.0.9
    - pyproject-hooks:   1.0.0
    - pyqt5:             5.15.9
    - pyqt5-sip:         12.12.1
    - pyqtwebengine:     5.15.6
    - pyside2:           5.15.9
    - pysocks:           1.7.1
    - pytest:            7.3.1
    - pytest-mock:       3.10.0
    - pytest-runner:     6.0.0
    - python-dateutil:   2.8.2
    - python-editor:     1.0.4
    - python-multipart:  0.0.6
    - python-socks:      2.0.3
    - pytorch-lightning: 2.0.1.post0
    - pytz:              2022.7.1
    - pyxdg:             0.28
    - pyyaml:            6.0
    - pyzmq:             25.0.2
    - qreator:           20.2.1
    - qrencode:          1.2
    - readchar:          4.0.5
    - regex:             2023.3.23
    - requests:          2.28.2
    - requests-oauthlib: 1.3.1
    - responses:         0.18.0
    - rich:              13.3.4
    - rsa:               4.9
    - ruamel.yaml:       0.17.21
    - ruamel.yaml.clib:  0.2.7
    - schedule:          1.1.0
    - scipy:             1.10.1
    - scons:             4.4.0
    - screenkey:         1.5
    - semantic-version:  2.10.0
    - sentencepiece:     0.1.98
    - setuptools:        67.6.1
    - setuptools-rust:   1.5.2
    - shiboken2:         5.15.9
    - six:               1.16.0
    - smmap:             5.0.0
    - sniffio:           1.3.0
    - snowballstemmer:   2.2.0
    - soupsieve:         2.4
    - speedtest-cli:     2.1.3
    - sphinx:            6.1.3
    - sphinxcontrib-applehelp: 1.0.4
    - sphinxcontrib-devhelp: 1.0.2
    - sphinxcontrib-htmlhelp: 2.0.1
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 1.0.3
    - sphinxcontrib-serializinghtml: 1.1.5
    - starlette:         0.22.0
    - starsessions:      1.3.0
    - sympy:             1.11.1
    - tbb:               0.2
    - tensorboard:       2.12.1
    - tensorboard-data-server: 0.8.0a0
    - tensorboard-plugin-wit: 1.8.1
    - tokenizers:        0.13.2
    - tomli:             2.0.1
    - toot:              0.28.0
    - torch:             2.0.0
    - torchdata:         0.6.0
    - torchmetrics:      0.11.4
    - torchtext:         0.15.1a0+c696895
    - torchvision:       0.15.1
    - tqdm:              4.65.0
    - traitlets:         5.9.0
    - transformers:      4.26.1
    - trove-classifiers: 2023.4.20
    - tweepy:            4.12.1
    - typing-extensions: 4.5.0
    - ujson:             5.7.0
    - urllib3:           1.26.13
    - urwid:             2.1.2
    - uvicorn:           0.21.1
    - validate-pyproject: 0.12.2.post1.dev0+g2940279.d20230328
    - vobject: 
    - wcwidth:           0.2.5
    - webencodings:      0.5.1
    - websocket-client:  1.5.1
    - websockets:        11.0.1
    - werkzeug:          2.2.3
    - wheel:             0.40.0
    - wxpython:          4.2.0
    - xxhash:            3.2.0
    - yarl:              1.8.2
    - youtube-dl:        2021.12.17
    - yt-dlp:            2023.3.4
    - zim:               0.75.1
    - zipp:              3.15.0
* System:
    - OS:                Linux
    - architecture:
        - 64bit
        - ELF
    - processor:         
    - python:            3.10.10
    - version:           #1 SMP PREEMPT_DYNAMIC Thu, 13 Apr 2023 16:59:24 +0000

awaelchli commented 1 year ago

@noahmartinwilliams Is there any indication that this is a problem with Lightning?

noahmartinwilliams commented 1 year ago

It's copied and pasted from the example code.

carmocca commented 1 year ago

Can you try using Lightning 1.7? Lightning transformers is archived and not updated to run with latest lightning

noahmartinwilliams commented 1 year ago

Oh, I didn't know that. Does the latest version of lightning have transformers for question answering?

awaelchli commented 10 months ago

Closing now, since lightning_transformers is archived. We don't have a QA example like you showed. If there a desire to have on in our examples, we need to build one.