rish-16 / tokenlearner-pytorch

Unofficial PyTorch implementation of TokenLearner by Google AI
MIT License
64 stars 10 forks source link

Implementation details of TokenFuser #2

Closed leijue222 closed 2 years ago

leijue222 commented 2 years ago

Snipaste_2022-04-19_15-03-16

In the paper, it said:

where X^j_{t} is the residual input to the previous TokenLearner module

So the fuser output = BY + X^j

But in the code, the fuser output = BY + SpatialAttention(X^j) https://github.com/rish-16/tokenlearner-pytorch/blob/a6908107c5b53b837127806fc1d46c64694bffc5/tokenlearner_pytorch/tokenlearner_pytorch.py#L59-L62

Why does the residual structure add to SpatialAttention(X^j) instead of X^j?

leijue222 commented 2 years ago

I wrote this version of the code with reference to the official open source code.

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k)
                                    for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x

class MLPBlock(nn.Module):
    """Transformer MLP / feed-forward block.
    https://github.com/google-research/scenic/blob/5b5a78da05855dc8111aaaa68bd6e71c783e1422/scenic/model_lib/layers/attention_layers.py#L393
    """
    def __init__(self, mlp_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        actual_out_dim = mlp_dim if out_dim is None else out_dim
        self.layer_in = nn.Linear(mlp_dim, mlp_dim)
        self.activation_fn = F.gelu
        self.drop1 = nn.Dropout(dropout_rate)

        self.layer_out = nn.Linear(mlp_dim, actual_out_dim)
        self.drop2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.layer_in(x)
        x = self.activation_fn(x)
        x = self.drop1(x)
        x = self.layer_out(x)
        x = self.drop2(x)

        return x

class TokenLearner(nn.Module):
    def __init__(self, d_model, n_token=8):
        super().__init__()
        self.dropout_rate = 0.1
        self.num_tokens = n_token
        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.layerNorm = nn.LayerNorm(d_model)
        self.softmax = nn.Softmax(-1)

    def forward(self, inputs):
        """Applies learnable tokenization to the 2D inputs.
        Args:
        inputs: Inputs of shape `[bs, h, w, c]`.

        Returns:
        Output of shape `[bs, n_token, c]`.
        """
        bs, h, w, c = inputs.shape
        inputs = inputs.reshape(bs, h*w, c)

        feature_shape = inputs.shape

        selected = inputs

        selected = self.layerNorm(selected)
        selected = self.MlpBlock(selected)
        selected = selected.reshape(feature_shape[0], -1, self.num_tokens) # Shape: [bs, h*w, n_token].
        selected = selected.permute(0, 2, 1)  # Shape: [bs, n_token, h*w].
        selected = self.softmax(selected)

        feat = inputs
        feat = torch.einsum('...si,...id->...sd', selected, feat)
        return feat

class TokenFuser(nn.Module):
    def __init__(self, d_model, num_tokens, use_normalization=True):
        super().__init__()
        self.num_tokens = num_tokens
        self.dropout_rate = 0.
        self.use_normalization = use_normalization
        self.fuser_mix_norm1 = nn.LayerNorm(d_model)
        self.layer_inputs = nn.Linear(num_tokens, num_tokens)
        self.fuser_mix_norm2 = nn.LayerNorm(d_model)
        self.original_norm = nn.LayerNorm(d_model)

        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.drop_inputs = nn.Dropout(self.dropout_rate)

    def forward(self, inputs, original):
        """Applies token fusion to the generate 2D ouputs.
        Args:
        inputs: Inputs of shape `[bs, n_token, c]`.
        original: Inputs of shape `[bs, hw, c]` or `[bs, h, w, c]`.

        Returns:
        Output tensor with the shape identical to `original'.
        """
        if original.ndim == 4:
            n, h, w, c = original.shape
            original = original.reshape(n, h*w, c)

        if self.use_normalization:
            inputs = self.fuser_mix_norm1(inputs)

        inputs = inputs.permute(0, 2, 1)  # Shape: [bs, c, n_token].
        inputs = self.layer_inputs(inputs)
        inputs = inputs.permute(0, 2, 1)  # Shape: [bs, n_token, c].

        if self.use_normalization:
            inputs = self.fuser_mix_norm2(inputs)

        original = self.original_norm(original)
        mix = self.MlpBlock(original)     # Shape: [bs, h*w, n_token].
        mix = self.sigmoid(mix)

        inputs = torch.einsum('...sc,...hs->...hc', inputs, mix)    # Shape: [bs, h*w, c].
        inputs = self.drop_inputs(inputs)

        inputs = inputs.reshape(n, h, w, -1)

        return inputs

if __name__ == '__main__':
    # B, H, W, C 
    # img = torch.Tensor(4, 32, 32, 3)

    x = torch.rand(10, 64, 48, 96)          # torch.Size([4, 64, 48, 96])
    tklr = TokenLearner(d_model=96, n_token=8)        
    tklr_res = tklr(x)                  # torch.Size([4, 8, 96]) B, N, C
    print('tklr_res shape: ', tklr_res.shape)

    tkfr = TokenFuser(96, 8) 
    tkfr_res = tkfr(tklr_res, x)      # torch.Size([4, 64, 48, 3])
    print('tkfr_res shape: ', tkfr_res.shape)
hpppppp8 commented 1 year ago

Can it be applied to video tasks simply by adding a time dimension? In the video task, according to the requirements of the paper, a Multi-head Attention is required. Does this require additional implementation? Or can point out what is wrong in my code?Thx!!!!!

import torch
import torch.nn as nn
import torch.nn.functional as F

class MLPBlock(nn.Module):
    """Transformer MLP / feed-forward block.
    https://github.com/google-research/scenic/blob/5b5a78da05855dc8111aaaa68bd6e71c783e1422/scenic/model_lib/layers/attention_layers.py#L393
    """
    def __init__(self, mlp_dim, out_dim, dropout_rate=0.1):
        super().__init__()
        actual_out_dim = mlp_dim if out_dim is None else out_dim
        self.layer_in = nn.Linear(mlp_dim, mlp_dim)
        self.activation_fn = F.gelu
        self.drop1 = nn.Dropout(dropout_rate)

        self.layer_out = nn.Linear(mlp_dim, actual_out_dim)
        self.drop2 = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.layer_in(x)
        x = self.activation_fn(x)
        x = self.drop1(x)
        x = self.layer_out(x)
        x = self.drop2(x)

        return x

class TokenLearner(nn.Module):
    def __init__(self, d_model, n_token=8):
        super().__init__()
        self.dropout_rate = 0.1
        self.num_tokens = n_token
        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.layerNorm = nn.LayerNorm(d_model)
        self.softmax = nn.Softmax(-1)

    def forward(self, inputs):
        """Applies learnable tokenization to the 3D inputs.
        Args:
        inputs: Inputs of shape `[bs, f, h, w, c]`.

        Returns:
        Output of shape `[bs, f, n_token, c]`.
        """
        bs, f, h, w, c = inputs.shape
        inputs = inputs.reshape(bs, f, h*w, c)

        feature_shape = inputs.shape

        selected = inputs
        # print("selected1.shape", selected.shape)
        selected = self.layerNorm(selected)
        # print("selected2.shape", selected.shape)
        selected = self.MlpBlock(selected)
        # print("selected3.shape", selected.shape)
        selected = selected.reshape(bs, feature_shape[1], -1, self.num_tokens) # Shape: [bs, h*w, n_token].
        # print("selected4.shape", selected.shape)
        selected = selected.permute(0, 1, 3, 2)  # Shape: [bs, n_token, h*w].
        print("selected5.shape", selected.shape)
        selected = self.softmax(selected)
        print("selected6.shape", selected.shape)

        feat = inputs
        # print("inputs.shape", inputs.shape)

        feat = torch.einsum('...si,...id->...sd', selected, feat)
        print("feat.shape", feat.shape)
        return feat

class TokenFuser(nn.Module):
    def __init__(self, d_model, num_tokens, use_normalization=True):
        super().__init__()
        self.num_tokens = num_tokens
        self.dropout_rate = 0.
        self.use_normalization = use_normalization
        self.fuser_mix_norm1 = nn.LayerNorm(d_model)
        self.layer_inputs = nn.Linear(num_tokens, num_tokens)
        self.fuser_mix_norm2 = nn.LayerNorm(d_model)
        self.original_norm = nn.LayerNorm(d_model)

        self.MlpBlock  = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
        self.sigmoid = nn.Sigmoid()
        self.drop_inputs = nn.Dropout(self.dropout_rate)

    def forward(self, inputs, original):
        """Applies token fusion to the generate 3D ouputs.
        Args:
        inputs: Inputs of shape `[bs, f, n_token, c]`.
        original: Inputs of shape `[bs, f, hw, c]` or `[bs, f, h, w, c]`.

        Returns:
        Output tensor with the shape identical to `original'.
        """
        if original.ndim == 5:
            bs, f, h, w, c = original.shape
            original = original.reshape(bs, f, h*w, c)

        if self.use_normalization:
            inputs = self.fuser_mix_norm1(inputs)
        # print("inputs.shape", inputs.shape)

        inputs = inputs.permute(0, 1, 3, 2)  # Shape: [bs, f, c, n_token].
        # print("inputs.shape", inputs.shape)
        inputs = self.layer_inputs(inputs)
        inputs = inputs.permute(0, 1, 3, 2)  # Shape: [bs, f, n_token, c].
        # print("inputs.shape", inputs.shape)

        if self.use_normalization:
            inputs = self.fuser_mix_norm2(inputs)

        original = self.original_norm(original)
        mix = self.MlpBlock(original)     # Shape: [bs, h*w, n_token].
        mix = self.sigmoid(mix)

        inputs = torch.einsum('...sc,...hs->...hc', inputs, mix)    # Shape: [bs, h*w, c].
        inputs = self.drop_inputs(inputs)

        inputs = inputs.reshape(bs, f, h, w, -1)

        return inputs

if __name__ == '__main__':
    x = torch.rand(256, 16, 7, 7, 1024)          # [bs, F, H, W, C]
    print("x.shape", x.shape)
    tklr = TokenLearner(d_model=1024, n_token=8)        
    tklr_res = tklr(x)                  # torch.Size([128, 16, 8, 1024]) B, F, N, C
    print('tklr_res shape: ', tklr_res.shape)

    tkfr = TokenFuser(1024, 8) 
    tkfr_res = tkfr(tklr_res, x)      # torch.Size([128, 16, 7, 7, 1024])
    print('tkfr_res shape: ', tkfr_res.shape)

    final_res = torch.mean(tkfr_res, dim=(2, 3))
    print("final_res.shape", final_res.shape)