JieShibo / PETL-ViT

[ICCV 2023 & AAAI 2023] Binary Adapters & FacT, [Tech report] Convpass
MIT License
168 stars 8 forks source link

FacT for Swin #9

Closed RaptorMai closed 1 year ago

RaptorMai commented 1 year ago

Thank you very much for sharing the code. I am curious if you will share the FacT for Swin in a near future. Look forward to hearing back from you.

JieShibo commented 1 year ago

Hi @RaptorMai , To use FacT-TT on Swin-B, please modify the fact_forward_attn, fact_forward_mlp, and set_FacT functions.

def fact_forward_attn(self, x, mask):
    B_, N, C = x.shape

    idx = int(math.log(C / 128, 2))
    FacTu = vit.FacTu[idx]
    FacTv = vit.FacTv[idx]
    qkv = self.qkv(x)
    q = FacTv(self.dp(self.q_FacTs(FacTu(x))))
    k = FacTv(self.dp(self.k_FacTs(FacTu(x))))
    v = FacTv(self.dp(self.v_FacTs(FacTu(x))))

    qkv += torch.cat([q, k, v], dim=2) * self.s
    qkv = qkv.reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

    q = q * self.scale
    attn = (q @ k.transpose(-2, -1))

    relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
        self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nH
    relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww
    attn = attn + relative_position_bias.unsqueeze(0)

    if mask is not None:
        nW = mask.shape[0]
        attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
        attn = attn.view(-1, self.num_heads, N, N)
        attn = self.softmax(attn)
    else:
        attn = self.softmax(attn)

    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
    proj = self.proj(x)
    proj += FacTv(self.dp(self.proj_FacTs(FacTu(x)))) * self.s
    x = self.proj_drop(proj)
    return x

def fact_forward_mlp(self, x):
    B, N, C = x.shape
    idx = int(math.log(C / 128, 2))
    FacTu = vit.FacTu[idx]
    FacTv = vit.FacTv[idx]
    h = self.fc1(x)  # B n 4c
    # print(x.size(), h.size())
    h += FacTv(self.dp(self.fc1_FacTs(FacTu(x))).reshape(
        B, N, 4, self.dim)).reshape(
        B, N, 4 * C) * self.s
    x = self.act(h)
    x = self.drop(x)
    h = self.fc2(x)
    x = x.reshape(B, N, 4, C)
    h += FacTv(self.dp(self.fc2_FacTs(FacTu(x).reshape(
        B, N, 4 * self.dim)))) * self.s
    x = self.drop(h)
    return x

def set_FacT(model, dim=8, s=1):
    if type(model) == timm.models.swin_transformer.SwinTransformer:
        model.FacTu = [nn.Linear(128 * expand, dim, bias=False) for expand in [1, 2, 4, 8]]
        model.FacTv = [nn.Linear(dim, 128 * expand, bias=False) for expand in [1, 2, 4, 8]]
        for weight in model.FacTv:
            nn.init.zeros_(weight.weight)
        model.FacTu = nn.ModuleList(model.FacTu)
        model.FacTv = nn.ModuleList(model.FacTv)
    for _ in model.children():
        if type(_) == timm.models.swin_transformer.WindowAttention:
            _.q_FacTs = nn.Linear(dim, dim, bias=False)
            _.k_FacTs = nn.Linear(dim, dim, bias=False)
            _.v_FacTs = nn.Linear(dim, dim, bias=False)
            _.proj_FacTs = nn.Linear(dim, dim, bias=False)
            _.dp = nn.Dropout(0.1)
            _.s = s
            # _.dim = dim
            bound_method = fact_forward_attn.__get__(_, _.__class__)
            setattr(_, 'forward', bound_method)
        elif type(_) == timm.models.layers.mlp.Mlp:
            _.fc1_FacTs = nn.Linear(dim, dim * 4, bias=False)
            _.fc2_FacTs = nn.Linear(4 * dim, dim, bias=False)
            _.dim = dim
            _.s = s
            _.dp = nn.Dropout(0.1)
            bound_method = fact_forward_mlp.__get__(_, _.__class__)
            setattr(_, 'forward', bound_method)
        elif len(list(_.children())) != 0:
            set_FacT(_, dim, s)

Load the pre-trained weights by

if __name__ == '__main__':
    ...
    vit = create_model('swin_base_patch4_window7_224_in22k', drop_path_rate=0.1)
    vit.load_state_dict(torch.load('<path>/swin_base_patch4_window7_224_22k.pth')['model'])
    ...