SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
346 stars 27 forks source link

Backward pass is very slow towards much tokens and large kernels #161

Closed pwangcs closed 1 month ago

pwangcs commented 1 month ago

Hi, Ali, I have been looking forward to implementing Attenion like Convolution. NATTEN is a very nice implementation.

Naive PyTorch implementation, using operators like Unfold, ReplicationPad, is mainly faced with a large memory usage issue. torch.utils.checkpoint can save menory at the cost of reducing speed. The combination of naive PyTorch implementation and torch.utils.checkpoint is an intuitive solution, but it is very slow towards much tokens and large kernels.

Unfortunetely, I find that NATTEN is still very slow during backward pass:

  1. much tokens and large kernels;
  2. 3D is slower than 2D
  3. fused NA is is slower than unfused NA.
import time
import natten
import torch
from natten.functional import na2d, na2d_qk, na2d_av, na3d, na3d_qk, na3d_av
import os
from torch.nn import functional as F
natten.use_fused_na(True)
natten.use_kv_parallelism_in_fused_na(True)

device = 'cuda'

method = 'unfused'  # 'fused' or 'unfused'.  'fused' for {na2d, na3d}; 'unfused' for {na2d_qk, na2d_av, na3d_qk, na3d_av}
data_type = '3D' # '2D' or '3D'

b = 1
num_heads, head_dim = 8, 32
if data_type =='3D':
    t, h, w = 16, 256, 256
    kernel_size = (3,5,5)
elif data_type =='2D':
    t, h, w = 1, 1024, 1024
    kernel_size = (9,9)

print(f"\nUsing {method} NA for {data_type} data [{t},{h},{w}] with kernel size of {kernel_size}\n")

for i in range(5):
    print(f"----------\ntrial {i}")
    if method == 'fused':
        query = torch.randn(b, t, h, w, num_heads, head_dim, requires_grad=True).squeeze(1).to(device)
    elif method == 'unfused':
        query = torch.randn(b, num_heads, t, h, w, head_dim, requires_grad=True).squeeze(2).to(device)
    key, value = torch.randn_like(query).to(device), torch.randn_like(query).to(device)

    torch.cuda.synchronize()
    start = time.time()
    if method == 'fused':
        if data_type =='2D':
            output = na2d(query, key, value, kernel_size=kernel_size)
        elif data_type =='3D':
            output = na3d(query, key, value, kernel_size=kernel_size)
    elif method == 'unfused':
        if data_type =='2D':
            attn = na2d_qk(query, key, kernel_size=kernel_size, dilation=1)
            attn = F.softmax(attn, dim=-1)
            output = na2d_av(attn,value, kernel_size=kernel_size, dilation=1)
        elif data_type =='3D':
            attn = na3d_qk(query, key, kernel_size=kernel_size, dilation=1)
            attn = F.softmax(attn, dim=-1)
            output = na3d_av(attn,value, kernel_size=kernel_size, dilation=1)
    end = time.time()
    print(f"forward: {end - start:3f}")

    loss = output.sum()

    torch.cuda.synchronize()
    start = time.time()
    loss.backward()
    torch.cuda.synchronize()
    end = time.time()
    print(f"backward: {end - start:3f}")

Unfused 2D NA result:

   Using unfused NA for 2D data [1,1024,1024] with kernel size of (9, 9)

----------
trial 0
forward: 0.006576
backward: 0.855989
----------
trial 1
forward: 0.000540
backward: 1.041160
----------
trial 2
forward: 0.000619
backward: 0.810329
----------
trial 3
forward: 0.000577
backward: 0.801688
----------
trial 4
forward: 0.000591
backward: 0.800081

Fused 2D NA result:

Using fused NA for 2D data [1,1024,1024] with kernel size of (9, 9)

----------
trial 0
forward: 0.002605
backward: 1.035360
----------
trial 1
forward: 0.000530
backward: 1.051781
----------
trial 2
forward: 0.000379
backward: 1.063205
----------
trial 3
forward: 0.000358
backward: 1.075845
----------
trial 4
forward: 0.000622
backward: 1.121022

Unfused 3D NA result:

Using unfused NA for 3D data [16,256,256] with kernel size of (3, 5, 5)

----------
trial 0
forward: 0.007787
backward: 1.141421
----------
trial 1
forward: 0.000634
backward: 1.134172
----------
trial 2
forward: 0.000909
backward: 1.124384
----------
trial 3
forward: 0.000777
backward: 1.093474
----------
trial 4
forward: 0.000623
backward: 1.187929

Fused 3D NA result:

Using fused NA for 3D data [16,256,256] with kernel size of (3, 5, 5)

----------
trial 0
forward: 0.002399
backward: 3.432977
----------
trial 1
forward: 0.003000
backward: 3.588447
----------
trial 2
forward: 0.000728
backward: 3.208563
----------
trial 3
forward: 0.000629
backward: 3.856050
----------
trial 4
forward: 0.001306
backward: 3.289000

Note that 2D data [1024, 1024] have the same token numbers with 3D data [16,256,256]. And 2D kernel size [9, 9] (=81) is closed to 3D [3, 5, 5] (=75).

For high-resolution image or video, current NATTEN looks like too slow to train. By the way, I build a NATTEN-based model with kernel size (5,7,7) for input size (8,128,128), 3k iterations need about 2.5 hours.

alihassanijr commented 1 month ago

Thank you for your interest. Something's definitely not right here. Could you please share what GPU you're running this on, and what your torch/cuda version is?

Also, have you tried setting memory usage in backprop to max?

natten.set_memory_usage_preference("unrestricted")

This kind of caps how high KV parallelism can be.

alihassanijr commented 1 month ago

Also, is this FP32? Or FP16/BF16?

pwangcs commented 1 month ago

Thank you for your interest. Something's definitely not right here. Could you please share what GPU you're running this on, and what your torch/cuda version is?

Also, have you tried setting memory usage in backprop to max?

natten.set_memory_usage_preference("unrestricted")

This kind of caps how high KV parallelism can be.

Thanks for your reply. Following your suggestion natten.set_memory_usage_preference("unrestricted"), the speedup is small.

Using fused NA for 3D data [16,256,256] with kernel size of (3, 5, 5)

----------
trial 0
forward: 0.002260
backward: 3.523521
----------
trial 1
forward: 0.000404
backward: 2.927348
----------
trial 2
forward: 0.001115
backward: 2.918803
----------
trial 3
forward: 0.000685
backward: 3.260884
----------
trial 4
forward: 0.000731
backward: 2.952460

I implement the testing with PyTorch 2.0.1 and cuda_11.8 on NVIDIA A40.

pwangcs commented 1 month ago

Also, is this FP32? Or FP16/BF16?

FP32

alihassanijr commented 1 month ago

Is this FP32? If so, then I wouldn't be surprised if unfused is better than fused.

alihassanijr commented 1 month ago

Okay that explains a lot. Can you try FP16/BF16?

pwangcs commented 1 month ago

Is this FP32? If so, then I wouldn't be surprised if unfused is better than fused.

Ok, I am most interested in whether the about 1s backward pass is resonable for unfused 2D/3D NA?

alihassanijr commented 1 month ago

Ok, I am most interested in whether the about 1s backward pass is resonable for unfused 2D/3D NA?

Yes I would think so -- because note that backward pass is always relatively more complex, and your measurement probably includes some overhead from doing autograd as well. The only sure way of measuring how much time the backward pass is actually taking is benchmarking the ops separately like we do in the NATTEN profiler.

To provide some context on the FP32 vs FP16 issue: Unfused attention is usually heavily bound by memory bandwidth, and fused attention can make it compute bound, but usually only as long as you do FP16, because FP32 peak FLOPS is usually terrible to begin with.

So if you run attention with FP32 or higher precision, you'll likely be memory bound. If your kernel size to input size ratio is too small you'll again wind up slightly memory bound.

For example, the arithmetic intensity of your 1024x1024 problem with a 9x9 kernel with FP32 is around 9 for unfused, and around 20 for fused attention, which means both are very likely memory bandwidth bound -- peak TFLOPS for A40 is 37.4, BW is 696 GB/s, so that means the threshold is around 54. Your 3D problem is in a similar boat, at roughly 9 and 19, again heavily bandwidth bound.

This basically means that fused won't necessarily work to your advantage in these cases, because we're not bound by compute to begin with. And of course, because of the complexity of the fused backwards pass, it's easy for the unfused variant to outperform it, because it just has less overhead and more room for occupancy than fused typically does.

alihassanijr commented 1 month ago

If you're curious, this is what the profiling result is on your 2D FP32 case -- though this is an A100, so different peak FLOPS and BW, but same architecture. But observations remain the same, in that unfused can actually provide better runtime than fused (see above for why that is).

Gradient ops for Q, K, V, and attn weights is around 69 ms, if we add the softmax backprop and the unfused scale ops, it becomes 77.5 ms. The fused op unfortunately takes more than twice that -- around 175 ms.

Unfused:

                         Profiler results
┏━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Kernel type ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ natten.gemm │ Sm80 │             QKRPB             │ 14.525ms  │
│ natten.gemm │ Sm80 │              AV               │ 10.399ms  │
│ natten.gemm │ Sm80 │             QGRAD             │ 10.399ms  │
│ natten.gemm │ Sm80 │             KGRAD             │ 22.282ms  │
│ natten.gemm │ Sm80 │             VGRAD             │ 22.282ms  │
│ natten.gemm │ Sm80 │             AGRAD             │ 14.103ms  │
│      -      │  -   │     softmax_warp_forward      │  3.410ms  │
│  at.native  │  -   │ vectorized_elementwise_kernel │  2.382ms  │
│  at.native  │  -   │ vectorized_elementwise_kernel │  2.398ms  │
│      -      │  -   │     softmax_warp_backward     │  5.003ms  │
│             │      │             Total             │ 107.184ms │
└─────────────┴──────┴───────────────────────────────┴───────────┘

Fused:

                                 Profiler results
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃         Kernel type          ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│         natten.fused         │ Sm80 │         FusedForward          │ 26.211ms  │
│         natten.fused         │ Sm80 │         FusedBackward         │ 169.742ms │
│ natten.cuda.reduction.kernel │  -   │         ComputeDelta          │  4.117ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│                              │      │             Total             │ 203.182ms │
└──────────────────────────────┴──────┴───────────────────────────────┴───────────┘
alihassanijr commented 1 month ago

But of course FP16 looks --very-- different:

                         Profiler results
┏━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Kernel type ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ natten.gemm │ Sm80 │             QKRPB             │ 11.153ms  │
│ natten.gemm │ Sm80 │              AV               │ 11.821ms  │
│ natten.gemm │ Sm80 │             QGRAD             │ 11.821ms  │
│ natten.gemm │ Sm80 │             KGRAD             │ 18.099ms  │
│ natten.gemm │ Sm80 │             VGRAD             │ 18.099ms  │
│ natten.gemm │ Sm80 │             AGRAD             │ 11.138ms  │
│      -      │  -   │     softmax_warp_forward      │  2.832ms  │
│  at.native  │  -   │ vectorized_elementwise_kernel │  2.393ms  │
│      -      │  -   │     softmax_warp_backward     │  2.711ms  │
│             │      │             Total             │ 90.066ms  │
└─────────────┴──────┴───────────────────────────────┴───────────┘
                                 Profiler results
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃         Kernel type          ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│         natten.fused         │ Sm80 │         FusedForward          │  9.598ms  │
│         natten.fused         │ Sm80 │         FusedBackward         │ 62.713ms  │
│ natten.cuda.reduction.kernel │  -   │         ComputeDelta          │  4.203ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│                              │      │             Total             │ 79.626ms  │
└──────────────────────────────┴──────┴───────────────────────────────┴───────────┘
pwangcs commented 1 month ago

If you're curious, this is what the profiling result is on your 2D FP32 case -- though this is an A100, so different peak FLOPS and BW, but same architecture. But observations remain the same, in that unfused can actually provide better runtime than fused (see above for why that is).

Gradient ops for Q, K, V, and attn weights is around 69 ms, if we add the softmax backprop and the unfused scale ops, it becomes 77.5 ms. The fused op unfortunately takes more than twice that -- around 175 ms.

Unfused:

                         Profiler results
┏━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Kernel type ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ natten.gemm │ Sm80 │             QKRPB             │ 14.525ms  │
│ natten.gemm │ Sm80 │              AV               │ 10.399ms  │
│ natten.gemm │ Sm80 │             QGRAD             │ 10.399ms  │
│ natten.gemm │ Sm80 │             KGRAD             │ 22.282ms  │
│ natten.gemm │ Sm80 │             VGRAD             │ 22.282ms  │
│ natten.gemm │ Sm80 │             AGRAD             │ 14.103ms  │
│      -      │  -   │     softmax_warp_forward      │  3.410ms  │
│  at.native  │  -   │ vectorized_elementwise_kernel │  2.382ms  │
│  at.native  │  -   │ vectorized_elementwise_kernel │  2.398ms  │
│      -      │  -   │     softmax_warp_backward     │  5.003ms  │
│             │      │             Total             │ 107.184ms │
└─────────────┴──────┴───────────────────────────────┴───────────┘

Fused:

                                 Profiler results
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃         Kernel type          ┃ Arch ┃           Operation           ┃ CUDA time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│         natten.fused         │ Sm80 │         FusedForward          │ 26.211ms  │
│         natten.fused         │ Sm80 │         FusedBackward         │ 169.742ms │
│ natten.cuda.reduction.kernel │  -   │         ComputeDelta          │  4.117ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│          at.native           │  -   │ vectorized_elementwise_kernel │  1.556ms  │
│                              │      │             Total             │ 203.182ms │
└──────────────────────────────┴──────┴───────────────────────────────┴───────────┘

Ok, I am most interested in whether the about 1s backward pass is resonable for unfused 2D/3D NA?

Yes I would think so -- because note that backward pass is always relatively more complex, and your measurement probably includes some overhead from doing autograd as well. The only sure way of measuring how much time the backward pass is actually taking is benchmarking the ops separately like we do in the NATTEN profiler.

To provide some context on the FP32 vs FP16 issue: Unfused attention is usually heavily bound by memory bandwidth, and fused attention can make it compute bound, but usually only as long as you do FP16, because FP32 peak FLOPS is usually terrible to begin with.

So if you run attention with FP32 or higher precision, you'll likely be memory bound. If your kernel size to input size ratio is too small you'll again wind up slightly memory bound.

For example, the arithmetic intensity of your 1024x1024 problem with a 9x9 kernel with FP32 is around 9 for unfused, and around 20 for fused attention, which means both are very likely memory bandwidth bound -- peak TFLOPS for A40 is 37.4, BW is 696 GB/s, so that means the threshold is around 54. Your 3D problem is in a similar boat, at roughly 9 and 19, again heavily bandwidth bound.

This basically means that fused won't necessarily work to your advantage in these cases, because we're not bound by compute to begin with. And of course, because of the complexity of the fused backwards pass, it's easy for the unfused variant to outperform it, because it just has less overhead and more room for occupancy than fused typically does.

I agree with you. Unfused NA did need more memory footprint

alihassanijr commented 1 month ago

So your expected pattern would essentially allow different kernel sizes for different query pixels/tokens? Is that right?

pwangcs commented 1 month ago

So your expected pattern would essentially allow different kernel sizes for different query pixels/tokens? Is that right?

The fixed non-cube kernels for all query tokens, namely variant-size along time dimension.

alihassanijr commented 1 month ago

Oh I see -- so kind of like attending to more context the more "recent" the frame is?

pwangcs commented 1 month ago

Oh I see -- so kind of like attending to more context the more "recent" the frame is?

Paying more attention to farther frames, leading to increassing kernel sizes along time.

pwangcs commented 1 month ago

In my viewpoint, it would be perfect if na2d_qk or na3d_qk supports an indexes tensor to determine the attended tokens, instead of the default cube receptive field.

alihassanijr commented 1 month ago

Okay yes I think this should be achievable -- but I have to think more about how to actually implement it. It's too specific a feature to include in NATTEN directly, but I can try and provide a branch with that implementation; but of course I still have to think about how to approach it.

One thing I'll note though is that your expected efficiency gain will very likely be the same as if you used a constant kernel size across time (meaning the runtime will be the same, but the masking will be different.)

If you think you'll need to stick with FP32 precision though, it might just be easier to use unfused (since it can probably match or outperform FNA) and just mask the weights. But if you'd like the fused op to support this feature, it might take some time for me to figure it out and implement it.

pwangcs commented 1 month ago

Okay yes I think this should be achievable -- but I have to think more about how to actually implement it. It's too specific a feature to include in NATTEN directly, but I can try and provide a branch with that implementation; but of course I still have to think about how to approach it.

One thing I'll note though is that your expected efficiency gain will very likely be the same as if you used a constant kernel size across time (meaning the runtime will be the same, but the masking will be different.)

If you think you'll need to stick with FP32 precision though, it might just be easier to use unfused (since it can probably match or outperform FNA) and just mask the weights. But if you'd like the fused op to support this feature, it might take some time for me to figure it out and implement it.

Do you mean that implement sparse attention on the current cube receptive field cannot save calculations and just use a corresponding mask to discard the calculated attention weights?Currently, I do that by masking the result of na3d_qk.

alihassanijr commented 1 month ago

Kind of -- many of our implementations, particularly the GEMM-based and FNA kernels, are tiling inputs, and tile shapes are less flexible and definitely difficult to express as anything but rectangular cubes, because anything else is just difficult to express. Because of that, KV tiles fall into two groups: either they're completely masked, or partially masked (or not at all masked). The kernels gain all of their speedup from skipping the former tiles.

So if the upper bound farthest neighbors don't change from switching from current behavior to your expected behavior, it is somewhat unlikely that the number of masked tiles can increase much, which means you probably won't get much of a speedup if any.

The naive kernels are different, they don't do tiling and just do one dot product per thread, and those --might-- gain a small but imo inconsiderable speedup because the only thing affected is the length of the KV for loop for some queries but not all, so the longest per-CTA runtime will likely also stay the same.

Because of this, and the fact that you're doing FP32 where FNA might not even give you that much speedup to begin with, I would guess that masking as you do right now is likely a good short term solution. FNA + Autotuning --might-- work out for your case if we figure out the implementation, but autotuning is unstable as is for training.

Overall I think it would be somewhat difficult to gain more of a speedup from your current approach, especially in FP32.

pwangcs commented 1 month ago

Ok, I am most interested in whether the about 1s backward pass is resonable for unfused 2D/3D NA?

Yes I would think so -- because note that backward pass is always relatively more complex, and your measurement probably includes some overhead from doing autograd as well. The only sure way of measuring how much time the backward pass is actually taking is benchmarking the ops separately like we do in the NATTEN profiler.

To provide some context on the FP32 vs FP16 issue: Unfused attention is usually heavily bound by memory bandwidth, and fused attention can make it compute bound, but usually only as long as you do FP16, because FP32 peak FLOPS is usually terrible to begin with.

So if you run attention with FP32 or higher precision, you'll likely be memory bound. If your kernel size to input size ratio is too small you'll again wind up slightly memory bound.

For example, the arithmetic intensity of your 1024x1024 problem with a 9x9 kernel with FP32 is around 9 for unfused, and around 20 for fused attention, which means both are very likely memory bandwidth bound -- peak TFLOPS for A40 is 37.4, BW is 696 GB/s, so that means the threshold is around 54. Your 3D problem is in a similar boat, at roughly 9 and 19, again heavily bandwidth bound.

This basically means that fused won't necessarily work to your advantage in these cases, because we're not bound by compute to begin with. And of course, because of the complexity of the fused backwards pass, it's easy for the unfused variant to outperform it, because it just has less overhead and more room for occupancy than fused typically does.

Kind of -- many of our implementations, particularly the GEMM-based and FNA kernels, are tiling inputs, and tile shapes are less flexible and definitely difficult to express as anything but rectangular cubes, because anything else is just difficult to express. Because of that, KV tiles fall into two groups: either they're completely masked, or partially masked (or not at all masked). The kernels gain all of their speedup from skipping the former tiles.

So if the upper bound farthest neighbors don't change from switching from current behavior to your expected behavior, it is somewhat unlikely that the number of masked tiles can increase much, which means you probably won't get much of a speedup if any.

The naive kernels are different, they don't do tiling and just do one dot product per thread, and those --might-- gain a small but imo inconsiderable speedup because the only thing affected is the length of the KV for loop for some queries but not all, so the longest per-CTA runtime will likely also stay the same.

Because of this, and the fact that you're doing FP32 where FNA might not even give you that much speedup to begin with, I would guess that masking as you do right now is likely a good short term solution. FNA + Autotuning --might-- work out for your case if we figure out the implementation, but autotuning is unstable as is for training.

Overall I think it would be somewhat difficult to gain more of a speedup from your current approach, especially in FP32.

By testing, I also find that tiling and indexing operations may be the bottleneck of speedup, rather than matrix calcualtions. None-cube kernels would be more diffult on tiling and indexing. Is it right?

In low-resolution image and video, speedup NATTEN may be trivial. But for high-resolution vision tasks, like video superresolution, current NATEEN looks like slower than Swin Attention.

pwangcs commented 1 month ago

Larger input size and kernel size, more time tiling-related ops need?

Because I find that the speed of NATTEN is highly closed to the input size and kernel size, particularly during backward pass. In large size case, the speed issue is severe. So I am curious how to speedup it in principle or/and skill.

alihassanijr commented 1 month ago

By testing, I also find that tiling and indexing operations may be the bottleneck of speedup, rather than matrix calcualtions. None-cube kernels would be more diffult on tiling and indexing. Is it right?

Kind of; software predication is definitely a bottleneck in 2D and 3D, and unfortunately there's not too much that can be done about that, but we are working on some optimizations there, and auto-tuning (once we figure out how to make it faster and more stable) will resolve a lot of those, because there's plenty more tile shapes possible in 2D vs 1D and 3D vs 2D.

In low-resolution image and video, speedup NATTEN may be trivial. But for high-resolution vision tasks, like video superresolution, current NATEEN looks like slower than Swin Attention.

Again, this is highly dependent on a lot of factors. NATTEN isn't just one kernel, and neither is Swin, and there's a lot of factors at play. If you can share more details about your specific workload, we can try and help figure out where the issue is. Both methods suffer from the same basic challenges in terms of performance optimizations, but one is just less invasive of the implementation, so it can always pick up the best attention kernel available, but is in turn less flexible. The other is more flexible in what you can do with it, but requires its own kernels, and it's near impossible to tell what the issue is without specifically looking at what kernels are being launched, on what software/hardware, and with what parameters.

Larger input size and kernel size, more time tiling-related ops need? I'm not sure I understand this.

Because I find that the speed of NATTEN is highly closed to the input size and kernel size, particularly during backward pass. In large size case, the speed issue is severe. So I am curious how to speedup it in principle or/and skill.

Yes, the speed of any operation depends on the input size and its parameters. And again, there's many many things that can affect performance, and without knowing the specifics of your use case I can't tell what can or cannot be the issue. If I do, I'm just a guess.

If you're still talking about the unfused case, and in FP32, and you can't use the GEMM-based kernels (again judging by your original post), then you're left with naive kernels, and yes we don't expect them to be performant. If they were more performant than say a standard BMM, I would be very surprised.

There's a plan in place to improve and potentially replace our naive backend, but it hasn't been a priority so far. But we'll certainly bump it up in our list.

alihassanijr commented 1 month ago

So I am curious how to speedup it in principle or/and skill.

There's just a bunch of factors. Off the top of my head, you rarely run the same kernel for different use cases. NATTEN's dispatcher will decide which kernels are available based on your use case and your environment, and picks one and runs it. If it's a naive kernel (which is likely in the unfused case), then it's just a kernel that doesn't do any tiling, doesn't go through smem and directly loads into registers, doesn't do LDGSTS, or use tensor cores, and its only objectives are functional correctness and maximum occupancy.

Maximum occupancy can be useful in smaller problem sizes, because it's much easier to issue as big of a wave as possible with kernels that are one-thread one-value like our naive kernels, instead of tiled. That's just a matter of utilizing more SMs basically.

And of course, like I said, the smaller your problem size, or the smaller your problem size to kernel size ratio, the more you're memory bandwidth bound than compute bound, so tensor core kernels will likely not have much of an advantage there.

With GEMM-based, memory alignment and the explicit gather/scatter kills performance a lot, but the biggest factor is the software predication required for multi-modal layouts that we rely on.

With FNA, software predication has been the biggest contributor to latency over the baseline; our 1D kernels basically match FMHA in most cases, but fall short in 2D and 3D more. This is something that we can likely optimize more but probably not much. In newer architectures like Hopper this won't be as bad, because we can do hardware predication.

And so many more factors that I'm just not actively thinking about but do matter.

pwangcs commented 1 month ago

By testing, I also find that tiling and indexing operations may be the bottleneck of speedup, rather than matrix calcualtions. None-cube kernels would be more diffult on tiling and indexing. Is it right?

Kind of; software predication is definitely a bottleneck in 2D and 3D, and unfortunately there's not too much that can be done about that, but we are working on some optimizations there, and auto-tuning (once we figure out how to make it faster and more stable) will resolve a lot of those, because there's plenty more tile shapes possible in 2D vs 1D and 3D vs 2D.

In low-resolution image and video, speedup NATTEN may be trivial. But for high-resolution vision tasks, like video superresolution, current NATEEN looks like slower than Swin Attention.

Again, this is highly dependent on a lot of factors. NATTEN isn't just one kernel, and neither is Swin, and there's a lot of factors at play. If you can share more details about your specific workload, we can try and help figure out where the issue is. Both methods suffer from the same basic challenges in terms of performance optimizations, but one is just less invasive of the implementation, so it can always pick up the best attention kernel available, but is in turn less flexible. The other is more flexible in what you can do with it, but requires its own kernels, and it's near impossible to tell what the issue is without specifically looking at what kernels are being launched, on what software/hardware, and with what parameters.

Larger input size and kernel size, more time tiling-related ops need? I'm not sure I understand this.

Because I find that the speed of NATTEN is highly closed to the input size and kernel size, particularly during backward pass. In large size case, the speed issue is severe. So I am curious how to speedup it in principle or/and skill.

Yes, the speed of any operation depends on the input size and its parameters. And again, there's many many things that can affect performance, and without knowing the specifics of your use case I can't tell what can or cannot be the issue. If I do, I'm just a guess.

If you're still talking about the unfused case, and in FP32, and you can't use the GEMM-based kernels (again judging by your original post), then you're left with naive kernels, and yes we don't expect them to be performant. If they were more performant than say a standard BMM, I would be very surprised.

There's a plan in place to improve and potentially replace our naive backend, but it hasn't been a priority so far. But we'll certainly bump it up in our list.

Sorry, after a deep testing, I find that unfused NATEEN is as fast as Swin for FP32 2D input. And na3d_qk and na3d_av look like slower than na2d_qk and na2d_qv for same token numbers and almost same kernel size. Just my experiment.