SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
322 stars 24 forks source link

circular/wrapping attention #142

Open fzimmermann89 opened 1 month ago

fzimmermann89 commented 1 month ago

Maybe it is to niche, but we would be interested in a fused circular windowed attention.

Currently, we pad q,k,v, use the fused kernel and crop. It would either help if k,v could have different dimensions than q (avoiding some unnecessary calculations), or if the iterators within the fna kernels could use wrapping instead of masking.

alihassanijr commented 1 month ago

Could you clarify what a circular windowed attention exactly means?

I think many types of attention masks that are programmable can be directly implemented through FNA just by writing a custom mask.

Different shapes for Q and KV might be tricky though, we make that assumption all throughout FNA. Especially when you factor in dilation and causal masking. Tricky but not impossible.

fzimmermann89 commented 1 month ago

In a 2d 'image', I would like the top-left pixel to also attend the top-right, bottom-right, bottom-left pixels, similar to what a convolution with circular padding would do. So the attention neighborhood should wrap around at the borders, instead of being smaller at the edges.

So

q=torch.ones(1,5,5,1,8).cuda()
k=torch.ones(1,5,5,1,8).cuda()
v=torch.zeros(1,5,5,1,8).cuda()
v[:,0,0,:,:]=1 # set top left pixel

x=na2d(q,k,v,kernel_size=3, mode='wrap')

print(x[0,:,:,0,0])
tensor([[0.1111, 0.1111, 0.0000, 0.0000, 0.1111],
        [0.1111, 0.1111, 0.0000, 0.0000, 0.1111],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1111, 0.1111, 0.0000, 0.0000, 0.1111]])
# top left pixel of v influences also 

This can currently be achieved by first padding q,k and v, performing the NA, and cropping.

I was wondering if there is a more performant implementation possible by having the iterators within the cuda kernels wrapping but, tbh, I haven't spent the time to understand how and where the kernels are generated and if this change would be feasible ,

alihassanijr commented 1 month ago

Thank you for explaining. This should be possible to do by modifying NATTEN kernels, specifically FNA, but doing it and saving the compute will be more challenging.

FNA masks are defined here: standard NA, causal NA.

I'd note that you'd probably need to modify two more conditions in the kernels themselves here and here -- this is where the masking actually takes place. For NA we'd only check if the key/value coordinate falls within the query coordinate's window_start and window_end, but in your case it would probably need an extra condition to see if it's one of the corner cases.

However, as I mentioned, one big limiter in your formulation might be that it would force you to not skip most of the computation.

I was wondering if there is a more performant implementation possible by having the iterators within the cuda kernels wrapping

This could work, but it will definitely require more work. Instead of changing the NA masking and behavior, you just make the kernels do an on the fly padding and work on the padded problem instead like they do now, and save on the extra ops, and activations.

The main GEMM iterators for FNA are here and [here](https://github.com/SHI-Labs/NATTEN/blob/main/csrc/include/natten/cuda/fna/iterators/predicated_tile_access_iterator_residual_last.

Another challenge in this one is that if you don't pad Q, and only pad KV, then there's a lot of assumptions (mostly about Q and KV having the same spatial size) throughout FNA that break that would need adjustment. If you pad Q and KV together, then you'd need to modify the epilogue iterator here and avoid storing outputs for the implicitly padded pixels.

But the more significant challenge in my opinion is the backward pass. Any non static padding in attention (NA is actually one of them) will affect the backward pass and that won't be solved by custom iterators. In the forward pass it's always simple: tokens/pixels can just attend to anything and anywhere without bookkeeping. In the backward pass though, KV tokens/pixels need to be dot-producted by the queries that attended to them in the forward pass, and that makes everything more complicated.

I realize this might be a lot of information that doesn't necessarily give you a hint where to start, but I guess at this point I'd only say that it might require quite a lot of work to implement this idea regardless of the approach. And I would guess that modifying the attention mask and not saving any compute is always the easiest solution, but it won't save you any compute.

In theory you should be able to modify most fused attention kernels and implement your own custom mask if you don't care about saving compute that much (and to be honest, unless your attention window size is significantly smaller than your input size, or you dilate your windows, you probably wouldn't save much on runtime anyway even if you save on compute).

In either case, I can definitely give feedback, and probably better feedback if you're doing it within FNA or FMHA since I'm familiar with those the most.

fzimmermann89 commented 1 month ago

Thank you for the detailed response, I will try to understand the iterator logic :)

A bit more background information We are working with 3D data, ~15-20 patches in each dim. At window size 5, this would require us to pad by 4 patches in each dim, resulting in more than twice the compute.

In our domain (MR images), it is somewhat important to use wrapping in at least 1-2 dimensions, as the data is wrapped/looping in these dimensions.

So, modifying the illustrations from your paper with paint, we would like to achieve this: image for the edge cases, instead of stopping the sliding window at the edge.

For 1D, I tested if wrapping iterators would be enough for both forward and backward pass by this simple test

from natten.autotuner import autotune_fna
from natten.utils import check_all_args
from natten import libnatten

class WrappedFusedNeighborhoodAttention1D(torch.autograd.Function):
    @staticmethod
    def forward(
        ctx,
        query,
        key,
        value,
        kernel_size_,
        dilation_,
        scale: float,
    ):
        kernel_size, dilation, _ = check_all_args(
            1, kernel_size_, dilation_, False)

        tiling_config, tiling_config_backward = autotune_fna( 1, query, kernel_size, dilation, False)
        pad = kernel_size[0]//2 * dilation[0]

        # circular padding of qkv, equivalent to using wrapping iterator for reading
        query_pad, key_pad,value_pad = (torch.nn.functional.pad(x,(0,0,0,0,pad,pad),'circular').float().contiguous() for x in (query,key,value))

        output_pad = torch.empty_like(value_pad)
        logsumexp = torch.empty(
            query_pad.shape[:-1], dtype=torch.float32, device=query.device
        )

        libnatten.na1d_forward(
            output_pad,
            query_pad,
            key_pad,
            value_pad,
            None,
            logsumexp,
            kernel_size,
            dilation,
            (False,),
            scale,
            *tiling_config,
        )

        # cropping of output to remove the special cases done inside libnatten
        output = output_pad[...,pad:-pad,:,:].clone()

        ctx.save_for_backward(query, key, value, logsumexp, output)
        ctx.kernel_size = kernel_size
        ctx.dilation = dilation
        ctx.scale = scale
        ctx.tiling_config = tiling_config
        ctx.tiling_config_backward = tiling_config_backward
        ctx.pad = pad

        return output

    @staticmethod
    def backward(ctx, grad_out):
        query, key, value, logsumexp, output = ctx.saved_tensors
        pad = ctx.pad

        # padding of qkv: equivalent to using wrapping iterator to read
        query_pad, key_pad, value_pad, = (torch.nn.functional.pad(x,(0,0,0,0,pad,pad),'circular').contiguous().float() for x in (query,key,value,))

        # the values outside the valid range, i.e. for results obtained as a special edge case inside libnatten do not matter.
        d_output,output_pad = (torch.nn.functional.pad(x,(0,0,0,0,pad,pad),'constant').contiguous().float() for x in (grad_out,output))

        d_query = torch.empty_like(query_pad)
        d_key = torch.empty_like(key_pad)
        d_value = torch.empty_like(value_pad)

        q_tile_shape, k_tile_shape, kv_splits, compute_delta_with_pt = (
            ctx.tiling_config_backward
        )
        libnatten.na1d_backward(
            d_query,
            d_key,
            d_value,
            query_pad,
            key_pad,
            value_pad,
            output_pad,
            d_output,
            logsumexp,
            ctx.kernel_size,
            ctx.dilation,
            (False,),
            ctx.scale,
            q_tile_shape,
            k_tile_shape,
            kv_splits,
            compute_delta_with_pt,
        )

        # this corrosponds to a wrapping write to d_query, d_key, d_value
        def wrapped_output(x):
            result = x[...,pad:-pad,:,:]
            result[...,:pad,:,:]+=x[...,-pad:,:,:]
            result[...,-pad:,:,:]+=x[...,:pad,:,:]
            return result
        return *(wrapped_output(x) for x in (d_query,d_key,d_value)), None, None, None, None, None, None, None

## Grad Check
Q=K=V=10
D=4
B=2
H=2

#we use double for the numerical grad check
q=(torch.rand(B,Q,H,D,device="cuda")*10-5).requires_grad_(True)
k=(torch.rand(B,K,H,D,device="cuda")*10-5).requires_grad_(True)
v=(torch.rand(B,V,H,D,device="cuda")*10-5).requires_grad_(True)

def wrapped1d(q,k,v):
    return WrappedFusedNeighborhoodAttention1D.apply(q,k,v,3,1,1)

torch.autograd.gradcheck(wrapped1d,(q,k,v),eps=1e-1,rtol=1e-2,fast_mode=True)

So, both the read access to query, key and value should wrap as well as the write access to d_query, d_key, and d_value.

fzimmermann89 commented 1 month ago

Unfortunately, I am quite lost in the code.

Do you think it would be easier to maybe somehow add an option to just skip the computation of the edges within libnatten? So instead of clamping the sliding window to the edges, only compute the "steady-state" where the full window is inside the volume?

Then the padding, wrapping and unwrapping could be done in few python instructions, without wasting many computations for the edge case handling within the kernels that will be discarded anyways?

alihassanijr commented 4 weeks ago

Sorry for the late response.

Do you think it would be easier to maybe somehow add an option to just skip the computation of the edges within libnatten? So instead of clamping the sliding window to the edges, only compute the "steady-state" where the full window is inside the volume?

Then the padding, wrapping and unwrapping could be done in few python instructions, without wasting many computations for the edge case handling within the kernels that will be discarded anyways?

I see -- so you're saying that if NATTEN could accept cases where Q is smaller than KV because KV alone is padded, then you'd save on padding Q and cropping the final output?

I think that can be done -- it's just that I still think modifying the mask might end up giving you a faster solution. That's because explicit padding to save on compute has rarely been a performant solution in my experience, at least in 2D and higher rank spaces. You're essentially occupying more resources and effectively spending more energy on saving compute and that usually rises above the compute (and time) you'd save pretty easily.

I gave it another thought, and I think modifying the FNA mask could actually work best in your case. CTAs targeting corner tokens/pixels will inevitably compute their entire row of attention weights, but you could still gain some of the efficiency back. But what remains true is that you would never explicitly pad any of the tensors.

I might be misunderstanding your method though, so just to clarify, are you saying that if NATTEN could skip the corner cases where the query isn't exactly centered and just compute a smaller output, then you can avoid padding all together, or you would pad fewer things?

fzimmermann89 commented 3 weeks ago

Sorry for the late response.

Do you think it would be easier to maybe somehow add an option to just skip the computation of the edges within libnatten? So instead of clamping the sliding window to the edges, only compute the "steady-state" where the full window is inside the volume? Then the padding, wrapping and unwrapping could be done in few python instructions, without wasting many computations for the edge case handling within the kernels that will be discarded anyways?

I see -- so you're saying that if NATTEN could accept cases where Q is smaller than KV because KV alone is padded, then you'd save on padding Q and cropping the final output?

And natten could save on calculating the attention output all the padded Qs, which currently in 2D and 3D can make up half of the compute.

I think that can be done -- it's just that I still think modifying the mask might end up giving you a faster solution. That's because explicit padding to save on compute has rarely been a performant solution in my experience, at least in 2D and higher rank spaces. You're essentially occupying more resources and effectively spending more energy on saving compute and that usually rises above the compute (and time) you'd save pretty easily.

I am sorry, most likely, I don#t really get this -- Most likely because I still haven't understood how FNA actually operates. I was under the assumption that the compute is done bockwise. So while the output for a certain Q is calculated, not all possible K/V locations are available in the kernel. The mask only applies the mask withing this bock. So I fail to see how to do a mask as in the image above: That one edge cue can attend to KV on the opposite side. Is my mental picture wrong?

I gave it another thought, and I think modifying the FNA mask could actually work best in your case. CTAs targeting corner tokens/pixels will inevitably compute their entire row of attention weights, but you could still gain some of the efficiency back. But what remains true is that you would never explicitly pad any of the tensors.

In 2D and 3D, a lot of pixels are corner cases, so a lot of CTA would have to do full attention?

I might be misunderstanding your method though, so just to clarify, are you saying that if NATTEN could skip the corner cases where the query isn't exactly centered and just compute a smaller output, then you can avoid padding all together, or you would pad fewer things?

If it also skips on compute and never writes outside the "valid" region of Q, it would mean I only have to pad KV.

But you have much more experience with both neighborhood attention and CUDA in general, so just to clarify what the goal is:

I want to have 2D and 3D neighborhood attention. For some (2d:1, 3d:2-3) of the directions, depending on the exact problem, the Qs should attend a KVs in by wrapping around. (For the other direction, whatever behavior is fastest can be used)

For now, I would use the approach above, although this results in both overhead by padding and compute. Or I would need quite a bit of help getting that to work in a more efficient way..

alihassanijr commented 3 weeks ago

I am sorry, most likely, I don#t really get this -- Most likely because I still haven't understood how FNA actually operates. I was under the assumption that the compute is done bockwise. So while the output for a certain Q is calculated, not all possible K/V locations are available in the kernel. The mask only applies the mask withing this bock. So I fail to see how to do a mask as in the image above: That one edge cue can attend to KV on the opposite side. Is my mental picture wrong?

No you are correct. Q, and the attention output are tiled together, and K and V are tiled together, so in both forward and backward pass, you're dealing with Q/O tiles and K/V tiles. In forward, you typically parallelize across Q/O, meaning every CTA loads a certain Q tile, loops through all valid K/V tiles, and stores the O tile corresponding to the Q tile. In backward it's more complicated, but usually the opposite.

In 2D and 3D, a lot of pixels are corner cases, so a lot of CTA would have to do full attention?

Kind of. It is possible to predict which KV tiles will be fully masked, so you can skip them, but it's a more complicated change.

What I would note though is that it is simply an illusion that time spent is time spent on compute. It is actually almost never the case. The kernel has to load and store bytes, that's time spent of data movement. And there's also other non-compute instructions, which at higher levels are just the indexing and setting up the shared memory, softmax, etc.

Therefore, reducing the number of KV iterations in the forward pass (and Q iterations in the backward pass), or rather, the number of KV tiles visited for a given query tile, isn't necessarily going to give you a faster kernel.

Because of this, my suggestion is to try and not save any of the computation (unless you're dealing with really large problem sizes) and just implement this pattern as an implicit mask. The reason is, I suspect even if we manage to figure out a way to save compute, the additional overhead from the extra bound checking might already undo the compute time you save, in addition to having to spend more time implementing it.

I want to have 2D and 3D neighborhood attention. For some (2d:1, 3d:2-3) of the directions, depending on the exact problem, the Qs should attend a KVs in by wrapping around. (For the other direction, whatever behavior is fastest can be used)

For now, I would use the approach above, although this results in both overhead by padding and compute. Or I would need quite a bit of help getting that to work in a more efficient way..

Yeah so unfortunately I might not have time to help implement this specific mask in the near future, but I think you might be able to modify FNA directly and get it to perform your desired pattern. However, some of these changes are easier to implement, but might end up doing more computation, so that way you might end up with similar performance, if not slightly worse, than pad and crop overall (even though you'll definitely save on memory usage because of fewer activations.)

So to summarize, here's some of those ideas that could work, in order of time required, from least time consuming to most:

  1. Modifying the "mask" in FNA so it performs your operation instead

    • Quick to implement
    • Hard to optimize, because dimensions that warp around will essentially skip no computation
  2. Doing 1. and modifying the FNA kernels themselves to implement your idea

    • Much more time consuming to implement, but I can still help you along the way
    • Will be able to skip computation in the same way FNA does for NA, so perf will probably be similar to standard NA.
    • This will be the best for perf if you're dealing with large problem sizes (if the number of compute blocks you skip over is noticeably greater than total.)
  3. Doing 1. in something already faster like FAv2

    • Giving up on saving (most) of the compute, but instead using an attention kernel that can already achieve better perf than FNA, and writing an implicit mask for it.
    • I might not be able to help as much because I'm less familiar with it.