conda-forge / flash-attn-feedstock

A conda-smithy repository for flash-attn.
BSD 3-Clause "New" or "Revised" License
1 stars 4 forks source link

Feature Request: Add multiple outputs for fused_dense_lib and layer_norm #18

Open rongou opened 2 weeks ago

rongou commented 2 weeks ago

Some models (e.g. InternVideo2 multi modality) depend on flash attention extensions. We would like to add additional outputs for:

fused_dense_lib: csrc/fused_dense_lib layer_norm: csrc/layer_norm

weiji14 commented 2 weeks ago

Hi @rongou, we're in the midst of transitioning our CI to Quantsight's Openstack servers at #10 to build flash-attn v2.6.3, and will consider your request after that. We'll need to wait for that PR to have have more CI capability for longer builds, and you're welcome to open a PR to add those additional outputs afterwards.

If I'm not mistaken, you'll need to modify these lines in setup.py to include csrc/fused_dense_lib and csrc/layer_norm:

https://github.com/conda-forge/flash-attn-feedstock/blob/ba22f20da7063ead07294dc872b47fcb5e4074ca/recipe/setup.py#L131-L135

However, I took a quick look at https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/csrc/layer_norm/README.md and https://github.com/Dao-AILab/flash-attention/issues/794#issuecomment-1913893699, and it sounds like layer_norm is deprecated in flash-attn, so unsure if we should support it here. Could you point to the code in InternVideo2 that is using this?

rongou commented 2 weeks ago

My understanding is InternVideo2 was trained using an older version of flash-attn and relied on layer_norm, I'm not sure if we can just swap out the module for inference. Let me play around a bit to see if I can get it to work with the new implementation.

rongou commented 2 weeks ago

Looks like the current flash-attn code still references dropout_layer_norm, which is provided by the layer_norm extension: https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/layer_norm.py#L4

weiji14 commented 2 weeks ago

My understanding is InternVideo2 was trained using an older version of flash-attn and relied on layer_norm, I'm not sure if we can just swap out the module for inference. Let me play around a bit to see if I can get it to work with the new implementation.

Ok, if you're confident that it's possible to get InternVideo2 to work with the layer_norm function in flash-attn>=2.6.3, then we can try to add support for it here. I see at https://github.com/OpenGVLab/InternVideo/blob/eca2cdc5a67d7442063d19963515b5bd0feef627/InternVideo2/single_modality/requirements.txt#L6 that they used flash_attn==2.0.8, though I'm unsure how compatible those versions would be.

rongou commented 2 weeks ago

Yes I've tested with flash-attn 2.6.3.

weiji14 commented 1 week ago

Cool, @rongou, now that #10 is merged, would you like to open a PR to add in fused_dense_lib and layer_norm? While testing, it would be good if you can uncomment these lines:

https://github.com/conda-forge/flash-attn-feedstock/blob/bc32f152ad78c9940dc01a7a2cb29c86bce22072/recipe/meta.yaml#L25-L27

so that we're only using 1 OpenStack runner at a time. I'd also recommend reducing TORCH_CUDA_ARCH_LIST here:

https://github.com/conda-forge/flash-attn-feedstock/blob/bc32f152ad78c9940dc01a7a2cb29c86bce22072/recipe/meta.yaml#L21

to just TORCH_CUDA_ARCH_LIST=8.0+PTX as you try things out, otherwise the CI builds will take 6-9 hours. Let us know if you need any help.