Open Mareeta26 opened 2 years ago
Hi!
You can replace SimA with any self-attention module. It may requires some parameters tuning for specific models. We tried on CvT, ViT and XCiT. It also works with DINO loss (self-supervised). I plan to try it with MAE in the future. When you said decoder, which model are you referring to (e.g, decoder of DETR or MAE)?
Thanks! Have a good day!
@soroush-abbasi Thanks for the reply. I meant the decoder of ConvTransformer, which incorporates convolutions in transformer.
@soroush-abbasi Also, is it possible to share the code for ViT with SimA? Thank you in advance!
I guess it should work with decoder of ConvTransformer. You can simply replace self-attention with SimA attention (SimA class in below). To run with ViT/DeiT architecture, please replace these classes in sima.py as below (removing LPI layer, removing class attention layer):
class SimA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
k = F.normalize(k, p=1.0, dim=-2)
q = F.normalize(q, p=1.0, dim=-2)
if (N / (C//self.num_heads)) < 1:
x = ((q @ k.transpose(-2, -1)) @ v).transpose(1, 2).reshape(B, N, C)
else:
x = (q @ (k.transpose(-2, -1) @ v)).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {}
class SimABlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0.,
attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
num_tokens=196, eta=None):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = SimA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer,
drop=drop)
self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
def forward(self, x, H, W):
x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
class SimAVisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768,
depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None,
cls_attn_layers=2, use_pos=True, patch_proj='linear', eta=None, tokens_norm=False):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
self.patch_embed = ConvPatchEmbed(img_size=img_size, embed_dim=embed_dim,
patch_size=patch_size)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([
SimABlock(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, num_tokens=num_patches, eta=eta)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.pos_embeder = PositionalEncodingFourier(dim=embed_dim)
self.use_pos = use_pos
# Classifier head
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'}
def forward_features(self, x):
B, C, H, W = x.shape
x, (Hp, Wp) = self.patch_embed(x)
if self.use_pos:
pos_encoding = self.pos_embeder(B, Hp, Wp).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = self.pos_drop(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
for blk in self.blocks:
x = blk(x, Hp, Wp)
x = self.norm(x)[:, 0]
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
if self.training:
return x, x
else:
return x
@soroush-abbasi Thank you! So, don't we need SimABlock class for ConvTransformer? What is the purpose of it? Can you please explain? My Input to the self attention module is a 5D tensor.Eg : 8,19,128,16,16. How shall I modify SimA class for such an input ?
SimABlock is a regular transformer block which has both self-attention(SimA) and MLP layer. As long as you replace self-attention in your code with SimA you should be fine I guess. So you need to figure out which dimensions in your input is sequence (N) and which is Token dimensions (D) . Or if your features are after splitting to multi-head, you need to find the ordering of B (batch size), H (heads), D (dimension after splitting) and N (sequence length/ number of tokens). sometimes tokens are not flattens. For example, one can look at the image feature maps as a set of tokens with 2D shape. If you have 512x16x16 feature map, you can flatten the last two dimensions to get 512x256 tokens (D=512, N=256). I guess last two dimensions are feature maps of the image in your case, but I'm not sure. Unfortunately, I'm not familiar with ConTransformer.
Thanks!
Sure, thanks for the reply!!
@soroush-abbasi Can we use SimA if it's a masked self-attention?
Hi,
It's a little complicated. So we normalize tokens in channel dimension before doing QKV dot product. Since we normalize tokens in the channel dimension, each token have effect on other tokens. Therefore, if you want to mask tokens, you need to apply masking before L1-normalization. Please let me know if you have more questions.
Thanks! Have a great day!
Hi, May I know whether I can use sima instead of multi head attention in decoder, to reduce complexity?
Thanks!