cmsflash / efficient-attention

An implementation of the efficient attention module.
https://arxiv.org/abs/1812.01243
MIT License
272 stars 26 forks source link

in_channels, key_channels, head_count, value_channels #10

Closed feimadada closed 1 year ago

feimadada commented 1 year ago

i am trying to use the efficient attention. I am confused about the params. I change the key_channels and the values, but the output dimension keeps the same, why?

cmsflash commented 1 year ago

Hi @feimadada, I replied to you in https://github.com/cmsflash/efficient-attention/issues/6, but I think it's better to isolate the discussion in this separate thread, so I'm replying here again.

value_channels only controls the channel count for the output of the core attention step. The default EA module, as I implemented here, has a residual connection around the attention step. Therefore, it has an additional self.reprojection module to project the output back to the same number of channels as the input and then add them up before returning.

If you want to have the output dimension different from the input, I'd suggest you either add an additional linear layer after the EA module or modify the code to remove the residual connection and reprojection.

feimadada commented 1 year ago

got it. Thanks