Dao-AILab / flash-attention

Fast and memory-efficient exact attention
BSD 3-Clause "New" or "Revised" License
13.83k stars 1.28k forks source link

Flash-attention under Triton 2.0 #234

Open junjie18 opened 1 year ago

junjie18 commented 1 year ago

Hi all,

Thanks for your excellent work. I met the following problem when using triton == 2.0.0. Forward succeed, but backward failed. How can I solve it, thanks. error: 'scf.for' op expects region #0 to have 0 or 1 blocks

And when I comment out these lines, the error disappers. https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L322-L327

The full error message is:

UserWarning: While `attn_impl: triton` can be faster than `attn_impl: flash` it uses more memory. When training larger models this can trigger alloc retries which hurts perfor
mance. If encountered, we recommend using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.
  warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using
`attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')

error: 'scf.for' op expects region #0 to have 0 or 1 blocks
Traceback (most recent call last):
  File "<string>", line 21, in _bwd_kernel
KeyError: ('2-.-0-.-0-7d1eb0d2fed8ff2032dccb99c2cc311a-d6252949da17ceb5f3a278a70250af13-3b85c7bef5f0a641282f3b73af50f599-3d2aedeb40d6d81c66a42791e268f98b-3498c340fd4b6ee7805fd54b882a04f5-e1f133f98d04093da2078dfc51c36b72-b26258bf01f83919
9e39d64851821f26-d7c06e3b46e708006c15224aac7a1378-f585402118c8a136948ce0a49cfe122c', (torch.float16, torch.float16, torch.float16, None, torch.float16, torch.float32, torch.float16, torch.float16, torch.float32, torch.float32, 'fp32', '
i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('none', True, 64,
False, False, False, True, 128, 128), (True, True, True, (False,), True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (Tr
ue, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False),
 (False, False), (False, False), (False, False), (True, False), (True, False), (False, False), (False, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "xxxx/attention.py", line 308, in <module>
    l.backward()
  File "~/anaconda3/lib/python3.10/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "~/anaconda3/lib/python3.10/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "~/anaconda3/lib/python3.10/site-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "~/anaconda3/lib/python3.10/site-packages/flash_attn-0.2.8-py3.10-linux-x86_64.egg/flash_attn/flash_attn_triton.py", line 830, in backward
    _flash_attn_backward(do, q, k, v, o, lse, dq, dk, dv,
  File "~/anaconda3/lib/python3.10/site-packages/flash_attn-0.2.8-py3.10-linux-x86_64.egg/flash_attn/flash_attn_triton.py", line 698, in _flash_attn_backward
    _bwd_kernel[grid](
  File "~/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "~/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 77, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "~/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 65, in _bench
    return do_bench(kernel_call)
  File "~/anaconda3/lib/python3.10/site-packages/triton/testing.py", line 143, in do_bench
    fn()
  File "~/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 63, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "~/anaconda3/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 199, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _bwd_kernel
  File "~/anaconda3/lib/python3.10/site-packages/triton/compiler.py", line 1621, in compile
    next_module = compile(module)
  File "~/anaconda3/lib/python3.10/site-packages/triton/compiler.py", line 1550, in <lambda>
    lambda src: ast_to_ttir(src, signature, configs[0], constants)),
  File "~/anaconda3/lib/python3.10/site-packages/triton/compiler.py", line 963, in ast_to_ttir
    return optimize_triton_ir(mod)
  File "~/anaconda3/lib/python3.10/site-packages/triton/compiler.py", line 957, in optimize_triton_ir
    pm.run(mod)
RuntimeError: PassManager::run failed
junjie18 commented 1 year ago

And when I comment out https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L322-L327.

If I run following code and set head dim = 128, (for example, the size of q is (2, 128, 2, 128) ) the procedure will hang up there. And head dim = 64 (2, 128, 2, 64) works fine.

flash_attn_triton.flash_attn_func(
      q, k, v, None, True/False, None
)
tridao commented 1 year ago

It's the Triton version. As mentioned at the beginning of the file:

Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.
fernandocamargoai commented 1 year ago

It's the Triton version. As mentioned at the beginning of the file:

Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.

Hey, @tridao. How do you install this version, though? It seems that PyTorch 2.0.0 and 2.0.1 both require triton 2.0.0. If I try to run pip install triton==2.0.0.dev20221202, I see that it starts downloading an older PyTorch version without CUDA (Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)). If I install it with --no-deps, I get this error:

fatal error: cuda.h: No such file or directory
    2 | #include "cuda.h"
      |          ^~~~~~~~
compilation terminated.

So, I wonder how I should install this specific triton version.

tridao commented 1 year ago

Yeah idk, it works fine with the Pytorch container from Nvidia (e.g. 23.05 has Pytorch 2.0) by pip install triton==2.0.0.dev20221202.

fernandocamargoai commented 1 year ago

Yeah idk, it works fine with the Pytorch container from Nvidia (e.g. 23.05 has Pytorch 2.0) by pip install triton==2.0.0.dev20221202.

I tried to install it using conda instead of pipenv with the following environment.tml:

name: page-tagging-svc
channels:
  - pytorch
  - nvidia
  - conda-forge
dependencies:
  - python=3.10
  - pytorch=2.0
  - pytorch-cuda=11.7
  - transformers=4.29
  - datasets=2.12
  - pytorch-lightning=2.0
  - torchmetrics=0.11
  - pydantic=1.10
  - python-json-logger=2.0
  - scikit-learn=1.2
  - xgboost=1.7
  - pandas=1.5
  - beautifulsoup4=4.12
  - sentencepiece=0.1
  - fastcore=1.5
  # Dev
  - black=23.3
  - isort=5.12
  - mypy=1.3
  - pylint=2.17
  - pytest=7.3
  - pytest-cov=4.1
  - pytest-mock=3.10
  - click=8.1
  - google-cloud-storage=2.9
  - ultimate-sitemap-parser=0.5
  - pycountry=22.3
  - pillow=9.5
  - jupyterlab=4.0
  - matplotlib=3.7
  - scaleapi=2.14
  - google-cloud-aiplatform=1.26
  - pip
  - pip:
    - jsonargparse[signatures]==4.20.1
    - bentoml==1.0.19
    - starlette==0.24.*
    - iterative-stratification==0.1.7
    - triton==2.0.0.dev20221202

I checked that the correct version of triton was installed. However, I still get:

/tmp/tmpz4gpqr27/main.c:2:10: fatal error: cuda.h: No such file or directory
    2 | #include "cuda.h"
      |          ^~~~~~~~
compilation terminated.
Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float32, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 174, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in validation_step
    return self.model.validation_step(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 120, in validation_step
    logits, loss = self._step(batch)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 73, in _step
    logits = self._model(batch)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 161, in forward
    markup_lm_outputs = self.markup_lm(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 915, in forward
    encoder_outputs = self.encoder(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 665, in forward
    layer_outputs = layer_module(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 549, in forward
    self_attention_outputs = self.attention(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 507, in forward
    self_outputs = self.self(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 310, in forward
    context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_mask)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 810, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 623, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1239, in compile
    so = _build(fn.__name__, src_path, tmpdir)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1169, in _build
    ret = subprocess.check_call(cc_cmd)
  File "/home/fernando/anaconda3/envs/python_310_cuda_113/lib/python3.10/subprocess.py", line 369, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpz4gpqr27/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/fernando/anaconda3/envs/python_310_cuda_113/include/python3.10', '-I/tmp/tmpz4gpqr27', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpz4gpqr27/_fwd_kernel.cpython-310-x86_64-linux-gnu.so', '-L/usr/lib', '-L/usr/lib32']' returned non-zero exit status 1.
python-BaseException
Traceback (most recent call last):
  File "<string>", line 21, in _fwd_kernel
KeyError: ('2-.-0-.-0--2b0c5161c53c71b37ae20a9996ee4bb8-c1f92808b4e4644c1732e8338187ac87-d962222789c30252d492a16cca3bf467-12f7ac1ca211e037f62a7c0c323d9990-5c5e32ff210f3b7f56c98ca29917c25e-06f0df2d61979d629033f4a22eff5198-0dd03b0bd512a184b3512b278d9dfa59-d35ab04ae841e2714a253c523530b071', (torch.float16, torch.float16, torch.float16, torch.float32, torch.float16, torch.float32, torch.float32, 'fp32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32', 'i32'), ('vector', False, 64, False, False, True, 128, 128), (True, True, True, True, True, True, True, (False,), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (True, False), (False, False), (True, False), (True, False), (True, False), (True, False), (False, False), (False, False)))

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 2195, in <module>
    main()
  File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 2177, in main
    globals = debugger.run(setup['file'], None, None, is_module)
  File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 1489, in run
    return self._exec(is_module, entry_point_fn, module_name, file, globals, locals)
  File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/pydevd.py", line 1496, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/home/fernando/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/223.8836.43/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/trainer.py", line 106, in <module>
    cli_main()
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/trainer.py", line 97, in cli_main
    cli = PageTaggingLightningCLI(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 353, in __init__
    self._run_subcommand(self.subcommand)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/cli.py", line 642, in _run_subcommand
    fn(**fn_kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 520, in fit
    call._call_and_handle_interrupt(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 935, in _run
    results = self._run_stage()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 976, in _run_stage
    self._run_sanity_check()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1005, in _run_sanity_check
    val_loop.run()
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/utilities.py", line 174, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 115, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 375, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_kwargs.values())
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 288, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/pytorch_lightning/strategies/strategy.py", line 378, in validation_step
    return self.model.validation_step(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 120, in validation_step
    logits, loss = self._step(batch)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/lightning.py", line 73, in _step
    logits = self._model(batch)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 161, in forward
    markup_lm_outputs = self.markup_lm(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 915, in forward
    encoder_outputs = self.encoder(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 665, in forward
    layer_outputs = layer_module(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 549, in forward
    self_attention_outputs = self.attention(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/transformers/models/markuplm/modeling_markuplm.py", line 507, in forward
    self_outputs = self.self(
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/model/markup_lm.py", line 310, in forward
    context_layer = flash_attn_func(query_layer, key_layer, value_layer, attention_mask)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/torch/autograd/function.py", line 506, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 810, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/fernando/klue_workspace/page-tagging-svc/src/torch/flash_attn_triton.py", line 623, in _flash_attn_forward
    _fwd_kernel[grid](
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/jit.py", line 106, in launcher
    return self.run(*args, grid=grid, **kwargs)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/runtime/autotuner.py", line 200, in run
    return self.fn.run(*args, **kwargs)
  File "<string>", line 41, in _fwd_kernel
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1239, in compile
    so = _build(fn.__name__, src_path, tmpdir)
  File "/home/fernando/.local/share/virtualenvs/page-tagging-svc-6EE2Yu73/lib/python3.10/site-packages/triton/compiler.py", line 1169, in _build
    ret = subprocess.check_call(cc_cmd)
  File "/home/fernando/anaconda3/envs/python_310_cuda_113/lib/python3.10/subprocess.py", line 369, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['/usr/bin/gcc', '/tmp/tmpz4gpqr27/main.c', '-O3', '-I/usr/local/cuda/include', '-I/home/fernando/anaconda3/envs/python_310_cuda_113/include/python3.10', '-I/tmp/tmpz4gpqr27', '-shared', '-fPIC', '-lcuda', '-o', '/tmp/tmpz4gpqr27/_fwd_kernel.cpython-310-x86_64-linux-gnu.so', '-L/usr/lib', '-L/usr/lib32']' returned non-zero exit status 1.
python-BaseException
wandb: Waiting for W&B process to finish... (failed 1).
wandb: You can sync this run to the cloud by running:
wandb: wandb sync /tmp/wandb/offline-run-20230608_230948-5b856dbf-a5e4-47b3-ad20-6570054489f7
wandb: Find logs at: /tmp/wandb/offline-run-20230608_230948-5b856dbf-a5e4-47b3-ad20-6570054489f7/logs

Process finished with exit code 1

Do you know what it could be, @tridao? Thank you very much in advance.

tridao commented 1 year ago

Idk what's wrong, I use the pytorch docker container and haven't seen this error.

fernandocamargoai commented 1 year ago

To help other people that might have this problem:

I needed to add two other dependencies in conda: cuda-nvcc and cudatoolkit-dev. And I also needed to include an environment variable C_INCLUDE_PATH=$CONDA_HOME/envs/<env_name>/include.

jimmieliu commented 1 year ago

error: 'scf.for' op expects region #0 to have 0 or 1 blocks for q torch.Size([1120, 280, 4, 32]) (1120, 280, 4, 32) True torch.bfloat16

for triton nightly triton_nightly-2.1.0.dev20230822000928-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

jimmieliu commented 1 year ago

Found out this "error: 'scf.for' op expects region #0 to have 0 or 1 blocks" is triggered by _bwd_kernel autotune, under the configuration of "SEQUENCE_PARALLEL": False.

I highly doubt it is some error happened during: for start_n in range(0, num_block_n): _bwd_kernel_one_col_block( <--- in this function

idontkonwher commented 6 months ago

It's the Triton version. As mentioned at the beginning of the file:

Tested with triton==2.0.0.dev20221202.
Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions
other than 64:
https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207
We'll update this implementation with the new Triton backend once this is fixed.

Hey, @tridao. How do you install this version, though? It seems that PyTorch 2.0.0 and 2.0.1 both require triton 2.0.0. If I try to run pip install triton==2.0.0.dev20221202, I see that it starts downloading an older PyTorch version without CUDA (Downloading torch-1.13.1-cp310-cp310-manylinux1_x86_64.whl (887.5 MB)). If I install it with --no-deps, I get this error:

fatal error: cuda.h: No such file or directory
    2 | #include "cuda.h"
      |          ^~~~~~~~
compilation terminated.

So, I wonder how I should install this specific triton version.

try to use pip install --no-deps to install trition without update torch