cmsflash / efficient-attention

An implementation of the efficient attention module.
https://arxiv.org/abs/1812.01243
MIT License
276 stars 26 forks source link

How to use efficient attention class #6

Closed chandlerbing65nm closed 2 years ago

chandlerbing65nm commented 2 years ago

image

I'm trying to understand how to use your attention module based on the figure above and the code below.

From what I understand from the non-local paper, if I have an input feature of m_batchsize, channels, height, width = input_.size(),

then n = m_batchsize*height*width and d = channels.

So in the code below, I should use channels = in_channels, key_channels, value_channels.

But what should the head_countbe? Should it be divisible by the number of channels?

import torch
from torch import nn
from torch.nn import functional as f

class EfficientAttention(nn.Module):

    def __init__(self, in_channels, key_channels, head_count, value_channels):
        super().__init__()
        self.in_channels = in_channels
        self.key_channels = key_channels
        self.head_count = head_count
        self.value_channels = value_channels

        self.keys = nn.Conv2d(in_channels, key_channels, 1)
        self.queries = nn.Conv2d(in_channels, key_channels, 1)
        self.values = nn.Conv2d(in_channels, value_channels, 1)
        self.reprojection = nn.Conv2d(value_channels, in_channels, 1)

    def forward(self, input_):
        n, _, h, w = input_.size()
        keys = self.keys(input_).reshape((n, self.key_channels, h * w))
        queries = self.queries(input_).reshape(n, self.key_channels, h * w)
        values = self.values(input_).reshape((n, self.value_channels, h * w))
        head_key_channels = self.key_channels // self.head_count
        head_value_channels = self.value_channels // self.head_count

        attended_values = []
        for i in range(self.head_count):
            key = f.softmax(keys[
                :,
                i * head_key_channels: (i + 1) * head_key_channels,
                :
            ], dim=2)
            query = f.softmax(queries[
                :,
                i * head_key_channels: (i + 1) * head_key_channels,
                :
            ], dim=1)
            value = values[
                :,
                i * head_value_channels: (i + 1) * head_value_channels,
                :
            ]
            context = key @ value.transpose(1, 2)
            attended_value = (
                context.transpose(1, 2) @ query
            ).reshape(n, head_value_channels, h, w)
            attended_values.append(attended_value)

        aggregated_values = torch.cat(attended_values, dim=1)
        reprojected_value = self.reprojection(aggregated_values)
        attention = reprojected_value + input_

        return attention
cmsflash commented 2 years ago

Hi Chandler,

The channel counts correspond as follows: d = in_channels, d_k = key_channels, d_v = value_channels. head_count must divide key_channels and value_channels, but not necessarily in_channels. If you don't know what value to set it to, 8 is usually a good default value.

chandlerbing65nm commented 2 years ago

@cmsflash what should be the values of the d = in_channels, d_k = key_channels, d_v = value_channels or the default values? Should I set it all to be equal to channels? If my input is like this m_batchsize, channels, height, width = input_.size()

or like this:

1, 256, 28, 28= x.size()`

Attention = EfficientAttention(in_channels = 256, key_channels = 256, head_count = 8, value_channels = 256) Attention(x)

cmsflash commented 2 years ago

Hi Chandler, d = in_channels, d_k = key_channels, d_v = value_channels is just the correspondence of the variables between the figure and the code, because we were using more mathematical notations in the figure, and more programmatic notations in the code.

in_channels obviously is the channel count of your input, so it is equal to channels. value_channels decides the channel count of the output, so if you want to keep it the same as the input, then it's also channels. You are free to set key_channels to tune the computational cost of the module, the higher it is, the more costly the module is, which usually also leads to better performance. If you don't have an idea, then key_channels = in_channels or key_channels = in_channels // 2 are good default values.

chandlerbing65nm commented 2 years ago

Thank you very much @cmsflash. Now I understand how to use attention blocks.

feimadada commented 1 year ago

hi, i am trying to use the efficient attention module, but I found that key_channels and value_channels has no impact with the output dimension

Hi Chandler, d = in_channels, d_k = key_channels, d_v = value_channels is just the correspondence of the variables between the figure and the code, because we were using more mathematical notations in the figure, and more programmatic notations in the code.

in_channels obviously is the channel count of your input, so it is equal to channels. value_channels decides the channel count of the output, so if you want to keep it the same as the input, then it's also channels. You are free to set key_channels to tune the computational cost of the module, the higher it is, the more costly the module is, which usually also leads to better performance. If you don't have an idea, then key_channels = in_channels or key_channels = in_channels // 2 are good default values.

cmsflash commented 1 year ago

Hi @feimadada, sorry for the misunderstanding. value_channels only controls the channel count for the output of the core attention step. The default EA module, as I implemented here, has a residual connection around the attention step. Therefore, it has an additional self.reprojection module to project the output back to the same number of channels as the input and then add them up before returning.

If you want to have the output dimension different from the input, I'd suggest you either add an additional linear layer after the EA module or modify the code the remove the residual connection and reprojection.

cmsflash commented 1 year ago

For follow-up discussion on the issue @feimadada raised, refer to the dedicaticated issue #10.