SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
362 stars 29 forks source link

Registers, self-cross attention #82

Closed Birch-san closed 8 months ago

Birch-san commented 10 months ago

Sometimes a query needs to attend to (a small quantity of) keys outside its local neighbourhood! Even outside its modality!

Vision Transformers Need Registers shows us that it can be helpful to include in our sequence a small number of registers, for depositing/withdrawing global information (see DINOv2 source).

image

DINOv2 also splices a CLS token into the key sequence, alongside all the vision tokens! This doesn't have coordinates / isn't even part of the vision modality, so isn't compatible with neighbourhood attention.

Imagen introduced self-cross attention, where queries may self-attend or cross-attend, as the key sequence consists of both vision-mode and text-mode tokens, concatenated.

We implement the cross attention by concatenating the text embedding sequence to the key-value pairs of each self-attention layer in the base 64 × 64 and 64 × 64 → 256 × 256 models
We found this to improve both fidelity and image-text alignment with minimal computational costs.

DeepFloyd IF implements self-cross attention too.

Attention Is Off By One (aka MultiheadAttention's add_zero_attn=True) shows us that it can be useful to give the attention head a way to attend to "nothing".
Language models achieve this by including a BOS token (i.e. a token with no predictive power) in all sequences.

====

So, the ask is this:

Would it be possible for natten2dqk to support passing in additional key tokens, to which all queries would attend?
And likewise for natten2dav to support passing in additional values corresponding to those keys.

With pure-PyTorch I think it might look like this:
https://github.com/Birch-san/natten-fwd-ad/commit/be295395a5c537da9070c12a4f1f7d3299b3639b

The use-cases would be things like:

Oh, and thanks for your responses to previous feature requests. 🙂

alihassanijr commented 10 months ago

Thank you for bringing this up. This would be a useful feature that I'd want included in NATTEN.

Cross attending between queries and a custom key-value pair (register, BOS, CLIP tokens, IP Adapter tokens, etc) is just a BMM, and I would guess it would be more performant if it's done in a separate kernel call. Adding an extra KV set to the naive kernels is easy, but not as much to the implicit GEMM kernels, since their logic is heavily tied to the assumptions made about the problem.

I can have it done with a standard cutlass GEMM/GEMV, or even just a pytorch BMM (I'd expect PT to be more performant because NATTEN does not have a profiler, and it would be counter-intuitive unless it's specifically intended for kernels).

Since softmax is done outside of NATTEN, we can either concat the two resulting attention weights, apply softmax, and split, then proceed to do the standard AV operation for the scaled neighborhood attention weights, and another BMM for the cross attention weights.

The only thing I don't like about this solution is the extra concat and split; and because NATTEN kernels assume contiguous inputs and outputs, we can't really play around with views.

But what we might be able to do is write a custom softmax operator that takes in two input tensors instead of one, and computes the statistics as if the two were concatenated. At that point, it can scale them inplace and we'll never have to concat or split the attention weights.

When we eventually add fused kernels, this will no longer be possible, but we'll cross that bridge when we get to it.

Let me know if you have any thoughts; I'll start looking for a good softmax/reduction reference kernel.

Birch-san commented 10 months ago

ah yeah, I had wondered about a concat-and-split approach, but my first reaction was "hmm seems like a lot of IO, and we know that's the worst part of attention".

but a custom softmax could fix that, yeah. that sounds like a good solution. not immediately sure how to compute the jvp of it though, so might need to follow-up later to restore forward-mode autodiff support.

would the custom softmax output two attn probabilities (one for the neighbourhood, one for the bonus)? the user would pass the neighbourhood to natten2dav as usual, and the bonus probabilities to a normal BMM?
or are you saying the custom softmax would be even more custom (and would do the BMM too)? this might make it harder to derive the backward pass and jvp. it might also be hard to achieve the same matmul performance as PyTorch's BMM.

I guess generally: any time we can solve the problem with existing pytorch operations without a performance cost, then re-using PyTorch bits should probably be our preference, in order to piggyback on their existing support for things like jvp,vjp,nested tensor, and any optimizations they've done). it's fine for the user to have to write more lines of code, so I wouldn't consider API to be a strong enough reason on its own to fuse operations.

alihassanijr commented 10 months ago

That's a good point; we'd have to figure out the JVP on top of a backwards kernel.

The custom softmax idea I had was doing the scaling in place; meaning there's either row-wise reductions, then scaling the input by the softmax scale. But this could be more difficult to implement. The alternative is to input and output two tensors, which I think will still work fine; still way better than concat and split, because in both of these cases, the only thing extra is allocating the extra tensors, but because of torch's memory planning, it'll be unlikely that we'd feel that impact anything much.

Alignment would be another challenge, given that the two input/output tensors will almost never be more than 1-aligned, at least the NA weights (because they're always a power of the kernel size, and kernel sizes in NA are by definition odd numbers, therefore never even a multiple of two, let alone a multiple of a power of two.) I'll keep thinking about this. It might be worth modifying all the kernels in NATTEN so that they can handle attention weight tensors that are not contiguous. If we get past that, we can just do a standard torch softmax without a concat/split (at least not explicit ones). And if we're lucky and kernel_size ** 2 + num_extra_kv_tokens is a power of two, it could make the softmax run faster than without any extra tokens.

I don't think we'd benefit much from a softmax+GEMM; we'd still have to profile the GEMM to get anywhere near the perf of a torch GEMM, and NATTEN isn't really set up to do profiling.

I guess generally: any time we can solve the problem with existing pytorch operations without a performance cost, then re-using PyTorch bits should probably be our preference, in order to piggyback on their existing support for things like jvp,vjp,nested tensor, and any optimizations they've done). it's fine for the user to have to write more lines of code, so I wouldn't consider API to be a strong enough reason on its own to fuse operations.

Yes; that's fair to say. PyTorch has pretty strong profiling, and as long as the ops we're interested in are implemented directly in cuBLAS, cuDNN, or any other highly optimized CUDA libraries that torch would have a backend for, going with torch would probably be the best choice.

alihassanijr commented 10 months ago

This will probably be easier than I thought. We don't need a custom softmax or anything like that. Changing all the existing kernels to support attention weights that are views is much easier.

I have a proof of concept working for QK forward and backward. I'll finish AV, and then add support for JVP, and nested tensors (as long as torch.bmm supports it, we should be fine).

Birch-san commented 10 months ago

brilliant! 😄 hope it goes smoothly. will be interesting to see how close local attn + global registers comes to the performance of global attn.

alihassanijr commented 10 months ago

Okay, it's merged. On my setup I'm seeing around 18% improvement in latency in most cases, compared to explicitly concat/splitting.

alihassanijr commented 10 months ago

Before I forget: you'll need to build from source again until the next release.

If you're not extremely short on disk space and can throw more CPU cores at building, you can break apart the kernel instantiations into more source files and use more workers to build quicker. You can have it rebuild everything in a few minutes, but the binary size might be as large as double (~ 500MB).

Birch-san commented 10 months ago

thanks! the API looks lovely. hm, I'm not getting the same result as in my pure-PyTorch registers/self-cross implementation though.

NATTEN:
https://github.com/Birch-san/natten-fwd-ad/commit/c135012a3fa8cafee6aa67823e2cd27934552dbd

pure-PyTorch:
https://github.com/Birch-san/natten-fwd-ad/commit/be295395a5c537da9070c12a4f1f7d3299b3639b

test script:
https://github.com/Birch-san/natten-fwd-ad/blob/registers/script/demo_registers.py

I guess I could try concat-and-split and see if it matches either of those.

alihassanijr commented 10 months ago

Did you happen to run the NATTEN unit tests on your end? If they pass, then it must be something that's implemented differently. I'm looking at your commit right now, and it looks okay, but there might be something I'm missing.

The only thing that I'm finding different is that you're repeating the registers along the batch dim, which should be fine given that the registers are fed directly into torch BMMs, which are fine with views, and if not will just copy the tensor if necessary.

But just to clarify, reg_k and reg_v are 4-D tensors, correct? (i.e. [B, heads, seq_len, dim]).

Birch-san commented 10 months ago

I tried make test just now. some XFAILs, not sure how usual that is:

(venv-next) birch@tree-diagram:~/git/NATTEN (main)$ make test
pytest -v -x ./tests
============================================================ test session starts =============================================================
platform linux -- Python 3.11.1, pytest-7.4.4, pluggy-1.3.0 -- /home/birch/git/natten-fwd-ad/venv-next/bin/python3.11
cachedir: .pytest_cache
rootdir: /home/birch/git/NATTEN
collected 33 items                                                                                                                           

tests/test_na1d.py::NA1DTests::test_autograd_cpu PASSED                                                                                [  3%]
tests/test_na1d.py::NA1DTests::test_autograd_cuda_gemm PASSED                                                                          [  6%]
tests/test_na1d.py::NA1DTests::test_autograd_cuda_naive PASSED                                                                         [  9%]
tests/test_na1d.py::NA1DTests::test_cpu_vs_cuda PASSED                                                                                 [ 12%]
tests/test_na1d.py::NA1DTests::test_fwad_cpu PASSED                                                                                    [ 15%]
tests/test_na1d.py::NA1DTests::test_fwad_cuda_gemm PASSED                                                                              [ 18%]
tests/test_na1d.py::NA1DTests::test_fwad_cuda_naive PASSED                                                                             [ 21%]
tests/test_na1d.py::NA1DTests::test_invalid_dilation XFAIL                                                                             [ 24%]
tests/test_na1d.py::NA1DTests::test_invalid_kernel_size XFAIL                                                                          [ 27%]
tests/test_na1d.py::NA1DTests::test_nested_forward_cpu PASSED                                                                          [ 30%]
tests/test_na1d.py::NA1DTests::test_nested_forward_cuda PASSED                                                                         [ 33%]
tests/test_na2d.py::NA2DTests::test_autograd_cpu PASSED                                                                                [ 36%]
tests/test_na2d.py::NA2DTests::test_autograd_cuda_gemm PASSED                                                                          [ 39%]
tests/test_na2d.py::NA2DTests::test_autograd_cuda_naive PASSED                                                                         [ 42%]
tests/test_na2d.py::NA2DTests::test_autograd_cuda_tiled PASSED                                                                         [ 45%]
tests/test_na2d.py::NA2DTests::test_cpu_vs_cuda PASSED                                                                                 [ 48%]
tests/test_na2d.py::NA2DTests::test_fwad_cpu PASSED                                                                                    [ 51%]
tests/test_na2d.py::NA2DTests::test_fwad_cuda_gemm PASSED                                                                              [ 54%]
tests/test_na2d.py::NA2DTests::test_fwad_cuda_naive PASSED                                                                             [ 57%]
tests/test_na2d.py::NA2DTests::test_fwad_cuda_tiled PASSED                                                                             [ 60%]
tests/test_na2d.py::NA2DTests::test_invalid_dilation XFAIL                                                                             [ 63%]
tests/test_na2d.py::NA2DTests::test_invalid_kernel_size XFAIL                                                                          [ 66%]
tests/test_na2d.py::NA2DTests::test_nested_forward_cpu PASSED                                                                          [ 69%]
tests/test_na2d.py::NA2DTests::test_nested_forward_cuda PASSED                                                                         [ 72%]
tests/test_na3d.py::NA3DTests::test_autograd_cpu PASSED                                                                                [ 75%]
tests/test_na3d.py::NA3DTests::test_autograd_cuda PASSED                                                                               [ 78%]
tests/test_na3d.py::NA3DTests::test_cpu_vs_cuda PASSED                                                                                 [ 81%]
tests/test_na3d.py::NA3DTests::test_fwad_cpu PASSED                                                                                    [ 84%]
tests/test_na3d.py::NA3DTests::test_fwad_cuda PASSED                                                                                   [ 87%]
tests/test_na3d.py::NA3DTests::test_invalid_dilation XFAIL                                                                             [ 90%]
tests/test_na3d.py::NA3DTests::test_invalid_kernel_size XFAIL                                                                          [ 93%]
tests/test_na3d.py::NA3DTests::test_nested_forward_cpu PASSED                                                                          [ 96%]
tests/test_na3d.py::NA3DTests::test_nested_forward_cuda PASSED                                                                         [100%]

============================================================== warnings summary ==============================================================
tests/test_na1d.py::NA1DTests::test_nested_forward_cpu
  /home/birch/git/NATTEN/tests/test_na1d.py:449: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
    query = torch.nested.nested_tensor(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================ 27 passed, 6 xfailed, 1 warning in 215.95s (0:03:35) ============================================
Birch-san commented 10 months ago

But just to clarify, reg_k and reg_v are 4-D tensors, correct? (i.e. [B, heads, seq_len, dim]).

# registers before repetition [1, heads, seq_len, head_dim]
reg_k.shape
torch.Size([1, 2, 16, 64])
reg_v.shape
torch.Size([1, 2, 16, 64])

# registers after repetition [B, heads, seq_len, head_dim]
reg_k.shape
torch.Size([2, 2, 16, 64])
reg_v.shape
torch.Size([2, 2, 16, 64])

# qkv before rearrange [B, h, w, channels]
qkv.shape
torch.Size([2, 32, 32, 384])

# q,k,v after rearrange [B, heads, h, w, head_dim]
q.shape
torch.Size([2, 2, 32, 32, 64])
k.shape
torch.Size([2, 2, 32, 32, 64])
v.shape
torch.Size([2, 2, 32, 32, 64])
alihassanijr commented 10 months ago

Okay these all look correct to me. Let me try running your example on my end and see where it goes.

Birch-san commented 10 months ago

I've now written a concat-and-split NATTEN block.
https://github.com/Birch-san/natten-fwd-ad/commit/127649ef3511dbe9febcac641c242f111c39e8bc

the built-in NATTEN cross-attention does match the concat-and-split approach. that's encouraging.

maybe my pure-PyTorch implementation is the one that's wrong then? it doesn't use concat-and-split, it concatenates an "attend to all registers" mask onto the end of a neighbourhood mask, and also concats keys and values.

alihassanijr commented 10 months ago

Okay that's good to hear. Yeah I've been looking at the mask trying to find a needle, but nothing so far. Can't say what could be the issue.

Birch-san commented 10 months ago

okay I've now implemented a pure-PyTorch concat-and-split.
https://github.com/Birch-san/natten-fwd-ad/commit/30ff09e9d2c26ff6e6f7041290613a7db5747824

NATTEN built-in xattn matches NATTEN concat-and-split and pure-PyTorch concat-and-split.

whereas my pure-PyTorch concat-and-mask-sdp does not match pure-PyTorch concat-and-split.

so perhaps my pure-PyTorch concat-and-mask-sdp is the only implementation here that was wrong?

Birch-san commented 10 months ago

okay I figured it out.

for some reason bias was enabled on the Linear projections, for the pure-PyTorch concat-and-mask class only.
https://github.com/Birch-san/natten-fwd-ad/commit/11503b6429eb32f8963c3dc358a8b382d4d9ce54

giving it bias=False like the other blocks, finally makes all four implementations agree. 🙂

alihassanijr commented 10 months ago

Thanks for confirming; sorry I couldn't get a chance to run your reference on my end. I did notice immediately though that the unit tests didn't cover everything, so I started adding new ones. But yeah the additional KV stuff were all passing the tests.

Birch-san commented 10 months ago

no worries, thanks for reading the code in any case.

btw, is it possible to support broadcasting of the register's batch dim?

I make my registers like this:

registers.shape
torch.Size([16, 128])

reg_kv: FloatTensor = linear(registers, kv_proj_weight, kv_proj_bias)
reg_kv.shape
torch.Size([16, 256])

reg_k, reg_v = rearrange(reg_kv, "r (t nh e) -> t 1 nh r e", t=2, e=self.d_head)

reg_k.shape
torch.Size([1, 2, 16, 64]) # [batch, head, seq, dim]

my registers end up with a singleton batch dim. it would be nice if this could be broadcast.

but check_additional_keys disallows attempting broadcast:

Shape mismatch between input tensor and additional tokens; they must match in batch size, heads, and dim per head. Got input_tensor.shape=torch.Size([2, 2, 32, 32, 64]), additional_keys.shape=torch.Size([1, 2, 16, 64]).
image

plan B was to try expanding the batch dim to explicitly match my batch-of-2.

reg_k, reg_v = (reg.expand(qkv.size(0), *reg.shape[1:]) for reg in (reg_k, reg_v))

this fails due to spanning of contiguous subspaces:

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
  File "/home/birch/git/NATTEN/src/natten/ops.py", line 34, in qk_cross_forward
    key_transposed_bmm_view = key.view(
                              ^^^^^^^^^
image

ultimately I had to resort to repeating instead of expanding:

reg_k, reg_v = (reg.repeat(qkv.size(0), *(1,)*(reg.ndim-1)) for reg in (reg_k, reg_v))

whereas expanding worked in my pure-PyTorch implementation.

alihassanijr commented 10 months ago

The reason why expand isn't working is that torch.bmm only accepts 3D tensors, which means we have to merge the heads and batch axes, but batch is a view, so merging it with heads is not possible. I'm trying to think whether there's a way we can avoid having to merge those two.

alihassanijr commented 10 months ago

Okay, torch.matmul to the rescue! It does appear to call some extra kernels when the batch is being repeated through a view, but with very minute latency, enough to be within the margin of error between different runs on my end.

I'll let all the tests run and start a PR for that soon.

Birch-san commented 10 months ago

oh brilliant!

is bmm on a 3D tensor faster than matmul on a 4D tensor?
should bmm be used wherever possible? is torch likely to do that internally anyway?

I guess what I'm wondering is: should the user prefer to repeat their cross-attn tensor over batch dims, in order to make it possible to view it as a 3D tensor and be eligible for a batched-matmul fast-path?

alihassanijr commented 10 months ago

On my device, it does look like torch.matmul and torch.bmm end up with the same graph, even with the view change (comparing attempting to merge the batch and heads dims, and not merging them at all and calling matmul on 4D tensors). I'm seeing it call the same GEMM kernels, and nothing else.

When I switch the extra KV tokens from a batched tensor to single-batch that's been broadcast, again it's still the same GEMM kernel and config, same latency, but this time there's two extra kernel calls that would only make sense to be transformations, but those are 0.04 % of the overall latency in my specific use case, so well within the margin of error and enough to be unnoticeable. Those transformations will likely also happen if you were doing any BMM implementation of attention. But again my obersvations are limited to the use case I'm working with, and the specific card and software, so the behavior might vary.

But overall I think views + matmul is going to work out pretty nicely, and as good as it can get when the two attention branches and the the QK and AV ops are not merged into a single op.

alihassanijr commented 8 months ago

Closing this due to inactivity; feel free to reopen if there's anything else.

Xynonners commented 5 months ago

Hi, if I'm understanding this correctly the current additional_keys/values allows NA to attend to extra tokens outside of the original self-attn q/k/v.

I've been thinking though, with architectures like FIT https://arxiv.org/abs/2305.12689 (grouped perceiver architecture), would it be possible to invert this? aka have registers attend to multiple different neighborhoods?