Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
14.1k stars 1.31k forks source link

[training] `python run.py` raises `ImportError: cannot import name 'GPTBigCodeConfig' from 'transformers'` #996

Open yumemio opened 4 months ago

yumemio commented 4 months ago

How to reproduce the error

  1. Launch the Docker training environment
    $ cd flash-attention/training
    $ docker build --tag 'flash-attention' .
    $ docker run --interactive --tty flash-attention --volume $(realpath ..):/work bash
  2. Run the training script
    # HYDRA_FULL_ERROR=1 python run.py \
      experiment=owt/gpt2s-flash

What I expected to happen

The training starts without raising an error.

What happened instead

I got the following ImportError:

[2024-06-21 09:12:07,875][src.tasks.seq][INFO] - Instantiating model <flash_attn.models.gpt.GPTLMHeadModel>
Error executing job with overrides: ['experiment=owt/gpt2s-flash', 'trainer.devices=1']
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/hydra/_internal/utils.py", line 644, in _locate
    obj = getattr(obj, part)
AttributeError: module 'flash_attn.models' has no attribute 'gpt'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/hydra/_internal/utils.py", line 650, in _locate
    obj = import_module(mod)
  File "/usr/lib/python3.8/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1014, in _gcd_import
  File "<frozen importlib._bootstrap>", line 991, in _find_and_load
  File "<frozen importlib._bootstrap>", line 975, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 671, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 848, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/usr/local/lib/python3.8/dist-packages/flash_attn/models/gpt.py", line 17, in <module>
    from flash_attn.models.bigcode import remap_state_dict_hf_bigcode
  File "/usr/local/lib/python3.8/dist-packages/flash_attn/models/bigcode.py", line 7, in <module>
    from transformers import GPT2Config, GPTBigCodeConfig, PretrainedConfig
ImportError: cannot import name 'GPTBigCodeConfig' from 'transformers' (/usr/local/lib/python3.8/dist-packages/transformers/__init__.py)

Why

There are two dependency issues with the current Dockerfile:

Workaround

Install transformers==4.33.1 and sentencepiece inside the container.

# pip install -U transformers==4.33.1 sentencepiece==0.1.99

# HYDRA_FULL_ERROR=1 python run.py \
  experiment=owt/gpt2s-flash
...
┏━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓
┃   ┃ Name          ┃ Type             ┃ Params ┃
┡━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩
│ 0 │ model         │ GPTLMHeadModel   │  125 M │
│ 1 │ loss_fn       │ CrossEntropyLoss │      0 │
│ 2 │ loss_fn_val   │ CrossEntropyLoss │      0 │
│ 3 │ train_metrics │ MetricCollection │      0 │
│ 4 │ val_metrics   │ MetricCollection │      0 │
│ 5 │ test_metrics  │ MetricCollection │      0 │
└───┴───────────────┴──────────────────┴────────┘
Trainable params: 125 M                                                         
Non-trainable params: 0                                                         
Total params: 125 M                                                             
Total estimated model params size (MB): 250                                     
...
tridao commented 4 months ago

Yeah i think the transformers version is the issue.