For those who experience slow inference and hence slow training of XLSR module, you can try the following custom module instead of built-in pytorch block for group conv
class GConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=3,groups=4):
super(GConv2d,self).__init__()
self.conv2d_block=nn.ModuleList()
self.groups=groups
for _ in range(groups):
self.conv2d_block.append(nn.Conv2d(in_channels=in_channels//groups,out_channels=out_channels//groups,kernel_size=kernel_size,padding=kernel_size//2))
def forward(self,x):
return torch.cat([filterg(xg) for filterg,xg in zip(self.conv2d_block,torch.chunk(x,self.groups,1))],dim=1)
Personally I can experience almost x2 speedup with this approach during training
For those who experience slow inference and hence slow training of XLSR module, you can try the following custom module instead of built-in pytorch block for group conv
change line in GBlock class
self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, groups=groups)
with
self.conv0 = GConv2d(in_channels,out_channels,kernel_size=3,groups=groups)
where
GConv2d is
Personally I can experience almost x2 speedup with this approach during training
*groups parameter is known to be resulting in a slow code accoding to [https://github.com/pytorch/pytorch/issues/18631]