huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
129.62k stars 25.74k forks source link

Import guards not working properly for remote models on the hub #29735

Open amyeroberts opened 4 months ago

amyeroberts commented 4 months ago

System Info

transformers 4.39.0.dev

Who can help?

@younesbelkada as you reported :)

@amyeroberts @ArthurZucker as it's probably a core maintenance issue

Information

Tasks

Reproduction

When trying to load certain models on the hub that have import guards e.g. if is_flash_attn_2_available() the model will still raise an error if flash attention isn't installed in the environment.

For reference:

Expected behavior

Same behaviour for models in the transformers repo and remote - if an import guard is used, we don't get errors triggered that a certain library isn't installed

ArthurZucker commented 3 months ago

cc @Rocketknight1 do you want to take this on?

Rocketknight1 commented 3 months ago

Sure, I'll take it!

Rocketknight1 commented 3 months ago

Hey, I'm slightly confused about this issue. In the reproduction branch (refs/pr/8), the import guard block is just:

if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input

Later in the code, the attention module to use is just chosen based on the value of config._attn_implementation, and the code never checks if the imported vars like flash_attn_func are defined or not. Therefore, it makes sense that an error will result if the flash attention modules aren't installed.

Can someone clarify what the expected behaviour here is, and what exactly the bug is?

amyeroberts commented 3 months ago

@Rocketknight1 - I'll let @younesbelkada confirm, but my understanding is that there shouldn't be an error if we have the is_flash_attn_2_available guard, flash attention isn't installed and the model has a flash attention class.

Specifically, @younesbelkada reported that there's an import error being flagged in the from flash_attn lines when loading the model with AutoModel, even if eager attention was selected. The equivalent doesn't happen for the models in the repo e.g. if I run model = AutoModel.from_pretrained('distilbert-base-uncased') if I don't have flash attention installed.

younesbelkada commented 3 months ago

Hi everyone ! yes I second what @amyeroberts said, imo users should be able to properly load that model even if FA2 is not installed, right now having FA2 modules on trust remote code models will directly lead to an error advising users to have FA2 installed in their env, which is not possible for some devices such as T4 GPUs or CPU devices. I think the fix should be to check if these optional libraries are correctly guarded, and not check for their existence in the environment if that's the case

Rocketknight1 commented 3 months ago

Got it! I'll try to investigate when I get a chance, but right now I'm being DDOSed by a thousand other issues!

Rocketknight1 commented 2 months ago

No stale, please!

amyeroberts commented 4 days ago

Gentle ping @Rocketknight1