SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
346 stars 27 forks source link

Memory bottleneck of NATTEN #69

Closed confucianism72 closed 7 months ago

confucianism72 commented 10 months ago

Here is one memory snapshot illustrated using torch.cuda.memory

I am testing memory usage of NATTEN.

截屏2023-11-28 18 14 55

My code is here.

model = NeighborhoodAttention2D(dim=D, kernel_size=25, dilation=1, num_heads=4).to(device)
optimizer = torch.optim.Adam(model.parameters())

def train(model, optimizer):

    fake_input = torch.randn(N, H, W, D, dtype=torch.float32).to(device)

    out = model.forward(fake_input)

    loss = out**2

    loss.sum().backward()
    optimizer.step()
    optimizer.zero_grad()
confucianism72 commented 10 months ago

what I found:

the memory bottleneck is largely caused by attention matrix and softmax.

in natten2d.py

        attn = natten2dqkrpb(q, k, self.rpb, self.kernel_size, self.dilation)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

You mentioned that newer version of NATTEN is in your recent working schedule.

I recommend considering using FlashAttention which is also supported in PyTorch standard attention operator.

As far as I know, FlashAttention does not construct one attention matrix explicitly, thus memory is saved.

alihassanijr commented 10 months ago

Thank you for so much for looking into this;

Yes the memory overhead is expected, not just of Neighborhood Attention, but any BMM-style implementation of attention, because they store attention weights in global memory.

Fused kernels (like FlashAttention) are our primary goal, and we will likely not continue to maintain our standard BMM-style kernels (naive, tiled, and gemm kernels), but those will likely be kept in NATTEN for older architectures.

alihassanijr commented 7 months ago

Fused neighborhood attention is now available (#111).