facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.62k stars 614 forks source link

`memory_efficient_attention` causes graph break in `torch.compile()` #765

Open grazder opened 1 year ago

grazder commented 1 year ago

🐛 Bug

I am trying to use memory_efficient_attention with torch.compile(). But it seems that memory_efficient_attention leads to graph breaks. xformers.ops.unbind also causes graph breaks.

Command

To Reproduce

Model

import math
import torch
import torch.nn as nn

from xformers.ops import memory_efficient_attention, unbind, fmha

class MultiHeadAttention(nn.Module):
    """Multi-Head Attention layer of Transformer.
    Args:
        n_head (int): number of heads
        n_feat (int): size of the features
        dropout_rate (float): dropout rate
    """

    def __init__(self, n_head, n_feat, dropout_rate):
        """Construct an MultiHeadedAttention object."""
        super(MultiHeadAttention, self).__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.s_d_k = math.sqrt(self.d_k)
        self.h = n_head

        self.linear_qkv = nn.Linear(n_feat, n_feat * 3)
        self.linear_out = nn.Linear(n_feat, n_feat)
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x, lengths):
        B, N, C = x.shape
        qkv = self.linear_qkv(x).reshape(1, B * N, 3, self.h, C // self.h)
        q, k, v = unbind(qkv, 2)

        list_lengths = lengths.tolist()
        mask = fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
            q_seqlen=list_lengths,
            kv_padding=N,
            kv_seqlen=list_lengths,
        )
        x = memory_efficient_attention(q, k, v, attn_bias=mask)

        x = x.reshape([B, N, C])

        x = self.linear_out(x)
        x = self.dropout(x)
        return x

No torch.compile() behaviour:

FEAT_SIZE = 128
MAX_LEN = 100
BATCH_SIZE = 8 

mha = MultiHeadAttention(8, FEAT_SIZE, 0.2).to('cuda')
batch = torch.randn(BATCH_SIZE, MAX_LEN, FEAT_SIZE).to('cuda')
lengths = torch.randint(10, MAX_LEN, (BATCH_SIZE, )).to('cuda')

output = mha(batch, lengths)
output.shape
> torch.Size([8, 100, 128])

torch.compile behavior

import logging
import torch._dynamo as dynamo

compile_mha = torch.compile(mha)
compile_output = compile_mha(batch, lengths)

print(compile_output.shape)

explanation, out_guards, graphs, ops_per_graph, break_reasons, explanation_verbose = dynamo.explain(mha, batch, lengths)
print(explanation_verbose)

output:

[2023-06-09 11:07:43,964] torch._inductor.utils: [WARNING] using triton random, expect difference from eager
torch.Size([8, 100, 128])
Dynamo produced 4 graphs with 3 graph break and 6 ops
 Break reasons: 

1. autograd.Function with requires_grad
  File "<ipython-input-2-a910d6af9905>", line 57, in forward
    q, k, v = unbind(qkv, 2)
  File "/usr/local/lib/python3.10/dist-packages/xformers/ops/unbind.py", line 115, in unbind
    return _Unbind.apply(x, dim)

2. return_value
  File "/usr/local/lib/python3.10/dist-packages/xformers/ops/unbind.py", line 84, in <graph break in forward>
    return x.unbind(dim)

3. autograd.Function with requires_grad
  File "<ipython-input-2-a910d6af9905>", line 65, in <graph break in forward>
    x = memory_efficient_attention(q, k, v, attn_bias=mask)
  File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py", line 192, in memory_efficient_attention
    return _memory_efficient_attention(
  File "/usr/local/lib/python3.10/dist-packages/xformers/ops/fmha/__init__.py", line 295, in _memory_efficient_attention
    return _fMHA.apply(

4. return_value
  File "<ipython-input-2-a910d6af9905>", line 71, in <graph break in forward>
    return x

TorchDynamo compilation metrics:
Function                        Runtimes (s)
------------------------------  --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
_compile                        0.0190, 0.0031, 0.0033, 0.0054, 0.0093, 0.0616, 0.0182, 0.0492, 0.0112, 0.0099, 0.0159, 0.0013, 0.0142, 0.0105, 0.0232, 0.0031, 0.0083, 0.0086, 0.0057, 0.0038, 0.0029, 0.0019, 0.0051, 0.0027, 0.0176, 0.0128, 0.0126, 0.0132, 0.0157, 0.0129, 0.0132, 0.0118, 0.0128, 0.0018, 0.0170
OutputGraph.call_user_compiler  0.0000, 0.0000, 0.0000, 0.0000

Expected behavior

It would be nice if xformers didn't create graph breaks

Environment

Current colab

python -m xformers.info

xFormers 0.0.20
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        available
memory_efficient_attention.tritonflashattB:        available
indexing.scaled_index_addF:                        available
indexing.scaled_index_addB:                        available
indexing.index_select:                             available
swiglu.dual_gemm_silu:                             available
swiglu.gemm_fused_operand_sum:                     available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
is_functorch_available:                            False
pytorch.version:                                   2.0.1+cu118
pytorch.cuda:                                      available
gpu.compute_capability:                            7.5
gpu.name:                                          Tesla T4
build.info:                                        available
build.cuda_version:                                1108
build.python_version:                              3.10.11
build.torch_version:                               2.0.1+cu118
build.env.TORCH_CUDA_ARCH_LIST:                    5.0+PTX 6.0 6.1 7.0 7.5 8.0 8.6
build.env.XFORMERS_BUILD_TYPE:                     Release
build.env.XFORMERS_ENABLE_DEBUG_ASSERTIONS:        None
build.env.NVCC_FLAGS:                              None
build.env.XFORMERS_PACKAGE_FROM:                   wheel-v0.0.20
build.nvcc_version:                                11.8.89
source.privacy:                                    open source

This problem also reproduces in my local setup.

bottler commented 1 year ago

Surely the graph break around memory_efficient_attention is a good thing? The xformers implementation of memory_efficient_attention is better than what torch.compile currently can come up with on its own, so the user should want torch.compile to only optimize those parts of the logic which are outside memory_efficient_attention.

vadimkantorov commented 1 year ago

Is there also a graph break around standard PyTorch's F.scaled_dot_product_attention?

If torch.compile is able to preserve the calls to xFormers' mem_efficient and flash_attention impls, for the end users might be somehow less questionable if these (and xFormers' optimized unbind) also do not trigger graph break, as for now it seems that the fewer graph breaks, the better perf

or at least document that these graph breaks are okay. as it may seem that it currently breaks only because these use custom autograd functions (and it's not clear if all custom autograd functions are causing graph breaks or if there are any special requirements/conditions https://github.com/pytorch/pytorch/issues/103318)