Closed RaptorMai closed 1 year ago
Hi @RaptorMai ,
To use FacT-TT on Swin-B, please modify the fact_forward_attn
, fact_forward_mlp
, and set_FacT
functions.
def fact_forward_attn(self, x, mask):
B_, N, C = x.shape
idx = int(math.log(C / 128, 2))
FacTu = vit.FacTu[idx]
FacTv = vit.FacTv[idx]
qkv = self.qkv(x)
q = FacTv(self.dp(self.q_FacTs(FacTu(x))))
k = FacTv(self.dp(self.k_FacTs(FacTu(x))))
v = FacTv(self.dp(self.v_FacTs(FacTu(x))))
qkv += torch.cat([q, k, v], dim=2) * self.s
qkv = qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
proj = self.proj(x)
proj += FacTv(self.dp(self.proj_FacTs(FacTu(x)))) * self.s
x = self.proj_drop(proj)
return x
def fact_forward_mlp(self, x):
B, N, C = x.shape
idx = int(math.log(C / 128, 2))
FacTu = vit.FacTu[idx]
FacTv = vit.FacTv[idx]
h = self.fc1(x) # B n 4c
# print(x.size(), h.size())
h += FacTv(self.dp(self.fc1_FacTs(FacTu(x))).reshape(
B, N, 4, self.dim)).reshape(
B, N, 4 * C) * self.s
x = self.act(h)
x = self.drop(x)
h = self.fc2(x)
x = x.reshape(B, N, 4, C)
h += FacTv(self.dp(self.fc2_FacTs(FacTu(x).reshape(
B, N, 4 * self.dim)))) * self.s
x = self.drop(h)
return x
def set_FacT(model, dim=8, s=1):
if type(model) == timm.models.swin_transformer.SwinTransformer:
model.FacTu = [nn.Linear(128 * expand, dim, bias=False) for expand in [1, 2, 4, 8]]
model.FacTv = [nn.Linear(dim, 128 * expand, bias=False) for expand in [1, 2, 4, 8]]
for weight in model.FacTv:
nn.init.zeros_(weight.weight)
model.FacTu = nn.ModuleList(model.FacTu)
model.FacTv = nn.ModuleList(model.FacTv)
for _ in model.children():
if type(_) == timm.models.swin_transformer.WindowAttention:
_.q_FacTs = nn.Linear(dim, dim, bias=False)
_.k_FacTs = nn.Linear(dim, dim, bias=False)
_.v_FacTs = nn.Linear(dim, dim, bias=False)
_.proj_FacTs = nn.Linear(dim, dim, bias=False)
_.dp = nn.Dropout(0.1)
_.s = s
# _.dim = dim
bound_method = fact_forward_attn.__get__(_, _.__class__)
setattr(_, 'forward', bound_method)
elif type(_) == timm.models.layers.mlp.Mlp:
_.fc1_FacTs = nn.Linear(dim, dim * 4, bias=False)
_.fc2_FacTs = nn.Linear(4 * dim, dim, bias=False)
_.dim = dim
_.s = s
_.dp = nn.Dropout(0.1)
bound_method = fact_forward_mlp.__get__(_, _.__class__)
setattr(_, 'forward', bound_method)
elif len(list(_.children())) != 0:
set_FacT(_, dim, s)
Load the pre-trained weights by
if __name__ == '__main__':
...
vit = create_model('swin_base_patch4_window7_224_in22k', drop_path_rate=0.1)
vit.load_state_dict(torch.load('<path>/swin_base_patch4_window7_224_22k.pth')['model'])
...
Thank you very much for sharing the code. I am curious if you will share the FacT for Swin in a near future. Look forward to hearing back from you.