facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.7k stars 621 forks source link

[RTX 3090] Raise NotImplementedError: No operator found for this attention: Inputs when I backward the loss #628

Open leeruibin opened 1 year ago

leeruibin commented 1 year ago

🐛 Bug

Command

To Reproduce

Steps to reproduce the behavior:

  1. I initialize a UNetModel in stablediffusion, simulate an input to get the output of the Unet model.
  2. Then I initialize another fake_label with the same shape as the output.
  3. Finally I use mseloss function to get the loss and call backward. However, it seems that I can get the output with the same Unet network, but when I call backward, it raise

    Traceback (most recent call last): File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/torch/autograd/function.py", line 399, in wrapper outputs = fn(ctx, *args) File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/init.py", line 111, in backward grads = _memory_efficient_attention_backward( File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/init.py", line 376, in _memory_efficient_attention_backward op = _dispatch_bw(inp) File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/dispatch.py", line 68, in _dispatch_bw raise NotImplementedError(f"No operator found for this attention: {inp}") NotImplementedError: No operator found for this attention: Inputs(query=tensor([[[[ 0.1457, 0.8941, -0.0281, ..., -0.0386, -0.2712, 0.9171]],

         [[-0.2015,  0.8000,  0.3302,  ...,  0.3778,  0.0166,  0.7670]],
    
         [[ 0.1928,  1.0940,  0.1479,  ...,  0.3554,  0.1671,  1.2954]],
        .....

    python-BaseException

Here is my code, I download the stable diffusion project, and use the ldm.modules.diffusionmodules.openaimodel

import ldm.modules.diffusionmodules.openaimodel as DFUnet
import torch
model = DFUnet.UNetModel(use_checkpoint=True,
                         num_classes=1000,  # timesteps for noise conditioning (here constant, just need one)
                         image_size=128,
                         in_channels=7,
                         out_channels=4,
                         model_channels=256,
                         attention_resolutions=[2, 4, 8],
                         num_res_blocks=2,
                         channel_mult=[1, 2, 2, 4],
                         disable_self_attentions=[True, True, True, False],
                         disable_middle_self_attn=False,
                         num_heads=8,
                         use_scale_shift_norm=True,
                         # use_fp16=True,
                         use_spatial_transformer=True,
                         transformer_depth=1,
                         context_dim=1024,
                         legacy=False,
                         use_linear_in_transformer=True
                         )
model.cuda()

x_in = torch.randn([8,7,128,128]).cuda()
context = torch.randn([8,77,1024]).cuda()
timesteps = torch.randint(0,1000,[8]).long().cuda()
y = torch.ones([8])*20
y = y.long().cuda()
out = model(x_in, timesteps=timesteps, context=context, y=y)
fake_label = torch.rand_like(out)
loss_fn = torch.nn.MSELoss()
loss = loss_fn(out,fake_label)
loss.backward()

python -m xformers.info

xFormers 0.0.15.dev395+git.7e05e2c
memory_efficient_attention.cutlassF:               available
memory_efficient_attention.cutlassB:               available
memory_efficient_attention.flshattF:               available
memory_efficient_attention.flshattB:               available
memory_efficient_attention.smallkF:                available
memory_efficient_attention.smallkB:                available
memory_efficient_attention.tritonflashattF:        available
memory_efficient_attention.tritonflashattB:        available
swiglu.fused.p.cpp:                                available
is_triton_available:                               True
is_functorch_available:                            False
pytorch.version:                                   1.12.1
pytorch.cuda:                                      available
gpu.compute_capability:                            8.6
gpu.name:                                          NVIDIA GeForce RTX 3090
danthe3rd commented 1 year ago

Oh I see, this is related to this: https://github.com/facebookresearch/xformers/issues/517 You should be able to train in f16 tho if that's supported

leeruibin commented 1 year ago

I try to use fp16 to run the demo, it return

  Traceback (most recent call last):
    File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/torch/autograd/function.py", line 399, in wrapper
      outputs = fn(ctx, *args)
    File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 111, in backward
      grads = _memory_efficient_attention_backward(
    File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/__init__.py", line 381, in _memory_efficient_attention_backward
      grads = op.apply(ctx, inp, grad)
    File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/xformers/ops/fmha/cutlass.py", line 184, in apply
      (grad_q, grad_k, grad_v,) = cls.OPERATOR(
    File "/home/anaconda/envs/pyDF/lib/python3.9/site-packages/torch/_ops.py", line 143, in __call__
      return self._op(*args, **kwargs or {})
  RuntimeError: CUDA error: invalid argument
  CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
  For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

I use half() to enable fp16 in demo.py

model.half()

batch_size = 1
x_in = torch.randn([batch_size,7,128,128]).cuda().half()
context = torch.randn([batch_size,77,1024]).cuda().half()
timesteps = torch.randint(0,1000,[batch_size]).long().cuda()
y = torch.ones([batch_size])*20
y = y.long().cuda()
out = model(x_in, timesteps=timesteps, context=context, y=y)
fake_label = torch.rand_like(out).half()
loss_fn = torch.nn.MSELoss()
loss = loss_fn(out,fake_label)
loss.backward()

with conda list the torch version is 1.13.1 the cudatoolkit version is 11.6.0

danthe3rd commented 1 year ago

It looks like you are doing the right things. Unfortunately, I don't have an RTX 3090 at hand to test, and this GPU is also not a priority for us, as we focus on V100/A100 mostly. If you can find a fix, we can get it landed, but that's not something we will prioritize at this point.

leeruibin commented 1 year ago

Thanks

zaptrem commented 1 year ago

@danthe3rd I have one reproducable situation where it works, and one where it doesn't. How can I help drill down to solve this issue?

Works:

def AttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
    mid_features = head_features * num_heads
    to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False)

    def forward(
        q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
    ) -> Tensor:
        # Use memory efficient attention
        out = xformers.ops.memory_efficient_attention(q, k, v)
        return to_out(out)

    return Module([to_out], forward)

Doesn't work:

def OldLinearAttentionBase(features: int, head_features: int, num_heads: int) -> nn.Module:
    scale = head_features**-0.5
    num_heads = num_heads
    mid_features = head_features * num_heads
    to_out = nn.Linear(in_features=mid_features, out_features=features, bias=False)

    # supposed to be functionally equivalent to memory_efficient_attention
    # source: https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention:~:text=to%20be%201-,Equivalent%20pytorch%20code,-scale%20%3D%201
    def atten(query, key, value):
        scale = 1 / query.shape[-1] ** 0.5
        query = query * scale
        attn = query @ key.transpose(-2, -1)
        attn = attn.softmax(-1)
        return attn @ value

    def forward(q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        q, k, v = map(lambda t: rearrange(t, "b t c -> b c t").contiguous(), (q, k, v))
        # Attending over channel dim
        attn = xformers.ops.memory_efficient_attention(q, k, v) # crashes during backward pass
        #attn = atten(q, k, v) # works fine
        attn = rearrange(attn, "b c t -> b t c")
        return to_out(attn)

    return Module([to_out], forward)
danthe3rd commented 1 year ago

Can you also provide the inputs that lead to the NANs?

zaptrem commented 1 year ago

Can you also provide the inputs that lead to the NANs?

I'm not sure the best way to share large raw matrices over the internet. Is there a standard way to do so? In the meantime, I inserted a print statement here (cutlass.py line 183):

    @classmethod
    def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:
        if inp.attn_bias is not None and not isinstance(
            inp.attn_bias, LowerTriangularMask
        ):
            raise NotImplementedError("Unsupported attn_bias type")
        causal = isinstance(inp.attn_bias, LowerTriangularMask)
        dtype = inp.query.dtype

        print("grad: ", grad.shape)

        force_pad_inf = torch.cuda.get_device_capability(inp.query.device) == (7, 5)
        (grad_q, grad_k, grad_v,) = cls.OPERATOR(
            grad.to(dtype),
            inp.query,
            inp.key,
            inp.value,
            ctx.get_padded_lse(32, force_pad_inf=force_pad_inf),
            ctx.out.to(dtype),
            causal=causal,
            scale=inp.scale,
        )
        return Gradients(dq=grad_q, dk=grad_k, dv=grad_v)

And also enabled CUDA_LAUNCH_BLOCKING, and this was the result: https://pastebin.com/dXjkDgXe

danthe3rd commented 1 year ago

Wait you have embed_dim_per_head = 4096some times?! This is usually<128. So this error looks related to your GPU model not being fully supported by XFormers at the moment. This might change in the future if we manage to reduce shmem usage ( @jfc4050 might have something) but likely not in the near future. The error could be improved tho and we need to fix that at least

zaptrem commented 1 year ago

Wait you have embed_dim_per_head = 4096some times?! This is usually<128. So this error looks related to your GPU model not being fully supported by XFormers at the moment. This might change in the future if we manage to reduce shmem usage ( @jfc4050 might have something) but likely not in the near future. The error could be improved tho and we need to fix that at least

It is possible I messed something up in the channel-wise Linear Attention function above. The idea is to apply attention channel-wise instead of time-wise. I added more print statements for the other args going into cls.OPERATOR()

grad:  torch.Size([10, 512, 1, 128])
query:  torch.Size([10, 512, 1, 128])
key:  torch.Size([10, 512, 1, 128])
value:  torch.Size([10, 512, 1, 128])
padded lse torch.Size([10, 1, 512])
ctx:  torch.Size([10, 512, 1, 128])
causal:  False
scale:  None

How do I make it gracefully fall back to a compatible implementation for the only backward pass when the shared mem can't handle it? It works fine both ways for most of the attention units I have and at least for the forward pass on others.

EDIT: I tried forcing all the other implementations (low K, flash) and it didn't work with those, so I'm now leaning towards this being an issue with my model. I agree a more descriptive error message with possible solutions could be helpful for others in the future.

danthe3rd commented 1 year ago

A few things: (1) I believe we don't support 64 < embedding_per_head <= 128 on RTX 3090 for the backward on any implementation (2) For embedding_per_head > 128, the kernel will be very slow (and possibly slower than a regular pytorch implementation), so might want to drop the mmeory efficient attention and use a vanilla pytorch implementation instead

danthe3rd commented 1 year ago

Related issue: https://github.com/facebookresearch/xformers/issues/517

zaptrem commented 1 year ago

A few things: (1) I believe we don't support 64 < embedding_per_head <= 128 on RTX 3090 for the backward on any implementation (2) For embedding_per_head > 128, the kernel will be very slow (and possibly slower than a regular pytorch implementation), so might want to drop the mmeory efficient attention and use a vanilla pytorch implementation instead

It would be great if Xformers could detect cases that are only supported by vanilla PyTorch impl and fall back to that so we can keep the speed/memory benefits for the vast majority of attention calls that are within those bounds (but also gain the speedups when we're allocated A100s).

jfc4050 commented 1 year ago

This might change in the future if we manage to reduce shmem usage ( @jfc4050 might have something) but likely not in the near future.

yes might not be for a little while unfortunately, need some reworking to have a unique code path for half precision, k <= 128, and SM80. in general, these changes apply to you if you have head_dim > 128.