ma-xu / Rewrite-the-Stars

[CVPR 2024] Rewrite the Stars
Apache License 2.0
287 stars 15 forks source link

About the Star Opt #9

Closed tarv33 closed 6 months ago

tarv33 commented 6 months ago

if replace two conv with one conv, it might be a litte faster

class Block(nn.Module):
    def __init__(self, dim, mlp_ratio=3, drop_path=0.0):
        super().__init__()
        self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True)

        # self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        # self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False)
        # if replace two conv with one conv, it might be a litte faster
        self.f = ConvBN(dim, mlp_ratio * dim * 2, 1, groups=2, with_bn=False)

        self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True)
        self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False)
        self.act = nn.ReLU6()
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.c = mlp_ratio * dim

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        # x1, x2 = self.f1(x), self.f2(x)
        gx = self.f(x)
        x1, x2 = torch.split(gx, self.c, dim=1)
        x = self.act(x1) * x2
        x = self.dwconv2(self.g(x))
        x = input + self.drop_path(x)
        return x
ma-xu commented 6 months ago

@tarv33 Thanks for your kind suggestions.

However, this could be hardware-platform aware. This implementation could be faster on GPU but slower on iPhone devices (due to split operation if I remember correctly).