Closed nightsnack closed 1 year ago
For vision tasks, we empirically found that subln is better in performance than postln. Moreover, it has a same theoretical guarantee of training stability as in DeepNet. So we suggest using subln rather than deepnorm for the vision tasks.
Indeed. I also found postln works on part of cv tasks, like supervised training image classification. But on MIM, it failed to converge. Also DeepNet slow down the converge speed. I will try subln. Thanks for mentioning.
For vision tasks, we empirically found that subln is better in performance than postln. Moreover, it has a same theoretical guarantee of training stability as in DeepNet. So we suggest using subln rather than deepnorm for the vision tasks.
Hi Shuming, I implement subln on BeiTv1 with the training setting in the appendix. While I didn't find the magnitude of each block's feature map hold constant. On the contrary, the magnitude blows up as the training goes. If I build an extremely deep model like the one in DeepNet, the magnitude of feature will soon exceed fp16‘s upper bound. Same phenomena is also found in MAE and supervised training of ViT-Base.
As this repo does not provide examples of BeiT or other vision models, my implementation is based on unilm/beit and I add sub-ln to Attention and MLP. For the initialization of weights, I follow the rules of Encoder-Only Architecture in paper. Did I miss something?
For vision tasks, we empirically found that subln is better in performance than postln. Moreover, it has a same theoretical guarantee of training stability as in DeepNet. So we suggest using subln rather than deepnorm for the vision tasks.
Hi Shuming, I implement subln on BeiTv1 with the training setting in the appendix. While I didn't find the magnitude of each block's feature map hold constant. On the contrary, the magnitude blows up as the training goes. If I build an extremely deep model like the one in DeepNet, the magnitude of feature will soon exceed fp16‘s upper bound. Same phenomena is also found in MAE and supervised training of ViT-Base.
As this repo does not provide examples of BeiT or other vision models, my implementation is based on unilm/beit and I add sub-ln to Attention and MLP. For the initialization of weights, I follow the rules of Encoder-Only Architecture in paper. Did I miss something?
This may be because of some issues in implementation. Could you provide the code for integrating subln to the BEiT codebase?
Yes.
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., subln=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
########################### SubLN ################
self.subln = nn.LayerNorm(hidden_features) if subln else nn.Identity()
################################# ################
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the orignal BERT implement
########################### SubLN ################
x = self.subln(x)
########################### SubLN ################
x = self.fc2(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
proj_drop=0., window_size=None, attn_head_dim=None, subln = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
if window_size:
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = \
torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index)
else:
self.window_size = None
self.relative_position_bias_table = None
self.relative_position_index = None
self.attn_drop = nn.Dropout(attn_drop)
self.out_proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
########################### SubLN ################
self.subln = nn.LayerNorm(all_head_dim) if subln else nn.Identity()
#################################################
def forward(self, x, rel_pos_bias=None):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if self.relative_position_bias_table is not None:
relative_position_bias = \
self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1,
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if rel_pos_bias is not None:
attn = attn + rel_pos_bias
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
########################### SubLN ################
x = self.subln(x)
########################### SubLN ################
x = self.out_proj(x)
x = self.proj_drop(x)
return x
class VisionTransformerForMaskedImageModeling(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, vocab_size=8192, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=None, init_values=None, attn_head_dim=None,
use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, init_std=0.02, **kwargs):
super().__init__()
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.total_depth = depth
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
if use_abs_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
else:
self.pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
if use_shared_rel_pos_bias:
self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
else:
self.rel_pos_bias = None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None,
attn_head_dim=attn_head_dim,
)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
self.init_std = init_std
self.lm_head = nn.Linear(embed_dim, vocab_size)
if self.pos_embed is not None:
trunc_normal_(self.pos_embed, std=self.init_std)
trunc_normal_(self.cls_token, std=self.init_std)
trunc_normal_(self.mask_token, std=self.init_std)
trunc_normal_(self.lm_head.weight, std=self.init_std)
self.apply(self._init_weights)
# self.fix_init_weight()
########################### SubLN ################
self._rescale_weights()
########################### SubLN ################
def fix_init_weight(self):
def rescale(param, layer_id):
param.div_(math.sqrt(2.0 * layer_id))
for layer_id, layer in enumerate(self.blocks):
rescale(layer.attn.proj.weight.data, layer_id + 1)
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
########################### SubLN ################
def _rescale_weights(self,):
init_scale = math.sqrt(math.log(self.total_depth * 2))
v_init_scale = torch.ones(3*self.embed_dim, self.embed_dim)
v_init_scale[-self.embed_dim:, :]*=init_scale
for name, p in self.named_parameters():
if (
"fc1" in name
or "fc2" in name
or "out_proj" in name
):
p.data.mul_(init_scale)
elif "qkv" in name:
p.data.mul(v_init_scale)
########################### SubLN ################
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=self.init_std)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=self.init_std)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_num_layers(self):
return len(self.blocks)
def forward_features(self, x, bool_masked_pos):
x = self.patch_embed(x, bool_masked_pos=bool_masked_pos)
batch_size, seq_len, _ = x.size()
cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
mask_token = self.mask_token.expand(batch_size, seq_len, -1)
# replace the masked visual tokens by mask_token
w = bool_masked_pos.unsqueeze(-1).type_as(mask_token)
x = x * (1 - w) + mask_token * w
x = torch.cat((cls_tokens, x), dim=1)
if self.pos_embed is not None:
x = x + self.pos_embed
x = self.pos_drop(x)
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
for blk in self.blocks:
x = blk(x, rel_pos_bias=rel_pos_bias)
return self.norm(x)
def forward(self, x, bool_masked_pos, return_all_tokens=False):
x = self.forward_features(x, bool_masked_pos=bool_masked_pos)
x = x[:, 1:]
if return_all_tokens:
return self.lm_head(x)
else:
# return the masked tokens
return self.lm_head(x[bool_masked_pos])
I comment out the fix_init_weight
in init and disable the layer_scale (gamma) in Block when conducting experiments.
Does the post layernorm and scaling in residual branch and initialization in DeepNet also support vision tasks, like ImageNet classification and mask image modeling?