Closed tarv33 closed 6 months ago
if replace two conv with one conv, it might be a litte faster
class Block(nn.Module): def __init__(self, dim, mlp_ratio=3, drop_path=0.0): super().__init__() self.dwconv = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=True) # self.f1 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) # self.f2 = ConvBN(dim, mlp_ratio * dim, 1, with_bn=False) # if replace two conv with one conv, it might be a litte faster self.f = ConvBN(dim, mlp_ratio * dim * 2, 1, groups=2, with_bn=False) self.g = ConvBN(mlp_ratio * dim, dim, 1, with_bn=True) self.dwconv2 = ConvBN(dim, dim, 7, 1, (7 - 1) // 2, groups=dim, with_bn=False) self.act = nn.ReLU6() self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.c = mlp_ratio * dim def forward(self, x): input = x x = self.dwconv(x) # x1, x2 = self.f1(x), self.f2(x) gx = self.f(x) x1, x2 = torch.split(gx, self.c, dim=1) x = self.act(x1) * x2 x = self.dwconv2(self.g(x)) x = input + self.drop_path(x) return x
@tarv33 Thanks for your kind suggestions.
However, this could be hardware-platform aware. This implementation could be faster on GPU but slower on iPhone devices (due to split operation if I remember correctly).
if replace two conv with one conv, it might be a litte faster