lm-sys / FastChat

An open platform for training, serving, and evaluating large language models. Release repo for Vicuna and Chatbot Arena.
Apache License 2.0
36.63k stars 4.52k forks source link

Anybody know what is the version of `flash_attn` used for finetune? #2970

Open Oscarjia opened 8 months ago

Oscarjia commented 8 months ago

When attempting to execute the FastChat\scripts\train_vicuna_7b.sh script, it raises an exception with the following error message:

File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 16, in <module>
    from flash_attn.flash_attn_interface import flash_attn_unpadded_func
ImportError: cannot import name 'flash_attn_unpadded_func' from 'flash_attn.flash_attn_interface' (/usr/local/lib/python3.10/dist-packages/flash_attn/flash_attn_interface.py)

Does anyone know why this error occurred? Additionally, why hasn't the repository provided a requirements.txt file to specify the required environment for fine-tuning the model?

dhruvpes commented 5 months ago

image

Hello, I also am currently facing a similar issue during fine tuning. Despite setting all the Environment variables correctly, this issue still persists.

I am able to run the model locally (both with CPU and GPU option). If you have figured out the solution, could you please share your approach?

Thanks

Oscarjia commented 4 months ago

Yes, actually i recommend you create a docker container for finetuning. For example: FROM nvcr.io/nvidia/pytorch:23.11-py3 and install some basic package below and then run train_vicuna_7b.sh

pip install peft==0.5.0 \
    transformers==4.37.1 \
    transformers-stream-generator==0.0.4 \
    deepspeed==0.12.3 \
    accelerate==0.26.1 \
    gunicorn==20.1.0 \
    flask==2.1.2 \
    flask_api==3.1 \
    langchain==0.1.4 \
    fastapi==0.109.1 \
    uvicorn==0.19.0 \
    jinja2==3.1.2 \
    huggingface_hub==0.20.3 \
    grpcio-tools==1.60.0 \
    bitsandbytes==0.42.0 \
    sentencepiece==0.1.99 \
    safetensors==0.4.2 \
    datasets==2.16.1 \
    texttable==1.7.0 \
    toml==0.10.2  \
    numpy==1.24.4 \
    scikit-learn==1.3.2\
    loguru==0.7.0 \
    protobuf==4.24.4 \
    pydantic==2.5.1 \
    python-dotenv==1.0.0 \
    tritonclient[all]==2.41.1 \
    sse-starlette==2.0.0 \
    boto3==1.34.30 \
    jsonlines==4.0.0

Hope this can help you.