microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.37k stars 94 forks source link

mu parametrization for channel attention #18

Closed xwjabc closed 2 years ago

xwjabc commented 2 years ago

Hi, I have another question about the mu parametrization for a special attention mechanism - channel attention.

In standard scaled dot-product attention (also regarded as spatial attention), we have Q, K, V with shape n x d (ignoring heads) and we will calculate softmax(scale * Q K^T) V to get a n x d output, where scale = 1/sqrt(d) in SP and scale = 1/d in muP (or 1/sqrt(d_0) / width_mult in muP for backward compatiblity).

In channel attention, we still have Q, K, V with shape n x d (ignoring heads). The different part is, we will calculate (softmax(scale * Q^T K) V^T)^T to get a n x d output, where scale = 1/sqrt(n) in SP. Since the attention map Q^T K now has shape d x d instead of n x n, I am not sure how the scale should be modified in SP accordingly. Should we use 1/sqrt(n) / width_mult?

In addition, Appendix B - Matrix-Like, Vector-Like, Scalar-Like Parameters has some interpretation behind the scale:

a multiplier of order 1=fan_in should accompany any weight that maps an infinite dimension to a finite one. This interpretation then nicely covers both the output logits and the attention logits (i.e. 1/d attention).

But such interpretation may not be directly used as a guidance to set up the scale in the channel attention.

Thanks!

edwardjhu commented 2 years ago

Since n (which I assume to be the batch size) is finite, the coordinates of Q^T K are \Theta(1) in d. The matmul of Q^T K and V^T, however, involves a summation over d, which needs to be scaled down by 1/d. So you want something like (softmax(scale * Q^T K) V^T / width_mult)^T.

Another way to look at this is that we are mapping a tensor with two inf dimensions (Q^T K) to a tensor with just one inf dimension; hence, we need to scale by 1/fan_in after this mapping.

xwjabc commented 2 years ago

Yes n is the number of tokens which is finite, whereas d is the feature dimension which is infinite. I think I got it, and if I understood it correctly, what I need should be (softmax(scale * Q^T K) V^T / width_mult)^T = (softmax(1/sqrt(n) * Q^T K) V^T / width_mult)^T Thank you for your quick response!

xwjabc commented 2 years ago

In addition, if we scale the n_head instead of d_head in channel attention, does it mean that we can simply use the original (softmax(scale * Q^T K) V^T)^T? Thanks!

edwardjhu commented 2 years ago

Sorry for the delay. I missed this earlier.

For a fixed d_head the original formulation looks good as long as scale isn't a function of d_head, since we don't have a summation over infinite many coordinates anymore.

xwjabc commented 2 years ago

Gotcha. Thank you, Edward!