SHI-Labs / NATTEN

Neighborhood Attention Extension. Bringing attention to a neighborhood near you!
https://shi-labs.com/natten/
Other
375 stars 30 forks source link

Arbitrary masks #124

Open Birch-san opened 7 months ago

Birch-san commented 7 months ago

Congratulations on shipping FNA backward! Looking forward to using it.

On another note: would it be possible to support arbitrary masking?

MaskDiT outperformed regular DiT, with a 70% reduction in training time, by masking-out 50% of patches (and predicting those via an auxilliary reconstruction loss + encoder-decoder architecture):
https://arxiv.org/abs/2306.09305v2

Perhaps there's also an opportunity to exploit sparsity (some regions may not require computation), but I think even without such optimization, arbitrary masking would still be useful due to enabling new training objectives.

Note: arbitrary masks/biases is something the PyTorch team are attempting with templated attention:
https://github.com/pytorch/pytorch/pull/121845
(hence the questions in https://github.com/SHI-Labs/NATTEN/issues/120 about whether a generic method could achieve the same effect as NATTEN performantly).

alihassanijr commented 7 months ago

Thank you!

Complicated but excellent question!

Supporting any type of masking is possible, but the efficiency gain is not so straightforward.

I can't think of a way to guarantee that completely masked out attention blocks are skipped if your mask is truly arbitrary (i.e. dropout), and even if you get your kernel to compute that before proceeding to compute every block, I don't really think it'll improve performance much, because I would guess the predication logic for that won't be cheap, and probably not easy to parallelize in the backward pass.

Another issue is that if your mask is in tensor form, that creates another issue because then you're gather yet another set of elements from global memory without being able to guarantee alignment or contiguity (in 2D and 3D), and that's just going to slow things down more.

However, if there are certain assumptions that can be made about the mask, then it could actually be trivial to implement that for FNA. Right now there's two masks actually: causal and standard NA. Adding more, as long as they meet those requirements, is trivial. This also extends to efficiency gains; FNA figures out exactly which tiles need to be skipped, and which are not, and it can still parallelize easily.

I guess what I'm trying to say is: I don't think supporting arbitrary masks in a way that relatively maintains performance is a bit challenging, but I'm skeptical if it can lead to efficiency gains (I can get into details as to why if you're interested.)

(Side note: I'll even go as far as to say that I think FNA could even fit into a similar framework as templated attention, because FNA is templated too, just on a different level.)

Birch-san commented 7 months ago

thanks for thinking about it!

I guess another consideration here is that the currently-foreseeable use-case of this would be used during training only. so optimization effort in the forward pass doesn't translate into speedup for any inference use-case.

if sparsity can't practically be exploited for a speedup, that's still okay. it still unlocks an objective that makes the overall training take fewer steps.

how hard is it likely to be for the user to comply with the mask requirements, in the trivial approach that you describe? happy to do some tensor gymnastics to get it to the preferred layout.

alihassanijr commented 7 months ago

So forward is both actually easier to optimize and customize than the backward pass, and since both are needed for training, the implementation has to apply to both.

As far as the mask tensor requirements go, the issue isn't just masking itself, but rather the properties of multi-dimensional attention.

Masks, attention weights, and attention bias all share a common problem: their contiguous mode (last axis) is no longer a single dimension, it is multiple! This is hard on the processor, because the iteration and pointer logic gets more complicated, and worst of all, it needs to be a gather operation because there's just no way to guarantee alignment (that I know of) across multiple modes at the same time. This is actually why our GEMM kernels can't accelerate things much, despite being much better solutions at their core.

We're actually planning to drop RPB support entirely for this reason, and that's why FNA backward doesn't support it. But of course, we know there's better alternatives to RPB for most cases, so that is why we're okay making that decision 🙂 .

If however, your mask isn't expressed as a tensor, and you can just define it as coordinate mappings of sorts, as long as it meets certain requirements, then not only can FNA support it, it will also guarantee that it'll skip computing tiles that are masked entirely, which means perf should be mostly similar to FNA. I'm not sure if this is what you need though.

Birch-san commented 7 months ago

Thanks for explaining.

I think the mask can be expressed as a coordinate mapping, yes. By that do you mean something like starting with a 2D BoolTensor then using .to_sparse() to turn it into a list of coordinates?
We wouldn't need per-head masks, and (if it helps) I think we could give all batch items the same mask.

Having said that…
it looks like the MaskDiT masking is typically implemented as a torch.gather(), which makes the sequence shorter:
https://github.com/Anima-Lab/MaskDiT/blob/master/models/maskdit.py#L88-L163

that makes sense (excludes tokens from FFNs and suchlike), but it destroys the locality on which neighbourhood attention relies.
maybe the gathered sequence can be scattered again onto a tensor of zeros, restoring the 2D canvas prior to neighbourhood attention?

rather than a typical attention mask "these queries shouldn't attend to these keys", it'd serve as a combination of a sparsity map "queries at these coordinates can be ignored (don't compute attentions)" plus a typical mask "no query should attend to any keys at these coordinates".

after neighbourhood attentions are computed, I guess you'd use gather (according to the original mask specification) to shorten the sequence again.

does this make sense? it's a bit convoluted, but probably the smarter training objective would be worth the complexity + mild overhead of the extra scatter and gather.

alihassanijr commented 6 months ago

Sorry I'm responding to this one so late. It was already a crazy last week on campus for me, and I started my internship this week, so still playing catch up.

I think I might have misspoke earlier about the mask being expressible as a mapping. I was more thinking if there's a specific pattern that we know ahead of time (i.e. given a query index, what's the range in KV to which it attends), that would be very easy to build into FNA, and you'd get your own customized kernel, and it would guarantee saving on compute. This would not be applicable if there's any random / non-deterministic masking pattern, because well you can't really compile a kernel at every call.

rather than a typical attention mask "these queries shouldn't attend to these keys", it'd serve as a combination of a sparsity map "queries at these coordinates can be ignored (don't compute attentions)" plus a typical mask "no query should attend to any keys at these coordinates".

I really like this idea. I think this might be easy enough to implement. Although one thing that concerns me is that pretty much all throughout FNA and NATTEN really, we assume we're always dealing with a self attention problem, so Q and KV share their coordinate space and other properties. If there's queries that don't attend to anything, and likewise key/value tokens to which no query attends, I'd assume we could get away with just one vector of length N (it can be different per-batch per-head, doesn't really make a difference here), where N is the total number of tokens in QKV, and it could be either a boolean or float tensor, either way works. It won't be any different from the way we're already reading LSE and delta in backprop. It won't be free, but I don't expect it to be that bad either. It'll be pretty easy to verify though.

Then it's pretty simple, we just mask out all attention weights associated with them. But this won't save you any compute.

I guess in theory you could try and predict if there are any Q or KV tiles that are completely masked out, and just have the kernel skip over those. But I don't know how 1. how likely this event is if masking is done randomly, 2. whether skipping those would save that much on latency. Auto-tuner will likely help here as long as the likelihood that multiple spatially proximate tokens are masked.

One last thing, sorry if I made it extra confusing. When I referred to gather/scatter earlier, I was referring to reading/writing elements one by one, which aside from preventing vectorized accesses, will also potentially suffer additional bound checks and index mapping latency.

Birch-san commented 6 months ago

no worries, and congrats on your internship!

any random / non-deterministic masking pattern

ah okay yeah, not possible with a MaskDiT-style mask. completely random.

pretty much all throughout FNA and NATTEN really, we assume we're always dealing with a self attention problem, so Q and KV share their coordinate space and other properties.

I think continuing to assume self-attention is still reasonable as "tokens near me" are so often "within the same sequence".
I can imagine cross-attention use-cases with shared-coordinate space, e.g. time-iterative processes, where a canvas attends to the canvas from the previous timestep. perhaps diffusion models could self-attend to "current noised image" and additionally cross-attend to "an older, more noised image". and maybe videos or motion-estimation could use this too, but I guess that's what your 3D kernels are for.

I'd assume we could get away with just one vector of length N, where N is the total number of tokens in QKV

yes, I think that's all that's needed. err, do you just mean the total number of tokens in Q, or is there some significance in mentioning KV as well?

boolean or float tensor, either way works

certainly at least boolean. I can only think of contrived use-cases for floats. it's essentially an absolute position bias, where all queries would be biased toward the same keys. maybe useful for regional attention with soft edges.

we just mask out all attention weights associated with them

certainly this solves the "no query attends to those keys", which is the most important part and enables MaskDiT.
I guess you still end up computing attentions for queries that are destined to be discarded. I wondered whether a sparse tensor could make it clear which queries we want to keep / discard, but 50% random masking seems like really poor sparsity; probably expensive to represent as a sparse tensor.

I guess in theory you could try and predict if there are any Q or KV tiles that are completely masked out, and just have the kernel skip over those. But I don't know how 1. how likely this event is if masking is done randomly

yeah, for 50% random masking I think it's unlikely to get many complete kernel skips.
but there's one other use-case that this enables, which would have a much higher hit rate: mixed aspect ratio image training.

the ideal way to handle mixed-aspect is via nested tensors of course, but since there's a lot of operations in the rest of the model that don't support nested tensors (e.g. einops.rearrange and L2 loss): it's still useful to find more ways to handle multi-aspect.
so, imagine we take a batch of images, all different aspects. padded to fit into a standard square via letterboxing or pillarboxing. it would be nice to be able to mask out such padding, and the sparsity here is a lot more exploitable.

even without sparsity optimizations, this masking can still enable downstream speedups.
for example an HDiT model is hierarchical, having both neighbourhood attention and global attention. MaskDiT can speed up the global attention due to gather()'s halving the sequence length. but only if the neighbourhood attention can be masked too. so, even a not-any-faster implementation of masking in NATTEN, unlocks a FLOP speedup in the wider picture, and overall the MaskDiT objective should result in a steps-speedup over training.

and even without speedup, having a way to mask out padding tokens on a padded image is still useful! admittedly you can just train the model to draw the padding, but I'm interested in seeing whether hiding the padding from the model but telling it the aspect ratio, could help it generalize, enabling you to then ask it "fill a square canvas with a wide-aspect-style image"; perhaps it would frame the subjects in a wide-aspect style but extend the canvas in some creative way to fill the square.

alihassanijr commented 6 months ago

yes, I think that's all that's needed. err, do you just mean the total number of tokens in Q, or is there some significance in mentioning KV as well?

No just mentioned QKV for completeness; I actually just call that num_tokens in the kernel for lack of a better term. It applies to all three.

and even without speedup, having a way to mask out padding tokens on a padded image is still useful! admittedly you can just train the model to draw the padding, but I'm interested in seeing whether hiding the padding from the model but telling it the aspect ratio, could help it generalize, enabling you to then ask it "fill a square canvas with a wide-aspect-style image"; perhaps it would frame the subjects in a wide-aspect style but extend the canvas in some creative way to fill the square.

I think we're on the same page; and this should be fairly easy to implement, and I can template it for now so it doesn't affect existing kernels (but NATTEN's binary size and build time is nearly insane now, but I'll figure out how to deal with that later).

Although I'd probably be able to work on it in like two weeks, because right now I only have weekends to get back to research, and trying to catch a conference deadline. Would that work for you?

the ideal way to handle mixed-aspect is via nested tensors of course, but since there's a lot of operations in the rest of the model that don't support nested tensors (e.g. einops.rearrange and L2 loss): it's still useful to find more ways to handle multi-aspect.

This is probably something we could handle better once nested tensor support gets there. I haven't checked the docs in a while, but if the torch C++ API can expose nested tensors, I can't imagine it would be too difficult to get most NATTEN kernels, but especially FNA, to handle it. FNA is mostly agnostic to batch/head, since we issue at least 1 CTA per each, and offset the pointers, so in theory, having different pointers and different problem sizes per batch/head is easy. As a matter of fact, FNA is already doing an implicit padding for dilation (i.e. imagine if you have 15 tokens and dilation 2, then you have 7 tokens in one partition, and 8 in the other.)

Birch-san commented 6 months ago
* Given Q, K and V of shape `[B, *, heads, dim]`, and any valid kernel size, dilation, causal masking,

* Take an additional optional boolean tensor, `qkv_mask` of shape `[B, *, heads]` indicating which inputs are masked,

* Mask out all attention weights corresponding to `[b, i, h]` (`attn[b, i, h, :]`) if `qkv_mask[b, i, h] == true`.

yup, this looks like a good API. as for the meaning of the boolean, could follow torch sdp:

A boolean mask where a value of True indicates that the element should take part in attention.

Would that work for you?

that'd be great, thanks. I'm currently training a little HDiT at home in multiple stages, starting with global attn, then will stack local attn later to reach higher resolutions. I'm gonna try implementing MaskDiT on the global stage first, so there's still preparations on my side before I can try extending the MaskDiT to local attn.

there's also one more complication regarding MaskDiT specifically; I wonder whether there's a way for NATTEN to streamline it or whether I'll have to resort to a workaround.

MaskDiT isn't just an attention mask, it also makes the sequence shorter and randomizes the order of tokens. in other words it creates a random index, indicating 50% of the patches in the vision sequence, and you'd torch.gather() the index to obtain the shortened sequence.

is there a way to give to NATTEN the shortened sequence (its tokens are in random order and half of them are missing) plus the index explaining where to put each token to reconstruct the original sequence? or rather is any optimization possible like that?
otherwise: I can just make a zeroed tensor the same size as the original sequence, and scatter the shortened sequence into the right locations using that index, prior to showing anything to NATTEN.
this is of course getting into super-obscure territory, so I'll understand if it doesn't make sense to support it. but at any rate curious to hear your thoughts and about whether anything could be done there that's faster than just reconstructing the original sequence with a scatter.

alihassanijr commented 6 months ago

Great, I'll try and add that feature soon.

is there a way to give to NATTEN the shortened sequence (its tokens are in random order and half of them are missing) plus the index explaining where to put each token to reconstruct the original sequence? or rather is any optimization possible like that?

I actually think there might be; I think adding scatter/gather iterators to FNA might be possible. This might take longer, but I think as long as I figure out a way to handle tokens that get removed, it should be possible.

One thing that I don't understand though is the logic behind the random ordering. Now I'm not sure what the network being trained looks like, but attention and linear layers are both invariant to permutations (along the token mode), and most activation functions are elementwise, so they're also invariant to permutations. Unless there's convolution (like 1D convs) somewhere (maybe that's what I'm missing), would reordering the tokens really make a difference?

Birch-san commented 6 months ago

I think the random ordering is not a desired property (they don't mention shuffling in the paper), I think you're right that it doesn't make a difference, so they went with the easiest "discard 50%" algorithm they could think of, and it happened to have as property of shuffling the sequence and this didn't cause any problems so they didn't put in any effort to prevent it:
https://github.com/Anima-Lab/MaskDiT/blob/73cfb96ca9c6d1d78ab567ffeea099f96692a6e4/models/maskdit.py#L435-L438

I'm also going to read up on whether MDT is a more promising (or complementary) approach to MaskDiT. I think the SotA for diffusion transformers looks like this right now:

No-CFG:

Model      | params | batch  | steps |  FID
===========|========|========|=======|======
DiT-XL/2   |   675M |    256 | 7000k | 9.62
DiffiT     |   561M |    256 |     ? | 9.5x
MDTv1      |   700M |    256 | 2500k | 7.41
HDiT       |   557M |      † | 2200k | 6.92
MDTv1      |   700M |    256 | 6500k | 6.23
MaskDiT    |   735M?|   1024 | 2000k | 5.69
MDTv2      |   676M |    256 | 2000k | 5.06
†HDiT batch size was:
2000k steps  256 +
 100k steps  512 +
 100k steps 1024

CFG:

Model      | params | batch  | steps |  FID
===========|========|========|=======|======
HDiT       |   557M |      † | 2200k | 3.21
DiT-XL/2-G |   675M |    256 | 7000k | 2.27
MDTv1-G*   |   700M |    256 | 2500k | 2.15
MDTv1-G*   |   700M |    256 | 6500k | 1.79
DiffiT*    |   561M |    256 |     ? | 1.73
MaskDiT-G  |   735M?|   1024 | 2000k | 1.73
MDTv2-G*   |   676M |    256 | 3500k | 1.63
MDTv2-G*   |   676M |    256 | 4600k | 1.58

* uses power-cosine CFG proposed in MDT paper

DiffiT: Diffusion Vision Transformers for Image Generation
preprint Dec 2023
https://arxiv.org/abs/2312.02139

MDTv2: Masked Diffusion Transformer is a Strong Image Synthesizer
v1 Mar 2023, v2 Feb 2024
https://arxiv.org/abs/2303.14389

(MaskDiT) Fast Training of Diffusion Models with Masked Transformers
preprint Jun 2023, published Mar 2024
https://arxiv.org/abs/2306.09305

so MaskDiT looks very promising, but I should check MDT too to see whether that prescribes a change-in-direction… or whether I should frankenstein them together.

Birch-san commented 6 months ago

okay yeah, MDT uses the same trick of randomly picking n% of tokens, using torch.gather() on the chosen indices to create a shortened sequence.

so arbitrary masks will be useful for incorporating either of these SotA masking architectures. and also for multi-aspect training. likewise, scatter/gather iterators look useful for both. if the lack of orderedness complicates it: it's probably possible to sort the randomly-chosen indices before doing the gather, since I don't think this technique utilises the shuffledness. I just dunno what are the relative perf costs of "scattering random indices onto a zeroed canvas" vs "sort my indices + use a NATTEN scatter/gather iterator which assumes sorted indices" vs "NATTEN scatter/gather iterator that tolerates unsorted indices".

alihassanijr commented 6 months ago

Yeah so if the shuffled order is in fact just there for convenience and doesn't contribute anything, I would assume re-ordering them back into place once would definitely be more performant than using a gather/scatter iterator, at least if we're calling NA more than once. But the downside is we'd still need the gather/scatter if we don't want to put the tokens back into their original 2D grid, with masked out elements just being zeros. That will be just terrible.

I think we could just give the gather/scatter a shot; it'll be really easy to tell exactly how much it'll slow things down compared to a non-masked example. That might take a little more time to implement though, but it's definitely worth it if it unlocks both masking architectures.

And while this way we might not save any compute beyond what NA normally saves, just the fact that we're not loading masked tokens from global memory, or storing them back, should undo some of the overhead.

But the one potential issue to look out for is the NA window size. Queries that get most of their neighborhood masked out will likely converge to projections of themselves very quickly (although a query not having any neighbors is not possible, it'll at least have itself :) ).

Birch-san commented 6 months ago

I would assume re-ordering them back into place once would definitely be more performant than using a gather/scatter iterator, at least if we're calling NA more than once.

yeah, my indices will be prepared for the lowest-resolution that the model downsamples to, so 16**2 * 0.5=128 tokens. cheap to sort, and I could even precompute-and-reuse a few hundred variations if that were a concern.
NATTEN will handle higher resolutions, but those masks will just be a nearest-neighbour upsample of the low-res mask. not immediately sure how to produce the indices that correspond to that, but no sorting would be involved.

we'd still need the gather/scatter if we don't want to put the tokens back into their original 2D grid
I think we could just give the gather/scatter a shot

do you mean "the user scatters the tokens back onto a zeroed 2D grid", or do you mean "NATTEN scatters the indices without materializing the full grid"?
certainly built-in NATTEN could have surprising benefits; it probably paves the way for operating on sparse tensors. I wonder what kind of problem might use local attention on a sparse sequence.

Queries that get most of their neighborhood masked out will likely converge to projections of themselves very quickly

dropout poses a similar problem.

at the core of this is that attention entropy increases with sequence length.
as more elements are incorporated: the softmax denominator gets larger, outputting a more diffuse distribution of probabilities.
conversely: fewer elements leads to a sharper distribution of probabilities, in the worst-case one-hot (as you say, projeciton-of-self).

there's an easy way to potentially mitigate this. Jin et al 2023 proposes that you can scale the attention logits based on the difference between train-time and inference-time sequence length:
https://arxiv.org/abs/2306.08645

in our case, "train-time sequence length" is diminished by both dropout and masking.

# start with the usual attention scale
scale = qk_head_dim**-.5

# 2D kernels
natten_kernel_dim = 2
# number of key tokens in kernel
inference_key_len = kernel_size**natten_kernel_dim

# this is pseudocode at this point
for q in queries:
  # this kernel may be missing some keys.
  # there may actually be overlap between "missing due to random dropout" versus "missing due to user-provided mask"
  # so we should be careful not to double-count them.
  keys_lost_to_dropout_or_masking = …

  # thus this query attends to a shorter key than it would in inference
  current_key_len = inference_key_len - keys_lost_to_dropout_or_masking

  # make logits smaller when our key is shorter than the standard key length we will use during inference
  # this makes the softmax numerator smaller, which helps balance the fact that a shorter sequence makes the denominator smaller too
  # q_scale = scale * math.log(current_key_len, inference_key_len) ** .5

  # clamp the log operand and log base to at least 2, otherwise we get 0 or Inf, which is worse lol
  q_scale = scale * math.log(math.max(2, current_key_len), math.max(2, inference_key_len)) ** .5

does that seem worth a try? it seems like an entirely overlooked paper. Lumina-T2X recently used it. I've had some luck myself using it, to inference smaller-than-trained-distribution images:

left = original
right = entropy-scaled


entropy scaling seems to help un-fry smaller-than-trained images a bit, by diffusing the attention entropy.
though the composition remains weak because stable-diffusion's convolutions are responsible for implicit position embedding, and their dilation remains out-of-distribution (FouriScale and ScaleCrafter look into fixing the convolution part of the problem).

Birch-san commented 6 months ago

noticed a fun algebraic simplification, so I'll note it here.

in the case where:

the user can combat the attention entropy difference caused by dropout like so:

# start with the usual attention scale
scale = qk_head_dim**-.5

# 2D kernels
natten_kernel_dim = 2
# number of key tokens in kernel
inference_key_len = kernel_size**natten_kernel_dim
assert inference_key_len > 2, "how did you make a kernel that small"

# make my queries bigger, to combat the fact that my softmax denominator will be bigger than in training
scale *= (1 + math.log(dropout, inference_key_len))**.5

q *= scale

# now invoke NATTEN APIs as usual, with your larger q

=====

it's also possible during inference for a user to overcome the difference between train-with-dropout and test by inferencing with dropout enabled.

For vision models, you can even make multiple predictions with different dropout random seeds, and use multi-cond CFG to average them:

  cond = (  cond_pred_0 +  cond_pred_1  + …)/cond_count
uncond = (uncond_pred_0 + uncond_pred_1 + …)/uncond_count
result = uncond + (cond - uncond) * cfg_scale
alihassanijr commented 6 months ago

do you mean "the user scatters the tokens back onto a zeroed 2D grid", or do you mean "NATTEN scatters the indices without materializing the full grid"? certainly built-in NATTEN could have surprising benefits; it probably paves the way for operating on sparse tensors. I wonder what kind of problem might use local attention on a sparse sequence.

If users scatter it back into a 2D grid then NATTEN wouldn't need to do anything other than mask interactions with any KV index that is masked. We don't really care what the output value for masked queries is, because post-NA they'll get thrown out, right?

But FNA could probably do this on its own as well. It would likely slow down the 1D kernel (if there ever is a use case for it), but it probably wouldn't affect the 2D/3D kernels much, since we're already recomputing offsets per QKV token, so it probably won't be that different from prefetching offsets and masks (token/pixel masks not attention) from global memory and using those to offset pointers. But I need to think about this more.

As for the query scaling, are you saying that it's possible to have just one floating point scale for all queries? Because you could just multiply the attention scale (dim ** -0.5) by that and let FNA do that within the kernel. You'll save exactly one elementwise op 🙂 .

But even if not, per query scaling wouldn't be too difficult if we add masking; it's just multiplying attention weights by a different number for each row.

Birch-san commented 6 months ago

possible to have just one floating point scale for all queries?

not really. each query attends to a different number of keys.
it's hard for the user to compute "for each of my queries: how many keys will it see?"
they'd have to scatter their sequence, materialize a mask mimicking the NATTEN kernel pattern, take their [self-attn mask or gather indices] and roll it around somehow to count how many unmasked keys fall into each kernel area.
on top of this the user doesn't know which keys NATTEN's dropout will drop out for each query.
it's probably a lot easier in-kernel; NATTEN will know how which keys it's skipping (due to dropout or masking), so perhaps it has a mechanism to count how many keys will be skipped, and use that quantity to adjust scaling per-query?

the "no attn mask, just dropout" case might be simpler, reducible to "each query attends to 10% fewer keys", but only if "NATTEN drops out 10% of keys per kernel" rather than "NATTEN drops out 10% of keys, amortized over all kernels".
this doesn't answer the attn-masking problem, but might be generally-helpful for standardizing attention entropy when using NATTEN with dropout.

alihassanijr commented 6 months ago

NATTEN will know how which keys it's skipping (due to dropout or masking), so perhaps it has a mechanism to count how many keys will be skipped, and use that quantity to adjust scaling per-query?

Actually this is not always the case. Fused kernels have no mechanism through which they can determine how many attention weights a query will have; that's kind of the point. They can only refer to the size of KV, or window size, when there is no masking involved, but when we're masking things out, the kernel actually wouldn't have that information. You could count how many weights are being masked for each key in the forward pass, but the last place where you can scale queries / attention weights is when they're computed. So for example if I just computed the first tile of attention weights for a specific query tile, I can only do one extra pass on this particular attention tile and scale according to how many attention weights I masked in the first pass (assuming I count them), but I can't predict how many will get masked in following attention tiles.

Fused attention kernels generally don't guarantee realizing all attention weights for even 1 query, that's how they can parallelize more.

Birch-san commented 6 months ago

Fused attention kernels generally don't guarantee realizing all attention weights for even 1 query

ah, okay. if neither the user nor NATTEN can know (early enough to inform scaling) the number of keys to which each query attends: I guess we can't standardize the attention entropy perfectly.

the next best thing we can do is approach it using averages.

# start with the usual attention scale
scale = qk_head_dim**-.5

# 2D kernels
natten_kernel_dim = 2
# number of key tokens in kernel
inference_key_len = kernel_size**natten_kernel_dim

mask = 0.5
mask_keep = 1 - mask
dropout = 0.1

# dropout happens _after_ masking. we diminish the dropout rate, to account for the fact that it may drop out tokens that are already lost to the mask.
current_key_len = inference_key_len * mask_keep * (1 - dropout * mask_keep)
# scale *= math.log(current_key_len, inference_key_len) ** .5
# clamp the log operand and log base to at least 2, otherwise we get 0 or Inf, which is worse lol
scale *= math.log(math.max(2, current_key_len), math.max(2, inference_key_len)) ** .5

algebra note:. I think you could equally express this like so (because current_key_len is defined as a multiple of inference_key_len):

# scale *= (1 + math.log(mask_keep - dropout * mask_keep ** 2, inference_key_len))**.5
scale *= (1 + math.log(math.max(2, mask_keep - dropout * mask_keep ** 2), math.max(2, inference_key_len)))**.5

=====

small kernel sizes mean that there may be significant deviation between how many keys one query sees versus how many another sees, but I think (especially at 50% masking rate) all queries will be a lot closer to inference-time attention entropy than if we'd used standard scaling.
and the average-across-all-queries will look good too (which is what matters for the weight update).

so yeah, I think we can just use the existing scale parameter that NATTEN provides. as is often the case in ML: a less correct method with a fast implementation, will learn more per second.

I guess the point of this whole thread was "are we worried about it converging on queries-attend-only-to-themselves". it's a slight worry, but I think a masking rate of 50% with a dropout rate of 10% means (on average) keeping 47.5% of tokens in a 7x7 kernel, so 23.
so on average it's a similar situation to using a 5x5 kernel with 10% dropout. it's slightly uncomfortable, but hopefully the entropy correction will help. some kernels will be outliers that get affected heavily by both masking and dropout, but the weights should learn from the average outcome, so I think outliers won't influence it too much.

alihassanijr commented 6 months ago

I guess the point of this whole thread was "are we worried about it converging on queries-attend-only-to-themselves". it's a slight worry, but I think a masking rate of 50% with a dropout rate of 10% means (on average) keeping 47.5% of tokens in a 7x7 kernel, so 23. so on average it's a similar situation to using a 5x5 kernel with 10% dropout. it's slightly uncomfortable, but hopefully the entropy correction will help. some kernels will be outliers that get affected heavily by both masking and dropout, but the weights should learn from the average outcome, so I think outliers won't influence it too much.

Yeah so even if entropy ends up being a problem, I guess we could just compute such cases with unfused NA where we have full control. But even if that doesn't work or end up meeting the requirements, if there ever is an entropy problem, we can certainly think about better alternatives. I think we write out the online softmax math, we might find that it is possible to address the scaling online as well.

Sorry again that I keep responding sporadically. If work wasn't enough, conference deadline just got added to the mix.

alihassanijr commented 2 weeks ago

@Birch-san I guess Flex Attention pretty much does exactly what's expected here? 😅