Open rongou opened 2 weeks ago
Hi @rongou, we're in the midst of transitioning our CI to Quantsight's Openstack servers at #10 to build flash-attn v2.6.3, and will consider your request after that. We'll need to wait for that PR to have have more CI capability for longer builds, and you're welcome to open a PR to add those additional outputs afterwards.
If I'm not mistaken, you'll need to modify these lines in setup.py
to include csrc/fused_dense_lib
and csrc/layer_norm
:
However, I took a quick look at https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/csrc/layer_norm/README.md and https://github.com/Dao-AILab/flash-attention/issues/794#issuecomment-1913893699, and it sounds like layer_norm
is deprecated in flash-attn
, so unsure if we should support it here. Could you point to the code in InternVideo2 that is using this?
My understanding is InternVideo2 was trained using an older version of flash-attn
and relied on layer_norm
, I'm not sure if we can just swap out the module for inference. Let me play around a bit to see if I can get it to work with the new implementation.
Looks like the current flash-attn
code still references dropout_layer_norm
, which is provided by the layer_norm
extension:
https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/layer_norm.py#L4
My understanding is InternVideo2 was trained using an older version of
flash-attn
and relied onlayer_norm
, I'm not sure if we can just swap out the module for inference. Let me play around a bit to see if I can get it to work with the new implementation.
Ok, if you're confident that it's possible to get InternVideo2 to work with the layer_norm
function in flash-attn>=2.6.3
, then we can try to add support for it here. I see at https://github.com/OpenGVLab/InternVideo/blob/eca2cdc5a67d7442063d19963515b5bd0feef627/InternVideo2/single_modality/requirements.txt#L6 that they used flash_attn==2.0.8
, though I'm unsure how compatible those versions would be.
Yes I've tested with flash-attn
2.6.3.
Cool, @rongou, now that #10 is merged, would you like to open a PR to add in fused_dense_lib
and layer_norm
? While testing, it would be good if you can uncomment these lines:
so that we're only using 1 OpenStack runner at a time. I'd also recommend reducing TORCH_CUDA_ARCH_LIST
here:
to just TORCH_CUDA_ARCH_LIST=8.0+PTX
as you try things out, otherwise the CI builds will take 6-9 hours. Let us know if you need any help.
Some models (e.g. InternVideo2 multi modality) depend on flash attention extensions. We would like to add additional outputs for:
fused_dense_lib: csrc/fused_dense_lib layer_norm: csrc/layer_norm