HuangLK / transpeeder

train llama on a single A100 80G node using 🤗 transformers and 🚀 Deepspeed Pipeline Parallelism
Apache License 2.0
208 stars 18 forks source link

ImportError: cannot import name 'flash_attn_unpadded_qkvpacked_func' from 'flash_attn.flash_attn_interface' #38

Closed BastianChen closed 1 year ago

BastianChen commented 1 year ago

你好,我在执行python convert2ckpt.py --mp_world_size 4 --model_name_or_path /path/to/llama-7b-hf --output_dir /path/to/llama-7b-init-ckpt时报了以下错误:

`ImportError: cannot import name 'flash_attn_unpadded_qkvpacked_func' from 'flash_attn.flash_attn_interface'

看了下flash_attn.flash_attn_interface脚本里面确实没有flash_attn_unpadded_qkvpacked_func函数,我用的环境是pytorch1.13, python3.10, flash-attn.2.0.8, 能否提供下你的环境或者解决方案吗?

HuangLK commented 1 year ago

不好意思,这个代码库之前用的是flash-attn 1.x版本,还没有升级到flash-attn 2,你可以降级到1.0.9

BastianChen commented 1 year ago

我看这个网站里https://mygit.osfipin.com/repository/494232964写了2.x版本将flash_attn_unpadded_qkvpacked_func 更名成了 flash_attn_varlen_qkvpacked_func.

我在代码了使用了最新的命名,可以跑通生成相应的文件,就是还不知道后面的finetune会不会有问题。

HuangLK commented 1 year ago

你可以对比开启与关闭flash-attn的loss曲线,一般来说loss曲线的趋势是一致的,loss值也应该非常接近