Open amyeroberts opened 4 months ago
cc @Rocketknight1 do you want to take this on?
Sure, I'll take it!
Hey, I'm slightly confused about this issue. In the reproduction branch (refs/pr/8
), the import guard block is just:
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input
Later in the code, the attention module to use is just chosen based on the value of config._attn_implementation
, and the code never checks if the imported vars like flash_attn_func
are defined or not. Therefore, it makes sense that an error will result if the flash attention modules aren't installed.
Can someone clarify what the expected behaviour here is, and what exactly the bug is?
@Rocketknight1 - I'll let @younesbelkada confirm, but my understanding is that there shouldn't be an error if we have the is_flash_attn_2_available
guard, flash attention isn't installed and the model has a flash attention class.
Specifically, @younesbelkada reported that there's an import error being flagged in the from flash_attn
lines when loading the model with AutoModel
, even if eager attention was selected. The equivalent doesn't happen for the models in the repo e.g. if I run model = AutoModel.from_pretrained('distilbert-base-uncased')
if I don't have flash attention installed.
Hi everyone ! yes I second what @amyeroberts said, imo users should be able to properly load that model even if FA2 is not installed, right now having FA2 modules on trust remote code models will directly lead to an error advising users to have FA2 installed in their env, which is not possible for some devices such as T4 GPUs or CPU devices. I think the fix should be to check if these optional libraries are correctly guarded, and not check for their existence in the environment if that's the case
Got it! I'll try to investigate when I get a chance, but right now I'm being DDOSed by a thousand other issues!
No stale, please!
Gentle ping @Rocketknight1
System Info
transformers 4.39.0.dev
Who can help?
@younesbelkada as you reported :)
@amyeroberts @ArthurZucker as it's probably a core maintenance issue
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
When trying to load certain models on the hub that have import guards e.g.
if is_flash_attn_2_available()
the model will still raise an error if flash attention isn't installed in the environment.For reference:
Expected behavior
Same behaviour for models in the transformers repo and remote - if an import guard is used, we don't get errors triggered that a certain library isn't installed