Open junjie18 opened 1 year ago
And when I comment out https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L322-L327.
If I run following code and set head dim = 128
, (for example, the size of q
is (2, 128, 2, 128)
) the procedure will hang up there. And head dim = 64 (2, 128, 2, 64)
works fine.
flash_attn_triton.flash_attn_func(
q, k, v, None, True/False, None
)
It's the Triton version. As mentioned at the beginning of the file:
Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.
It's the Triton version. As mentioned at the beginning of the file:
Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed.
Hey, @tridao. How do you install this version, though? It seems that PyTorch 2.0.0 and 2.0.1 both require triton 2.0.0. If I try to run pip install triton==2.0.0.dev20221202
, I see that it starts downloading an older PyTorch version without CUDA (Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
). If I install it with --no-deps
, I get this error:
fatal error: cuda.h: No such file or directory
2 | #include "cuda.h"
| ^~~~~~~~
compilation terminated.
So, I wonder how I should install this specific triton version.
Yeah idk, it works fine with the Pytorch container from Nvidia (e.g. 23.05 has Pytorch 2.0) by pip install triton==2.0.0.dev20221202
.
Yeah idk, it works fine with the Pytorch container from Nvidia (e.g. 23.05 has Pytorch 2.0) by
pip install triton==2.0.0.dev20221202
.
I tried to install it using conda instead of pipenv with the following environment.tml:
name: page-tagging-svc
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- python=3.10
- pytorch=2.0
- pytorch-cuda=11.7
- transformers=4.29
- datasets=2.12
- pytorch-lightning=2.0
- torchmetrics=0.11
- pydantic=1.10
- python-json-logger=2.0
- scikit-learn=1.2
- xgboost=1.7
- pandas=1.5
- beautifulsoup4=4.12
- sentencepiece=0.1
- fastcore=1.5
# Dev
- black=23.3
- isort=5.12
- mypy=1.3
- pylint=2.17
- pytest=7.3
- pytest-cov=4.1
- pytest-mock=3.10
- click=8.1
- google-cloud-storage=2.9
- ultimate-sitemap-parser=0.5
- pycountry=22.3
- pillow=9.5
- jupyterlab=4.0
- matplotlib=3.7
- scaleapi=2.14
- google-cloud-aiplatform=1.26
- pip
- pip:
- jsonargparse[signatures]==4.20.1
- bentoml==1.0.19
- starlette==0.24.*
- iterative-stratification==0.1.7
- triton==2.0.0.dev20221202
I checked that the correct version of triton was installed. However, I still get:
/tmp/tmpz4gpqr27/main.c:2:10: fatal error: cuda.h: No such file or directory
2 | #include "cuda.h"
| ^~~~~~~~
compilation terminated.
Traceback (most recent call last):
File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float32, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
results = self._run_stage()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
self._run_sanity_check()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
val_loop.run()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 174, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 120, in validation_step
logits, loss = self._step(batch)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 73, in _step
logits = self._model(batch)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 161, in forward
markup_lm_outputs = self.markup_lm(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 915, in forward
encoder_outputs = self.encoder(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 665, in forward
layer_outputs = layer_module(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 549, in forward
self_attention_outputs = self.attention(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 507, in forward
self_outputs = self.self(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 310, in forward
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_mask)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 810, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 623, in _flash_attn_forward
_fwd_kernel[grid](
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 41, in _fwd_kernel
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1239, in compile
so = _build(fn.__name__, src_path, tmpdir)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1169, in _build
ret = subprocess.check_call(cc_cmd)
File "/home/fernando/anaconda3/envs/python_310_cuda_113/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpz4gpqr27/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/fernando/anaconda3/envs/python_310_cuda_113/include/python3.10', '-I/tmp/tmpz4gpqr27', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpz4gpqr27/_fwd_kernel.cpython-310-x86_64-linux-gnu.so', '-L/usr/lib', '-L/usr/lib32']' returned non-zero exit status 1.
python-BaseException
Traceback (most recent call last):
File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float32, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 2195, in <module>
main()
File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 2177, in main
globals = debugger.run(setup['file'], None, None, is_module)
File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 1489, in run
return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 1496, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/trainer.py", line 106, in <module>
cli_main()
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/trainer.py", line 97, in cli_main
cli = PageTaggingLightningCLI(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 353, in __init__
self._run_subcommand(self.subcommand)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 642, in _run_subcommand
fn(**fn_kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
call._call_and_handle_interrupt(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
results = self._run_stage()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
self._run_sanity_check()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
val_loop.run()
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 174, in _decorator
return loop_run(self, *args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
self._evaluation_step(batch, batch_idx, dataloader_idx)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in validation_step
return self.model.validation_step(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 120, in validation_step
logits, loss = self._step(batch)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 73, in _step
logits = self._model(batch)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 161, in forward
markup_lm_outputs = self.markup_lm(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 915, in forward
encoder_outputs = self.encoder(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 665, in forward
layer_outputs = layer_module(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 549, in forward
self_attention_outputs = self.attention(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 507, in forward
self_outputs = self.self(
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 310, in forward
context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_mask)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 810, in forward
o, lse, ctx.softmax_scale = _flash_attn_forward(
File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 623, in _flash_attn_forward
_fwd_kernel[grid](
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
return self.run(*args, grid=grid, **kwargs)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
return self.fn.run(*args, **kwargs)
File "<string>", line 41, in _fwd_kernel
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1239, in compile
so = _build(fn.__name__, src_path, tmpdir)
File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1169, in _build
ret = subprocess.check_call(cc_cmd)
File "/home/fernando/anaconda3/envs/python_310_cuda_113/lib/python3.10/subprocess.py", line 369, in check_call
raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpz4gpqr27/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/fernando/anaconda3/envs/python_310_cuda_113/include/python3.10', '-I/tmp/tmpz4gpqr27', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpz4gpqr27/_fwd_kernel.cpython-310-x86_64-linux-gnu.so', '-L/usr/lib', '-L/usr/lib32']' returned non-zero exit status 1.
python-BaseException
wandb: Waiting for W&B process to finish... (failed 1).
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /tmp/wandb/offline-run-20230608_230948-5b856dbf-a5e4-47b3-ad20-6570054489f7
wandb: Find logs at: /tmp/wandb/offline-run-20230608_230948-5b856dbf-a5e4-47b3-ad20-6570054489f7/logs
Process finished with exit code 1
Do you know what it could be, @tridao? Thank you very much in advance.
Idk what's wrong, I use the pytorch docker container and haven't seen this error.
To help other people that might have this problem:
I needed to add two other dependencies in conda: cuda-nvcc
and cudatoolkit-dev
. And I also needed to include an environment variable C_INCLUDE_PATH=$CONDA_HOME/envs/<env_name>/include
.
error: 'scf.for' op expects region #0 to have 0 or 1 blocks for q torch.Size([1120, 280, 4, 32]) (1120, 280, 4, 32) True torch.bfloat16
for triton nightly triton_nightly-2.1.0.dev20230822000928-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
Found out this "error: 'scf.for' op expects region #0 to have 0 or 1 blocks" is triggered by _bwd_kernel autotune, under the configuration of "SEQUENCE_PARALLEL": False.
I highly doubt it is some error happened during: for start_n in range(0, num_block_n): _bwd_kernel_one_col_block( <--- in this function
It's the Triton version. As mentioned at the beginning of the file:
Tested with triton==2.0.0.dev20221202. Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions other than 64: https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 We'll update this implementation with the new Triton backend once this is fixed.
Hey, @tridao. How do you install this version, though? It seems that PyTorch 2.0.0 and 2.0.1 both require triton 2.0.0. If I try to run
pip install triton==2.0.0.dev20221202
, I see that it starts downloading an older PyTorch version without CUDA (Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)
). If I install it with--no-deps
, I get this error:fatal error: cuda.h: No such file or directory 2 | #include "cuda.h" | ^~~~~~~~ compilation terminated.
So, I wonder how I should install this specific triton version.
try to use pip install --no-deps to install trition without update torch
Hi all,
Thanks for your excellent work. I met the following problem when using
triton == 2.0.0
. Forward succeed, but backward failed. How can I solve it, thanks.error: 'scf.for' op expects region #0 to have 0 or 1 blocks
And when I comment out these lines, the error disappers. https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L322-L327
The full error message is: