mosaicml / llm-foundry

LLM training code for Databricks foundation models
https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm
Apache License 2.0
3.99k stars 525 forks source link

Nicer error message for undefined symbol #1339

Closed dakinggg closed 3 months ago

dakinggg commented 3 months ago

Adds a nicer error message for the most common case of the flash attention install getting messed up.

Before:

ImportError:
/usr/lib/python3/dist-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
: undefined symbol: _ZN3c104cuda9SetDeviceEi

After:

ImportError:
/usr/lib/python3/dist-packages/flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so
: undefined symbol: _ZN3c104cuda9SetDeviceEi

The above exception was the direct cause of the following exception:

╭───────────────────── Traceback (most recent call last) ──────────────────────╮
│ /workspace/llm-foundry/scripts/train/train.py:25 in <module>                 │
│                                                                              │
│    22 from omegaconf import DictConfig                                       │
│    23 from omegaconf import OmegaConf as om                                  │
│    24                                                                        │
│ ❱  25 from llmfoundry.callbacks import AsyncEval, HuggingFaceCheckpointer    │
│    26 from llmfoundry.data.dataloader import build_dataloader                │
│    27 from llmfoundry.eval.metrics.nlp import InContextLearningMetric        │
│    28 from llmfoundry.layers_registry import ffns_with_megablocks            │
│                                                                              │
│ /workspace/llm-foundry/llmfoundry/__init__.py:17 in <module>                 │
│                                                                              │
│   14 │   del flash_attn_func                                                 │
│   15 except ImportError as e:                                                │
│   16 │   if "undefined symbol" in str(e):                                    │
│ ❱ 17 │   │   raise ImportError(                                              │
│   18 │   │   │   "The flash_attn package is not installed correctly. Usually │
│   19 │   │   │   " of PyTorch is different from the version that flash_attn  │
│   20 │   │   │   " workflow has resulted in PyTorch being reinstalled. This  │
╰──────────────────────────────────────────────────────────────────────────────╯
ImportError: The flash_attn package is not installed correctly. Usually this
means that your runtime version. of PyTorch is different from the version that
flash_attn was installed with, which can occur when your workflow has resulted
in PyTorch being reinstalled. This probably happened because you are using an
old docker image with the latest version of LLM Foundry. Check that the PyTorch
version in your Docker image matches the PyTorch version in LLM Foundry setup.py
and update accordingly. The latest Docker image can be found in the README.