facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.67k stars 619 forks source link

[FA3] Link to cuda library to fix the FA3 extension build #1157

Closed xuzhao9 closed 1 day ago

xuzhao9 commented 1 day ago

What does this PR do?

Fix the FA3 extension build by adding the cuda library.

The original flash-attn repo mentioned that -lcuda is required to build and install the FA3 library: https://github.com/Dao-AILab/flash-attention/blob/main/hopper/setup.py#L222C13-L223C31

We need to add the library to xformers too to build the correct so file.

Test Plan:

Before this PR:

$ pip install -e .
# install xformers from source...
$ python -c "import xformers._C_flashattention3"
ImportError: /data/users/xzhao9/tritonbench/submodules/xformers/xformers/_C_flashattention3.so: undefined symbol: cuTensorMapEncodeTiled

After this PR:

$ pip install -e .
# install xformers from source...
$ python -c "import xformers._C_flashattention3"
# success!

Before submitting

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

xuzhao9 commented 1 day ago

The CI workflow failed because the CI runner does not come with libcuda.so installed. To install it, we need to install the NVIDIA driver package: https://github.com/pytorch-labs/tritonbench/blob/main/docker/tritonbench-nightly.dockerfile#L47

xuzhao9 commented 1 day ago

Thanks for explaining the details! I will close this PR and wait for upstream FA3 to fix and upgrade CUTLASS.