qhfan / RMT

(CVPR2024)RMT: Retentive Networks Meet Vision Transformer
272 stars 18 forks source link

A question about decay mask #21

Open Josh00-Lu opened 5 months ago

Josh00-Lu commented 5 months ago

Congratulations on your wonderful work. I learned a lot from it.

I have a small question. In your code:

https://github.com/qhfan/RMT/blob/18c09a80a8cc58dfb41899abe84b5f7a8a35d8aa/classfication_release/RMT.py#L227-L229

here mask[some_pos] = $log(\gamma^d)$

I am wondering why the decay mask here is added to dot_production before the softmax().

In the paper, the decay mask is multiplied after the softmax(), which is not equal, i guess.

Thanks very much.

Josh00-Lu commented 5 months ago

Congratulations on your wonderful work. I learned a lot from it.

I have a small question. In your code:

https://github.com/qhfan/RMT/blob/18c09a80a8cc58dfb41899abe84b5f7a8a35d8aa/classfication_release/RMT.py#L227-L229

here mask[some_pos] = log(γd) why the decay mask is add to dot product before softmax().

In the paper, the decay mask is multiplied after the softmax(), which is not equal, i guess.

Thanks very much.

For example, when qk_mat is like:

[
  [-10, 10], 
  [10, 10],
]

I'm not sure whether mask would provide obvious decay.

syz247179876 commented 4 months ago

I am also confused about this

longzilicart commented 3 months ago

+1, I am also confused about this....

qhfan commented 3 months ago

Thank you for your attention to our work. Our implementation follows the decay mask approach in RetNet. Adding the attention weights and the mask before applying softmax is equivalent to performing softmax on the attention weights and then multiplying by the mask, followed by normalization. This is a relatively simple mathematical derivation, which is why we did not specify it in the paper. To facilitate this operation, we took the logarithm of the decay weight when generating the mask.