TinyLLaVA / TinyLLaVA_Factory

A Framework of Small-scale Large Multimodal Models
https://arxiv.org/abs/2402.14289
Apache License 2.0
672 stars 70 forks source link

flash_attn_2 undefined symbol #141

Open pspdada opened 2 days ago

pspdada commented 2 days ago

I follow the instruction to set up env:

git clone https://github.com/TinyLLaVA/TinyLLaVA_Factory.git
cd TinyLLaVA_Factory
conda create -n tinyllava_factory python=3.10 -y
conda activate tinyllava_factory
pip install --upgrade pip  # enable PEP 660 support
pip install -e .
pip install flash-attn --no-build-isolation

but encountered an error when try to load the model:

Traceback (most recent call last):
  File "/root/llm-project/TinyLLaVA_Factory/mycode/inference.py", line 10, in <module>
    model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 550, in from_pretrained
    model_class = get_class_from_dynamic_module(
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/dynamic_module_utils.py", line 501, in get_class_from_dynamic_module
    return get_class_in_module(class_name, final_module)
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/dynamic_module_utils.py", line 201, in get_class_in_module
    module = importlib.machinery.SourceFileLoader(name, module_path).load_module()
  File "<frozen importlib._bootstrap_external>", line 548, in _check_name_wrapper
  File "<frozen importlib._bootstrap_external>", line 1063, in load_module
  File "<frozen importlib._bootstrap_external>", line 888, in load_module
  File "<frozen importlib._bootstrap>", line 290, in _load_module_shim
  File "<frozen importlib._bootstrap>", line 719, in _load
  File "<frozen importlib._bootstrap>", line 688, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 883, in exec_module
  File "<frozen importlib._bootstrap>", line 241, in _call_with_frames_removed
  File "/root/llm-project/utils/models/modules/transformers_modules/tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B/a98601f69e72442f71721aefcfbcdce26db8982a/modeling_tinyllava_phi.py", line 27, in <module>
    from transformers import AutoConfig, AutoModelForCausalLM, PhiForCausalLM
  File "<frozen importlib._bootstrap>", line 1075, in _handle_fromlist
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1501, in __getattr__
    value = getattr(module, name)
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1500, in __getattr__
    module = self._get_module(self._class_to_module[name])
  File "/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/transformers/utils/import_utils.py", line 1512, in _get_module
    raise RuntimeError(
RuntimeError: Failed to import transformers.models.phi.modeling_phi because of the following error (look up to see its traceback):
/root/anaconda3/envs/tinyllava_factory/lib/python3.10/site-packages/flash_attn_2_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c105Error4whatEv

The code I use:


from transformers import AutoModelForCausalLM, AutoTokenizer

hf_path = "tinyllava/TinyLLaVA-Phi-2-SigLIP-3.1B"
model = AutoModelForCausalLM.from_pretrained(hf_path, trust_remote_code=True)
model.cuda()
config = model.config
tokenizer = AutoTokenizer.from_pretrained(
    hf_path,
    use_fast=False,
    model_max_length=config.tokenizer_model_max_length,
    padding_side=config.tokenizer_padding_side,
)
prompt = "What are these?"
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"
output_text, genertaion_time = model.chat(prompt=prompt, image=image_url, tokenizer=tokenizer)

print("model output:", output_text)
print("runing time:", genertaion_time)

How can I solve this?

pspdada commented 2 days ago

Specifying the version of flash-attn to install can resolve this issue. It seems that the latest version of flash-attn is incompatible with torch==2.0.1 specified in the pyproject.toml. I was able to resolve the problem using the following version:

pip install flash-attn==2.1.0 --no-build-isolation

Would it be possible to add this to the README to prevent others from encountering the same problem?

ZhangXJ199 commented 1 day ago

Thank you for your suggestion!