AmeenAli / HiddenMambaAttn

Official PyTorch Implementation of "The Hidden Attention of Mamba Models"
204 stars 12 forks source link

Mamba attention matrix aggregation. #6

Closed patronum08 closed 8 months ago

patronum08 commented 8 months ago

Thank you for your work. I have learned a lot of useful insights about mamba from it.

I am a bit confused about why mamba's attention has 4 columns in Figure 3. Also, how is the attention from different channels of mamba aggregated?

patronum08 commented 8 months ago

Additionally, is it unfair to negative attention scores in the mamba attention when they are normalized using softmax and then plotted in the figure?

Itamarzimm commented 8 months ago

Hi @patronum08 , thanks for your questions.

  1. The (hidden) Mamba attention matrices are extracted from each S6 channel. We selected 4 representative channels for visualization. Similarly, for transformers, we focus on a single representative head. Please note that channels are not aggregated in this figure.
  2. To present a comparative visualization, it's necessary to normalize both Mamba and transformer attention matrices to the same domain. A natural approach is to normalize scores obtained from Mamba to a 0-1 range. This normalization can be achieved via softmax (just like attention) or by using min-max normalization. We acknowledge that using softmax could be problematic for negative attention scores. Therefore, in the second version of our paper, we employ min-max normalization on the absolute values of the scores, as illustrated in Figure 4 of version 2 of our paper.