Closed leijue222 closed 2 years ago
I wrote this version of the code with reference to the official open source code.
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
""" Very simple multi-layer perceptron (also called FFN)"""
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k)
for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.gelu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class MLPBlock(nn.Module):
"""Transformer MLP / feed-forward block.
https://github.com/google-research/scenic/blob/5b5a78da05855dc8111aaaa68bd6e71c783e1422/scenic/model_lib/layers/attention_layers.py#L393
"""
def __init__(self, mlp_dim, out_dim, dropout_rate=0.1):
super().__init__()
actual_out_dim = mlp_dim if out_dim is None else out_dim
self.layer_in = nn.Linear(mlp_dim, mlp_dim)
self.activation_fn = F.gelu
self.drop1 = nn.Dropout(dropout_rate)
self.layer_out = nn.Linear(mlp_dim, actual_out_dim)
self.drop2 = nn.Dropout(dropout_rate)
def forward(self, x):
x = self.layer_in(x)
x = self.activation_fn(x)
x = self.drop1(x)
x = self.layer_out(x)
x = self.drop2(x)
return x
class TokenLearner(nn.Module):
def __init__(self, d_model, n_token=8):
super().__init__()
self.dropout_rate = 0.1
self.num_tokens = n_token
self.MlpBlock = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
self.layerNorm = nn.LayerNorm(d_model)
self.softmax = nn.Softmax(-1)
def forward(self, inputs):
"""Applies learnable tokenization to the 2D inputs.
Args:
inputs: Inputs of shape `[bs, h, w, c]`.
Returns:
Output of shape `[bs, n_token, c]`.
"""
bs, h, w, c = inputs.shape
inputs = inputs.reshape(bs, h*w, c)
feature_shape = inputs.shape
selected = inputs
selected = self.layerNorm(selected)
selected = self.MlpBlock(selected)
selected = selected.reshape(feature_shape[0], -1, self.num_tokens) # Shape: [bs, h*w, n_token].
selected = selected.permute(0, 2, 1) # Shape: [bs, n_token, h*w].
selected = self.softmax(selected)
feat = inputs
feat = torch.einsum('...si,...id->...sd', selected, feat)
return feat
class TokenFuser(nn.Module):
def __init__(self, d_model, num_tokens, use_normalization=True):
super().__init__()
self.num_tokens = num_tokens
self.dropout_rate = 0.
self.use_normalization = use_normalization
self.fuser_mix_norm1 = nn.LayerNorm(d_model)
self.layer_inputs = nn.Linear(num_tokens, num_tokens)
self.fuser_mix_norm2 = nn.LayerNorm(d_model)
self.original_norm = nn.LayerNorm(d_model)
self.MlpBlock = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
self.sigmoid = nn.Sigmoid()
self.drop_inputs = nn.Dropout(self.dropout_rate)
def forward(self, inputs, original):
"""Applies token fusion to the generate 2D ouputs.
Args:
inputs: Inputs of shape `[bs, n_token, c]`.
original: Inputs of shape `[bs, hw, c]` or `[bs, h, w, c]`.
Returns:
Output tensor with the shape identical to `original'.
"""
if original.ndim == 4:
n, h, w, c = original.shape
original = original.reshape(n, h*w, c)
if self.use_normalization:
inputs = self.fuser_mix_norm1(inputs)
inputs = inputs.permute(0, 2, 1) # Shape: [bs, c, n_token].
inputs = self.layer_inputs(inputs)
inputs = inputs.permute(0, 2, 1) # Shape: [bs, n_token, c].
if self.use_normalization:
inputs = self.fuser_mix_norm2(inputs)
original = self.original_norm(original)
mix = self.MlpBlock(original) # Shape: [bs, h*w, n_token].
mix = self.sigmoid(mix)
inputs = torch.einsum('...sc,...hs->...hc', inputs, mix) # Shape: [bs, h*w, c].
inputs = self.drop_inputs(inputs)
inputs = inputs.reshape(n, h, w, -1)
return inputs
if __name__ == '__main__':
# B, H, W, C
# img = torch.Tensor(4, 32, 32, 3)
x = torch.rand(10, 64, 48, 96) # torch.Size([4, 64, 48, 96])
tklr = TokenLearner(d_model=96, n_token=8)
tklr_res = tklr(x) # torch.Size([4, 8, 96]) B, N, C
print('tklr_res shape: ', tklr_res.shape)
tkfr = TokenFuser(96, 8)
tkfr_res = tkfr(tklr_res, x) # torch.Size([4, 64, 48, 3])
print('tkfr_res shape: ', tkfr_res.shape)
Can it be applied to video tasks simply by adding a time dimension? In the video task, according to the requirements of the paper, a Multi-head Attention is required. Does this require additional implementation? Or can point out what is wrong in my code?Thx!!!!!
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLPBlock(nn.Module):
"""Transformer MLP / feed-forward block.
https://github.com/google-research/scenic/blob/5b5a78da05855dc8111aaaa68bd6e71c783e1422/scenic/model_lib/layers/attention_layers.py#L393
"""
def __init__(self, mlp_dim, out_dim, dropout_rate=0.1):
super().__init__()
actual_out_dim = mlp_dim if out_dim is None else out_dim
self.layer_in = nn.Linear(mlp_dim, mlp_dim)
self.activation_fn = F.gelu
self.drop1 = nn.Dropout(dropout_rate)
self.layer_out = nn.Linear(mlp_dim, actual_out_dim)
self.drop2 = nn.Dropout(dropout_rate)
def forward(self, x):
x = self.layer_in(x)
x = self.activation_fn(x)
x = self.drop1(x)
x = self.layer_out(x)
x = self.drop2(x)
return x
class TokenLearner(nn.Module):
def __init__(self, d_model, n_token=8):
super().__init__()
self.dropout_rate = 0.1
self.num_tokens = n_token
self.MlpBlock = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
self.layerNorm = nn.LayerNorm(d_model)
self.softmax = nn.Softmax(-1)
def forward(self, inputs):
"""Applies learnable tokenization to the 3D inputs.
Args:
inputs: Inputs of shape `[bs, f, h, w, c]`.
Returns:
Output of shape `[bs, f, n_token, c]`.
"""
bs, f, h, w, c = inputs.shape
inputs = inputs.reshape(bs, f, h*w, c)
feature_shape = inputs.shape
selected = inputs
# print("selected1.shape", selected.shape)
selected = self.layerNorm(selected)
# print("selected2.shape", selected.shape)
selected = self.MlpBlock(selected)
# print("selected3.shape", selected.shape)
selected = selected.reshape(bs, feature_shape[1], -1, self.num_tokens) # Shape: [bs, h*w, n_token].
# print("selected4.shape", selected.shape)
selected = selected.permute(0, 1, 3, 2) # Shape: [bs, n_token, h*w].
print("selected5.shape", selected.shape)
selected = self.softmax(selected)
print("selected6.shape", selected.shape)
feat = inputs
# print("inputs.shape", inputs.shape)
feat = torch.einsum('...si,...id->...sd', selected, feat)
print("feat.shape", feat.shape)
return feat
class TokenFuser(nn.Module):
def __init__(self, d_model, num_tokens, use_normalization=True):
super().__init__()
self.num_tokens = num_tokens
self.dropout_rate = 0.
self.use_normalization = use_normalization
self.fuser_mix_norm1 = nn.LayerNorm(d_model)
self.layer_inputs = nn.Linear(num_tokens, num_tokens)
self.fuser_mix_norm2 = nn.LayerNorm(d_model)
self.original_norm = nn.LayerNorm(d_model)
self.MlpBlock = MLPBlock(d_model, self.num_tokens, self.dropout_rate)
self.sigmoid = nn.Sigmoid()
self.drop_inputs = nn.Dropout(self.dropout_rate)
def forward(self, inputs, original):
"""Applies token fusion to the generate 3D ouputs.
Args:
inputs: Inputs of shape `[bs, f, n_token, c]`.
original: Inputs of shape `[bs, f, hw, c]` or `[bs, f, h, w, c]`.
Returns:
Output tensor with the shape identical to `original'.
"""
if original.ndim == 5:
bs, f, h, w, c = original.shape
original = original.reshape(bs, f, h*w, c)
if self.use_normalization:
inputs = self.fuser_mix_norm1(inputs)
# print("inputs.shape", inputs.shape)
inputs = inputs.permute(0, 1, 3, 2) # Shape: [bs, f, c, n_token].
# print("inputs.shape", inputs.shape)
inputs = self.layer_inputs(inputs)
inputs = inputs.permute(0, 1, 3, 2) # Shape: [bs, f, n_token, c].
# print("inputs.shape", inputs.shape)
if self.use_normalization:
inputs = self.fuser_mix_norm2(inputs)
original = self.original_norm(original)
mix = self.MlpBlock(original) # Shape: [bs, h*w, n_token].
mix = self.sigmoid(mix)
inputs = torch.einsum('...sc,...hs->...hc', inputs, mix) # Shape: [bs, h*w, c].
inputs = self.drop_inputs(inputs)
inputs = inputs.reshape(bs, f, h, w, -1)
return inputs
if __name__ == '__main__':
x = torch.rand(256, 16, 7, 7, 1024) # [bs, F, H, W, C]
print("x.shape", x.shape)
tklr = TokenLearner(d_model=1024, n_token=8)
tklr_res = tklr(x) # torch.Size([128, 16, 8, 1024]) B, F, N, C
print('tklr_res shape: ', tklr_res.shape)
tkfr = TokenFuser(1024, 8)
tkfr_res = tkfr(tklr_res, x) # torch.Size([128, 16, 7, 7, 1024])
print('tkfr_res shape: ', tkfr_res.shape)
final_res = torch.mean(tkfr_res, dim=(2, 3))
print("final_res.shape", final_res.shape)
In the paper, it said:
So the fuser output = BY + X^j
But in the code, the fuser output = BY + SpatialAttention(X^j) https://github.com/rish-16/tokenlearner-pytorch/blob/a6908107c5b53b837127806fc1d46c64694bffc5/tokenlearner_pytorch/tokenlearner_pytorch.py#L59-L62
Why does the residual structure add to SpatialAttention(X^j) instead of X^j?