thu-ml / SageAttention

Quantized Attention that achieves speedups of 2.1-3.1x and 2.7-5.1x compared to FlashAttention2 and xformers, respectively, without lossing end-to-end metrics across various models.
Apache License 2.0
614 stars 28 forks source link

Windows compile issue when testing CogVideoX script #7

Open SoftologyPro opened 1 month ago

SoftologyPro commented 1 month ago

Trying under Windows here (adding to CogVideoX as per your demo script).

  File "D:\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\build.py", line 52, in _build
    raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
RuntimeError: Failed to find C compiler. Please specify via CC environment variable.

For both Triton 2.1.0 and/or Triton 3.0.0. Any idea what the CC env var should be set to?

jt-zhang commented 1 month ago

Thank you for reaching out. Maybe you could try the newest example codes in ./example in our repository.

SoftologyPro commented 1 month ago

Yes, this is using the latest, full results...

D:\Tests\CogVideoX>python sageattn_cogvideo.py
Loading checkpoint shards: 100%|¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦| 2/2 [00:05<00:00,  2.66s/it]
Loading pipeline components...: 100%|¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦¦| 5/5 [00:05<00:00,  1.17s/it]
  0%|                                                                                                                                                                                                                                                 | 0/50 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "D:\Tests\CogVideoX\sage.py", line 20, in <module>
    video = pipe(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\diffusers\pipelines\cogvideo\pipeline_cogvideox.py", line 671, in __call__
    noise_pred = self.transformer(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\accelerate\hooks.py", line 170, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\diffusers\models\transformers\cogvideox_transformer_3d.py", line 456, in forward
    hidden_states, encoder_hidden_states = block(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\diffusers\models\transformers\cogvideox_transformer_3d.py", line 131, in forward
    attn_hidden_states, attn_encoder_hidden_states = self.attn1(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\torch\nn\modules\module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\diffusers\models\attention_processor.py", line 490, in forward
    return self.processor(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\diffusers\models\attention_processor.py", line 1934, in __call__
    hidden_states = F.scaled_dot_product_attention(
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\sageattention\core.py", line 36, in sageattn
    q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\sageattention\quant_per_block.py", line 63, in per_block_int8
    q_kernel_per_block_int8[grid](
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\jit.py", line 166, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\jit.py", line 348, in run
    device = driver.get_current_device()
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\driver.py", line 230, in __getattr__
    self._initialize_obj()
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\driver.py", line 227, in _initialize_obj
    self._obj = self._init_fn()
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\driver.py", line 260, in initialize_driver
    return CudaDriver()
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\driver.py", line 122, in __init__
    self.utils = CudaUtils()
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\runtime\driver.py", line 69, in __init__
    so = _build("cuda_utils", src_path, tmpdir)
  File "D:\Tests\CogVideoX\CogVideo\venv\lib\site-packages\triton\common\build.py", line 101, in _build
    raise RuntimeError("Failed to find C compiler. Please specify via CC environment variable.")
RuntimeError: Failed to find C compiler. Please specify via CC environment variable.

Current pip list

Package                   Version
------------------------- ---------------
accelerate                0.34.2
aiofiles                  23.2.1
aiohappyeyeballs          2.4.0
aiohttp                   3.10.5
aiosignal                 1.3.1
altair                    5.4.1
annotated-types           0.7.0
anyio                     4.4.0
async-timeout             4.0.3
attrs                     24.2.0
blinker                   1.8.2
boto3                     1.35.20
botocore                  1.35.20
braceexpand               0.1.7
cachetools                5.5.0
certifi                   2024.8.30
charset-normalizer        3.3.2
click                     8.1.7
colorama                  0.4.6
contourpy                 1.3.0
cpm-kernels               1.0.11
cycler                    0.12.1
Cython                    3.0.11
datasets                  3.0.0
decorator                 4.4.2
deepspeed                 0.12.7+40342055
diffusers                 0.31.0.dev0
dill                      0.3.8
distro                    1.9.0
einops                    0.8.0
exceptiongroup            1.2.2
fastapi                   0.114.2
ffmpy                     0.4.0
filelock                  3.16.0
fonttools                 4.53.1
frozenlist                1.4.1
fsspec                    2024.6.1
gitdb                     4.0.11
GitPython                 3.1.43
gradio                    4.44.0
gradio_client             1.3.0
h11                       0.14.0
hjson                     3.1.0
httpcore                  1.0.5
httpx                     0.27.2
huggingface-hub           0.24.7
idna                      3.10
imageio                   2.35.1
imageio-ffmpeg            0.5.1
importlib_metadata        8.5.0
importlib_resources       6.4.5
Jinja2                    3.1.4
jiter                     0.5.0
jmespath                  1.0.1
jsonschema                4.23.0
jsonschema-specifications 2023.12.1
kiwisolver                1.4.7
markdown-it-py            3.0.0
MarkupSafe                2.1.5
matplotlib                3.9.2
mdurl                     0.1.2
moviepy                   1.0.3
mpmath                    1.3.0
multidict                 6.1.0
multiprocess              0.70.16
narwhals                  1.8.1
networkx                  3.3
ninja                     1.11.1.1
numpy                     1.26.4
openai                    1.45.1
orjson                    3.10.7
packaging                 24.1
pandas                    2.2.2
Pillow                    9.5.0
pip                       24.2
proglog                   0.1.10
protobuf                  5.28.1
psutil                    6.0.0
py-cpuinfo                9.0.0
pyarrow                   17.0.0
pydantic                  2.9.1
pydantic_core             2.23.3
pydeck                    0.9.1
pydub                     0.25.1
Pygments                  2.18.0
pynvml                    11.5.3
pyparsing                 3.1.4
python-dateutil           2.9.0.post0
python-multipart          0.0.9
pytz                      2024.2
PyYAML                    6.0.2
referencing               0.35.1
regex                     2024.9.11
requests                  2.32.3
rich                      13.8.1
rpds-py                   0.20.0
ruff                      0.6.5
s3transfer                0.10.2
safetensors               0.4.5
sageattention             1.0.1
semantic-version          2.10.0
sentencepiece             0.2.0
setuptools                65.5.0
shellingham               1.5.4
six                       1.16.0
smmap                     5.0.1
sniffio                   1.3.1
starlette                 0.38.5
streamlit                 1.38.0
SwissArmyTransformer      0.4.12
sympy                     1.13.2
tenacity                  8.5.0
tensorboardX              2.6.2.2
tokenizers                0.19.1
toml                      0.10.2
tomlkit                   0.12.0
torch                     2.4.0+cu118
torchaudio                2.4.0+cu118
torchvision               0.19.0+cu118
tornado                   6.4.1
tqdm                      4.66.5
transformers              4.44.2
triton                    2.1.0
typer                     0.12.5
typing_extensions         4.12.2
tzdata                    2024.1
urllib3                   2.2.3
uvicorn                   0.30.6
watchdog                  4.0.2
webdataset                0.2.100
websockets                12.0
xxhash                    3.5.0
yarl                      1.11.1
zipp                      3.20.2

I do have VS command line cl.exe support and cmake.exe that happily compile other AI/ML related systems when needed. Both of them are in my path.

jt-zhang commented 1 month ago

Sorry, but I can not reproduce the bug. I suggest updating cuda from 11.8 to 12.1 or a higher version and using triton 3.0.0.

SoftologyPro commented 1 month ago

It works for you under Windows? I get the same error with CUDA 12.4 and Triton 3.0.0. Do you have a requirements.txt for the packages and versions you use? Or give me the required packages after creating a new empty venv for this to work?

If I set CC=cl.exe I get a different error

subprocess.CalledProcessError: Command '['cl.exe', 'C:\\Users\\Jason\\AppData\\Local\\Temp\\tmpa_lxe3au\\main.c', '-O3', '-shared', '-lcuda', '-LD:\\Tests\\CogVideoX\\CogVideo\\venv\\Lib\\site-packages\\triton\\backends\\nvidia\\lib', '-LC:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\lib\\x64', '-LD:\\Python\\libs', '-ID:\\Tests\\CogVideoX\\CogVideo\\venv\\Lib\\site-packages\\triton\\backends\\nvidia\\include', '-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\include', '-IC:\\Users\\Jason\\AppData\\Local\\Temp\\tmpa_lxe3au', '-ID:\\Python\\Include', '-o', 'C:\\Users\\Jason\\AppData\\Local\\Temp\\tmpa_lxe3au\\cuda_utils.cp310-win_amd64.pyd']' returned non-zero exit status 2.

If I set CC=cmake.exe I get similar

subprocess.CalledProcessError: Command '['cmake.exe', 'C:\\Users\\Jason\\AppData\\Local\\Temp\\tmp42timoi4\\main.c', '-O3', '-shared', '-lcuda', '-LD:\\Tests\\CogVideoX\\CogVideo\\venv\\Lib\\site-packages\\triton\\backends\\nvidia\\lib', '-LC:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\lib\\x64', '-LD:\\Python\\libs', '-ID:\\Tests\\CogVideoX\\CogVideo\\venv\\Lib\\site-packages\\triton\\backends\\nvidia\\include', '-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.4\\include', '-IC:\\Users\\Jason\\AppData\\Local\\Temp\\tmp42timoi4', '-ID:\\Python\\Include', '-o', 'C:\\Users\\Jason\\AppData\\Local\\Temp\\tmp42timoi4\\cuda_utils.cp310-win_amd64.pyd']' returned non-zero exit status 1.

jt-zhang commented 1 month ago

Sorry about that. I do not have a Windows OS. I can share the pip list of mine in Ubuntu:

absl-py                        2.1.0
accelerate                     1.0.0
anaconda-anon-usage            0.4.4
anyio                          4.4.0
archspec                       0.2.3
argon2-cffi                    23.1.0
argon2-cffi-bindings           21.2.0
arrow                          1.3.0
asttokens                      2.4.1
async-lru                      2.0.4
attrs                          23.2.0
Babel                          2.15.0
beautifulsoup4                 4.12.3
bleach                         6.1.0
boltons                        23.0.0
Brotli                         1.0.9
build                          1.2.2.post1
certifi                        2024.2.2
cffi                           1.16.0
charset-normalizer             2.0.4
comm                           0.2.2
conda                          24.4.0
conda-content-trust            0.2.0
conda-libmamba-solver          24.1.0
conda-package-handling         2.2.0
conda_package_streaming        0.9.0
contourpy                      1.2.1
cryptography                   42.0.5
cycler                         0.12.1
debugpy                        1.8.1
decorator                      5.1.1
defusedxml                     0.7.1
diffusers                      0.30.3
distro                         1.9.0
docutils                       0.21.2
executing                      2.0.1
fastjsonschema                 2.19.1
filelock                       3.14.0
fonttools                      4.53.0
fqdn                           1.5.1
fsspec                         2024.5.0
grpcio                         1.64.0
h11                            0.14.0
httpcore                       1.0.5
httpx                          0.27.0
huggingface-hub                0.25.1
idna                           3.7
imageio                        2.35.1
imageio-ffmpeg                 0.5.1
importlib_metadata             8.5.0
iniconfig                      2.0.0
ipykernel                      6.29.4
ipython                        8.25.0
ipywidgets                     8.1.3
isoduration                    20.11.0
jaraco.classes                 3.4.0
jaraco.context                 6.0.1
jaraco.functools               4.1.0
jedi                           0.19.1
jeepney                        0.8.0
Jinja2                         3.1.4
json5                          0.9.25
jsonpatch                      1.33
jsonpointer                    2.1
jsonschema                     4.22.0
jsonschema-specifications      2023.12.1
jupyter_client                 8.6.2
jupyter_core                   5.7.2
jupyter-events                 0.10.0
jupyter-lsp                    2.2.5
jupyter_server                 2.14.1
jupyter_server_terminals       0.5.3
jupyterlab                     4.2.1
jupyterlab-language-pack-zh-CN 4.2.post1
jupyterlab_pygments            0.3.0
jupyterlab_server              2.27.2
jupyterlab_widgets             3.0.11
keyring                        25.4.1
kiwisolver                     1.4.5
libmambapy                     1.5.8
Markdown                       3.6
markdown-it-py                 3.0.0
MarkupSafe                     2.1.5
matplotlib                     3.9.0
matplotlib-inline              0.1.7
mdurl                          0.1.2
menuinst                       2.0.2
mistune                        3.0.2
more-itertools                 10.5.0
mpmath                         1.3.0
nbclient                       0.10.0
nbconvert                      7.16.4
nbformat                       5.10.4
nest-asyncio                   1.6.0
networkx                       3.3
nh3                            0.2.18
notebook_shim                  0.2.4
numpy                          1.26.4
nvidia-cublas-cu12             12.1.3.1
nvidia-cuda-cupti-cu12         12.1.105
nvidia-cuda-nvrtc-cu12         12.1.105
nvidia-cuda-runtime-cu12       12.1.105
nvidia-cudnn-cu12              9.1.0.70
nvidia-cufft-cu12              11.0.2.54
nvidia-curand-cu12             10.3.2.106
nvidia-cusolver-cu12           11.4.5.107
nvidia-cusparse-cu12           12.1.0.106
nvidia-nccl-cu12               2.20.5
nvidia-nvjitlink-cu12          12.5.40
nvidia-nvtx-cu12               12.1.105
opencv-python                  4.10.0.84
overrides                      7.7.0
packaging                      23.2
pandas                         2.2.3
pandocfilters                  1.5.1
parso                          0.8.4
pexpect                        4.9.0
pillow                         10.3.0
pip                            24.0
pkginfo                        1.10.0
platformdirs                   3.10.0
pluggy                         1.5.0
prometheus_client              0.20.0
prompt_toolkit                 3.0.45
protobuf                       5.27.0
psutil                         5.9.8
ptyprocess                     0.7.0
pure-eval                      0.2.2
pycosat                        0.6.6
pycparser                      2.21
Pygments                       2.18.0
pyparsing                      3.1.2
pyproject_hooks                1.2.0
PySocks                        1.7.1
pytest                         8.3.3
python-dateutil                2.9.0.post0
python-json-logger             2.0.7
pytz                           2024.2
PyYAML                         6.0.1
pyzmq                          26.0.3
readme_renderer                44.0
referencing                    0.35.1
regex                          2024.9.11
requests                       2.31.0
requests-toolbelt              1.0.0
rfc3339-validator              0.1.4
rfc3986                        2.0.0
rfc3986-validator              0.1.1
rich                           13.9.2
rpds-py                        0.18.1
ruamel.yaml                    0.17.21
safetensors                    0.4.5
sageattention                  1.0.2
SecretStorage                  3.3.3
Send2Trash                     1.8.3
sentencepiece                  0.2.0
setuptools                     69.5.1
six                            1.16.0
sniffio                        1.3.1
soupsieve                      2.5
stack-data                     0.6.3
supervisor                     4.2.5
sympy                          1.12.1
tensorboard                    2.16.2
tensorboard-data-server        0.7.2
terminado                      0.18.1
tinycss2                       1.3.0
tokenizers                     0.20.0
torch                          2.4.1
torchvision                    0.19.1
tornado                        6.4
tqdm                           4.66.2
traitlets                      5.14.3
transformers                   4.45.2
triton-nightly                 3.0.0.post20240716052845
truststore                     0.8.0
twine                          5.1.1
types-python-dateutil          2.9.0.20240316
typing_extensions              4.12.1
tzdata                         2024.2
uri-template                   1.3.0
urllib3                        2.1.0
wcwidth                        0.2.13
webcolors                      1.13
webencodings                   0.5.1
websocket-client               1.8.0
Werkzeug                       3.0.3
wheel                          0.43.0
widgetsnbextension             4.0.11
zipp                           3.20.2
zstandard                      0.22.0
guyan364 commented 1 month ago

For now, Triton only supports Linux, maybe you can try it with WSL2.

SoftologyPro commented 1 month ago

For now, Triton only supports Linux, maybe you can try it with WSL2.

There are working WHL installers for Triton out there that I do use for other systems/scripts. Something diffferent here causes it to have to compile and then fail.

woct0rdho commented 1 month ago

Just an update that I've built Triton wheels for Windows, with hopefully better logic to detect the compiler toolchain and better instructions for installing them https://github.com/woct0rdho/triton-windows