facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.41k stars 597 forks source link

Windows build of xformers cannot work on pytorch>=2.2 now. #1073

Closed KohakuBlueleaf closed 2 months ago

KohakuBlueleaf commented 2 months ago

🐛 Bug

In the last release of xformers (0.0.27.post1) Xformers introduce a feature which use flash_attn package and pytorch's builtin SDP to reduce size/compile time. The problem is this behavior affect the windows platform which:

  1. Flash Attn doesn't have pre-built wheel.
  2. Pytorch drop the flash attn2 support after 2.2.0 on Windows

Basically it means xformers is the ONLY ONE flash attn implementation that have windows pre-built wheel. But now it drop the support.

Since we just need to modify some env param to let the compile process to actually compile it. I think this is actually a bug. not a question/help or feature request.

Command

using python -m xformers.info will found that xformers are using "flashattF@2.5.6-pt" which is ACTUALLY NOT SUPPORTED

To Reproduce

ANY scripts utilizing xformers' flash attention on Windows platform in 0.0.27.post1 version.

Expected behavior

the prebuilt wheel of xformers should have flashattn/cutlass compiled, not just import pytorch one.

Environment

PyTorch version: 2.4.0+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Microsoft Windows 11 專業版
GCC version: (GCC) 13.2.0
Clang version: Could not collect
CMake version: version 3.20.3
Libc version: N/A

Python version: 3.11.7 | packaged by conda-forge | (main, Dec 15 2023, 08:28:06) [MSC v.1937 64 bit (AMD64)] (64-bit runtime)
Python platform: Windows-10-10.0.22631-SP0
Is CUDA available: True
CUDA runtime version: 12.4.99
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 4090
Nvidia driver version: 560.70
cuDNN version: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.1\bin\cudnn_ops_train64_8.dll
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture=9
CurrentClockSpeed=3000
DeviceID=CPU0
Family=207
L2CacheSize=32768
L2CacheSpeed=
Manufacturer=GenuineIntel
MaxClockSpeed=3000
Name=13th Gen Intel(R) Core(TM) i9-13900K
ProcessorType=3
Revision=

Versions of relevant libraries:
[pip3] numpy==1.23.5
[pip3] pytorch-lightning==2.3.3
[pip3] torch==2.4.0+cu124
[pip3] torchaudio==2.4.0+cu124
[pip3] torchsde==0.2.6
[pip3] torchtext==0.6.0
[pip3] torchvision==0.19.0+cu124
[conda] Could not collect

Additional context

I can ensure built from source xformers is working normally. With XFORMERS_PT_CUTLASS_ATTN/XFORMERS_PT_FLASH_ATTN set to 0

KohakuBlueleaf commented 2 months ago

related issue: https://github.com/facebookresearch/xformers/issues/1069

USE_FLASH_ATTENTION was not enabled for build is pytorch side error but it means xformers SHOULD NOT ASSUME user in windows will have flash attn.

FurkanGozukara commented 2 months ago

this is why i am getting error

please fix it :(

so latest we can use is 2.2.0 and xformers 0.0.24.post1 ?

ananosleep commented 2 months ago

this is why i am getting error

please fix it :(

so latest we can use is 2.2.0 and xformers 0.0.24.post1 ?

2.3.0 and xformers 0.0.26.post1 works well, since this "feature" hadn't been brought in. 2.3.1 and xformers 0.0.27 should also work but I didn't test it.

FurkanGozukara commented 2 months ago

this is why i am getting error please fix it :( so latest we can use is 2.2.0 and xformers 0.0.24.post1 ?

2.3.0 and xformers 0.0.26.post1 works well, since this "feature" hadn't been brought in. 2.3.1 and xformers 0.0.27 should also work but I didn't test it.

thanks i am gonna test 2.3.0 and 0.0.26.post1

lw commented 2 months ago

Sorry about this, indeed it looks like PyTorch has never supported building FlashAttention v2 on Windows, and this occurred in the 2.2.0 release (https://github.com/pytorch/pytorch/pull/105602).

We're looking into re-enabling FlashAttention in xFormers just for Windows.

However, this will likely be a temporary fix. We'll see whether we can get PyTorch to ship it by default, but I'd also recommend that you look into whether you have to use FlashAttention, or whether you can switch to some of the other backends provided by PyTorch on Windows.

FurkanGozukara commented 2 months ago

thanks a lot @lw

KohakuBlueleaf commented 2 months ago

Sorry about this, indeed it looks like PyTorch has never supported building FlashAttention v2 on Windows, and this occurred in the 2.2.0 release (pytorch/pytorch#105602).

We're looking into re-enabling FlashAttention in xFormers just for Windows.

However, this will likely be a temporary fix. We'll see whether we can get PyTorch to ship it by default, but I'd also recommend that you look into whether you have to use FlashAttention, or whether you can switch to some of the other backends provided by PyTorch on Windows.

I will say it is definitely possible for me to run things on cutlass attn (from pt)

But it also means I need to reimplement all the related things from diffusers to get it work right. With worse performance (speed term) as result.

Which looks non-sense for me.

Also, this issue indicate that the flashattn-pt detect method cannot check if pytorch is compiled with Flash Attn or not correctly

lw commented 2 months ago

this issue indicate that the flashattn-pt detect method cannot check if pytorch is compiled with Flash Attn or not correctly

Right, this is what we're currently investigating. Because if that worked as intended then we wouldn't be having this issue.

We don't have access to any Windows machine to debug. Could you please help us by installing PyTorch 2.4.0 and xFormers 0.0.27.post1 and give us the output of these commands?

import torch
from xformers.ops.fmha.torch_attention_compat import is_pt_flash_compatible
print(torch.backends.cuda.flash_sdp_enabled())
print(is_pt_flash_compatible(force=False))

Thanks!

ear361 commented 2 months ago

this issue indicate that the flashattn-pt detect method cannot check if pytorch is compiled with Flash Attn or not correctly

Right, this is what we're currently investigating. Because if that worked as intended then we wouldn't be having this issue.

We don't have access to any Windows machine to debug. Could you please help us by installing PyTorch 2.4.0 and xFormers 0.0.27.post1 and give us the output of these commands?

import torch
from xformers.ops.fmha.torch_attention_compat import is_pt_flash_compatible
print(torch.backends.cuda.flash_sdp_enabled())
print(is_pt_flash_compatible(force=False))

Thanks!

>>> version('torch')
'2.4.0+cu121'
>>> version('xformers')
'0.0.27.post1'
>>> import torch
>>> from xformers.ops.fmha.torch_attention_compat import is_pt_flash_compatible
C:\Users\Admin\Downloads\cfui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\xformers\ops\fmha\flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
C:\Users\Admin\Downloads\cfui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\xformers\ops\fmha\flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
A matching Triton is not available, some optimizations will not be enabled
Traceback (most recent call last):
  File "C:\Users\Admin\Downloads\cfui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\xformers\__init__.py", line 57, in _is_triton_available
    import triton  # noqa
    ^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'triton'
C:\Users\Admin\Downloads\cfui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\xformers\ops\swiglu_op.py:127: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  @torch.cuda.amp.custom_fwd
C:\Users\Admin\Downloads\cfui\ComfyUI_windows_portable\python_embeded\Lib\site-packages\xformers\ops\swiglu_op.py:148: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  @torch.cuda.amp.custom_bwd
>>> print(torch.backends.cuda.flash_sdp_enabled())
True
>>> print(is_pt_flash_compatible(force=False))
True
>>>
lw commented 2 months ago

I just triggered a new release, v0.0.27.post2, which should include FlashAttention in its Windows builds.

Moreover, I'm trying to make PyTorch include FlashAttention on Windows by default, so that in the future you won't have to depend on xFormers: https://github.com/pytorch/pytorch/pull/131875

FurkanGozukara commented 2 months ago

@lw awesome thank you so much

I wish shameless OpenAI were following you

They still don't support Triton and thus I can't use Cogvlm v2 on Windows

https://github.com/triton-lang/triton/issues/4395

https://github.com/THUDM/CogVLM2/issues/169

KohakuBlueleaf commented 2 months ago

@lw awesome thank you so much

I wish shameless OpenAI were following you

They still don't support Triton and thus I can't use Cogvlm v2 on Windows

triton-lang/triton#4395

THUDM/CogVLM2#169

We have some fork for triton which for windows build to work Although OpenAI directly close the PR from them LoL

FurkanGozukara commented 2 months ago

@lw awesome thank you so much I wish shameless OpenAI were following you They still don't support Triton and thus I can't use Cogvlm v2 on Windows triton-lang/triton#4395 THUDM/CogVLM2#169

We have some fork for triton which for windows build to work Although OpenAI directly close the PR from them LoL

can i get that fork please really needed. i got 2.1 triton pre compiled wheel but looks like triton 3 is mandatory for Cog VLM

KohakuBlueleaf commented 2 months ago

can i get that fork please really needed. i got 2.1 triton pre compiled wheel but looks like triton 3 is mandatory for Cog VLM

https://github.com/triton-lang/triton/pull/4045