cxzhou95 / XLSR

PyTorch implementation of paper "Extremely Lightweight Quantization Robust Real-Time Single-Image Super Resolution for Mobile Devices"
MIT License
56 stars 9 forks source link

Unoptimized Pytorch group conv2d block #7

Closed deepernewbie closed 2 years ago

deepernewbie commented 2 years ago

For those who experience slow inference and hence slow training of XLSR module, you can try the following custom module instead of built-in pytorch block for group conv

change line in GBlock class

self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)

with

self.conv0 = GConv2d(in_channels,out_channels,kernel_size=3,groups=groups)

where

GConv2d is

class GConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3,groups=4):
        super(GConv2d,self).__init__()
        self.conv2d_block=nn.ModuleList()
        self.groups=groups
        for _ in range(groups):
            self.conv2d_block.append(nn.Conv2d(in_channels=in_channels//groups,out_channels=out_channels//groups,kernel_size=kernel_size,padding=kernel_size//2))

    def forward(self,x):
        return torch.cat([filterg(xg) for filterg,xg in zip(self.conv2d_block,torch.chunk(x,self.groups,1))],dim=1)

Personally I can experience almost x2 speedup with this approach during training

*groups parameter is known to be resulting in a slow code accoding to [https://github.com/pytorch/pytorch/issues/18631]