triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.7k stars 1.53k forks source link

[BUG] DeepSpeed Inference with HF diffusers breaks on T4 #1054

Open tchaton opened 1 year ago

tchaton commented 1 year ago

I am trying to use DeepSpeed Inference with Diffusers on T4 GPU but it seems there is a triton error.

Reported the bug on DeepSpeed for better tracking: https://github.com/microsoft/DeepSpeed/issues/2702

import os, torch, diffusers, deepspeed

hf_auth_key = os.getenv("HF_AUTH_KEY")
if not hf_auth_key:
    raise ValueError("HF_AUTH_KEY is not set")

pipe = diffusers.StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    use_auth_token=hf_auth_key,
    torch_dtype=torch.float16,
    revision="fp16")

model = deepspeed.init_inference(pipe.to("cuda"), dtype=torch.float16)
model("hello from here")

Here is the error trace associated with the inference. It seems related to Triton caching.

Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-42648570729a4835b21c1c18cebedbfe-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 64, 128), (True, True, True, (False,), True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False, False), (False, False), (True, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "test.py", line 14, in <module>
    model("hello from here")
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/inference/engine.py", line 524, in forward
    outputs = self.module(*inputs, **kwargs)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py", line 387, in __call__
    noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/model_implementations/diffusers/unet.py", line 41, in forward
    return self._forward(*inputs, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/model_implementations/diffusers/unet.py", line 63, in _forward
    return self.unet(sample, timestamp, encoder_hidden_states, return_dict)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/unet_2d_condition.py", line 307, in forward
    sample, res_samples = downsample_block(
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/unet_2d_blocks.py", line 598, in forward
    hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/diffusers/models/attention.py", line 202, in forward
    hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_transformer_block.py", line 99, in forward
    out_attn_1 = self.attn_1(out_norm_1)
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 225, in forward
    output = DeepSpeedDiffusersAttentionFunction.apply(
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 115, in forward
    output = selfAttention_fp(input, context, input_mask)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/diffusers_attention.py", line 79, in selfAttention_fp
    context_layer = triton_flash_attn_kernel(qkv_out[0],
  File "/home/zeus/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/venv/lib/python3.8/site-packages/deepspeed/ops/transformer/inference/triton_ops.py", line 119, in forward
    _fwd_kernel[grid](
  File "/content/venv/lib/python3.8/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "<string>", line 43, in _fwd_kernel
RuntimeError: Triton Error [CUDA]: invalid argument
 NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.5
aiobotocore==2.4.2
aiohttp==3.8.3
aioitertools==0.11.0
aiosignal==1.3.1
anyio==3.6.2
arrow==1.2.3
async-timeout==4.0.2
attrs==22.2.0
beautifulsoup4==4.11.1
black==22.12.0
bleach==5.0.1
blessed==1.19.1
botocore==1.27.59
certifi==2019.11.28
cffi==1.15.1
chardet==3.0.4
charset-normalizer==2.1.1
click==8.1.3
cmake==3.25.0
commonmark==0.9.1
croniter==1.3.8
cryptography==39.0.0
dbus-python==1.2.16
deepdiff==6.2.3
deepspeed==0.7.7
diffusers==0.7.1
diffusion-with-autoscaler @ file:///content
distlib==0.3.6
dnspython==2.2.1
docker==6.0.1
docutils==0.19
email-validator==1.3.0
exceptiongroup==1.1.0
fastapi==0.89.1
filelock==3.9.0
frozenlist==1.3.3
fsspec==2022.11.0
h11==0.14.0
hjson==3.1.0
httpcore==0.16.3
httptools==0.5.0
httpx==0.23.3
huggingface-hub==0.11.1
idna==2.8
importlib-metadata==6.0.0
importlib-resources==5.10.2
iniconfig==2.0.0
inquirer==3.1.2
isort==5.11.4
itsdangerous==2.1.2
jaraco.classes==3.2.3
jeepney==0.8.0
Jinja2==3.1.2
jmespath==1.0.1
keyring==23.13.1
lightning @ https://github.com/Lightning-AI/lightning/archive/refs/tags/1.8.6.zip
lightning-api-access @ git+https://github.com/Lightning-AI/LAI-API-Access-UI-Component.git@ec3016c1bd2165f9e720b686a83376def1705a60
lightning-cloud==0.5.16
lightning-launcher==0.0.43
lightning-utilities==0.5.0
MarkupSafe==2.1.1
more-itertools==9.0.0
multidict==6.0.4
mypy-extensions==0.4.3
ninja==1.11.1
numpy==1.24.1
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
ordered-set==4.1.0
orjson==3.8.5
packaging==23.0
pathspec==0.10.3
Pillow==9.4.0
pkginfo==1.9.6
platformdirs==2.6.2
pluggy==1.0.0
protobuf==3.20.1
psutil==5.9.4
py-cpuinfo==9.0.0
pycparser==2.21
pydantic==1.10.4
Pygments==2.14.0
PyGObject==3.36.0
PyJWT==2.6.0
pytest==7.2.0
python-apt==2.0.0+ubuntu0.20.4.8
python-dateutil==2.8.2
python-dotenv==0.21.0
python-editor==1.0.4
python-multipart==0.0.5
PyYAML==6.0
readchar==4.0.3
readme-renderer==37.3
redis==4.4.2
regex==2022.10.31
requests==2.28.1
requests-toolbelt==0.10.1
requests-unixsocket==0.2.0
rfc3986==1.5.0
rich==13.0.1
s3fs==2022.11.0
SecretStorage==3.3.3
six==1.14.0
sniffio==1.3.0
soupsieve==2.3.2.post1
starlette==0.22.0
starsessions==1.3.0
tabulate==0.9.0
tensorboardX==2.5.1
tokenizers==0.13.2
tomli==2.0.1
torch==1.13.1
torchmetrics==0.11.0
tqdm==4.64.1
traitlets==5.8.1
transformers==4.24.0
triton==2.0.0.dev20221202
twine==4.0.2
typing_extensions==4.4.0
ujson==5.7.0
urllib3==1.26.14
uvicorn==0.20.0
uvloop==0.17.0
virtualenv==20.17.1
watchfiles==0.18.1
wcwidth==0.2.5
webencodings==0.5.1
websocket-client==1.4.2
websockets==10.4
wrapt==1.14.1
yarl==1.8.2
zipp==3.11.0
Jokeren commented 1 year ago

It would be better if you can attach the kernel that caused the problem

Jokeren commented 1 year ago

Is it triton_flash_attn_kernel?

tchaton commented 1 year ago

Hey @jokeren, Yes. It is. I didn't add it as it was in the trace, apology for that. Here it is: https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/ops/transformer/inference/triton_ops.py#L119. Do you have any idea why there is a KeyError and how to debug it ?

tchaton commented 1 year ago

It would be fantastic for me if I could run this model on T4.

Jokeren commented 1 year ago

Flash attention used to work on A100 only, for whatever reasons I don't remember clearly 🤣

@ptillet Is it still true?

tchaton commented 1 year ago

Note: I tried the original Flash Attention and it seems to produce the same result as triton but it works on T4 and slightly faster. I am not blocked anymore but it would be great to have this resolved. https://github.com/Lightning-AI/stablediffusion/pull/8

ptillet commented 1 year ago

If I remember correctly, the forward pass used to work on pre-Ampere hardware, but the backward pass only worked on post-Ampere hardware. It may be the case that now neither works on Turing :D I'm still working on A100 performance optimizations, but I agree that the forward pass should work well on all hardware. I don't think there's any major roadblock against this. I'll look into it.

tchaton commented 1 year ago

Hey @ptillet. Thanks for the update. Please, ping me once you have a PR ready.

stephen-youn commented 1 year ago

I wonder whether there is any progress to make the flash attention work with the latest triton in T4 GPU? is forward-passing working at least then?

ptillet commented 1 year ago

The new tutorial should work on pre-Ampere hardware for the fwd pass https://github.com/openai/triton/blob/main/python/tutorials/06-fused-attention.py. Let me know if it doesn't

stephen-youn commented 1 year ago

it failed in my test to run fwd with T4 my env: ubuntu22.04, installed cuda_11.7.1_515.65.01_linux the error message is as follows

[from the latest head in triton main repo at 34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5] python: /mnt/styoun/projects/triton/lib/Analysis/Allocation.cpp:40: std::pair<llvm::SmallVector, llvm::SmallVector > mlir::triton::getCvtOrder(mlir::Attribute, mlir::Attribute): Assertion `!(srcMmaLayout && dstMmaLayout) && "Unexpected mma -> mma layout conversion"' failed. Aborted (core dumped)

[triton-2.0.0.post1] error: 'tt.reduce' op inferred type(s) 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [4, 1]}>}>>' are incompatible with return type(s) of operation 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 2, warpsPerCTA = [2, 2]}>}>>' Traceback (most recent call last): File "", line 21, in _fwd_kernel KeyError: ('2-.-0-.-0-d82511111ad128294e9d31a6ac684238-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f839199e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, 'fp32', torch.float32, torch.float32, torch.float16, 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), (128, 64, 128), (True, True, True, (False,), True, True, True, (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (True, False), (True, False), (True, False), (False, True), (False, True), (False, False), (True, False)))

[updates] tried the same from pytorch1.12.0+cu113 but failed again like before with the same error messages

stephen-youn commented 1 year ago

is there a specific cuda version and a commit in the triton that makes the flash attention work in t4? then plz let me know

stephen-youn commented 1 year ago

also tried in v100 with torch112cuda116+ latest release triton=2.0.0.post1 but failed

error: 'tt.reduce' op inferred type(s) 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 0, warpsPerCTA = [4, 1]}>}>>' are incompatible with return type(s) of operation 'tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #triton_gpu.mma<{versionMajor = 1, versionMinor = 2, warpsPerCTA = [2, 2]}>}>>' Traceback (most recent call last):

Jokeren commented 1 year ago

Yes, we acknowledge the issue concerning the MMA-to-MMA conversion error. While it hasn't been our top priority due to existing workarounds, we have indeed raised its importance a bit. As such, we anticipate a resolution in the near future.

stephen-youn commented 1 year ago

thanks for the reply and i wonder what's the workaround that i can use to make it run in T4, is it a specific release or commit in triton that works at least for flash attn forward passing?

Jokeren commented 1 year ago

I thought I mentioned it somewhere else but I cannot remember.

You are try to store the result of the dot to a piece of temporary global memory and then reload it to a tensor.

mirh commented 1 month ago

MMA conversion was supposedly fixed in #2627 I think?