K-H-Ismail / Dilated-Convolution-with-Learnable-Spacings-PyTorch

[ICLR 2023] "Dilated convolution with learnable spacings" Ismail Khalfaoui Hassani, Thomas Pellegrini and Timothée Masquelier
MIT License
57 stars 5 forks source link

How can I modify the program to accommodate non-square kernels? #5

Closed Lutong-Qin closed 5 months ago

Lutong-Qin commented 10 months ago

Hello, Thank you very much for your excellent work. I want to build a Dcls2d convolution with non-square kernels, unequal stride, padding, and dilation. However, it seems like 'kernel_count' only supports 'int'. Is it feasible to modify it for non-square kernels?

non-square kernels and unequal stride and with padding and dilation

import torch from DCLS.construct.modules import Dcls2d Dcls2d(16, 33, kernel_count=(3, 1), dilated_kernel_size=(7, 1), stride=(2, 1),padding= (1, 0)) input = torch.randn(20, 16, 50, 100) output = m(input) loss = output.sum() loss.backward() print(output, m.weight.grad, m.P.grad)

The anticipation for your response is greatly appreciated!

K-H-Ismail commented 10 months ago

Hello, Actually, the kernel count is the number of kernel elements inside the kernel. Kernel count should be an int greater or equal to 1. The dilated kernel size controls the limits of the kernel. So if you want a rectangular kernel you should do dilated_kernel_size =(5, 7) for example. Then you mention the number of kernel elements inside it by doing kernel_count = 10 for example

K-H-Ismail commented 10 months ago

In your example, the problem is more with giving a dimension of 1 to the width of dilated kernel size. (7, 1) Is it really your use case ?

Lutong-Qin commented 10 months ago

Hello, I actually want to replace the nn.Conv2d layers in my original model with your Dcls2d. I may have misunderstood the 'kernel_count'; initially, I thought it was equivalent to 'kernel_size'. In my original network model, the convolutional layer is defined as follows: nn.Conv2d(16, 33, kernel_size=(3, 1), dilation=(7, 1), stride=(2, 1), padding=(1, 0)). Now, I think I understand that this convolutional layer should be equivalent to: Dcls2d(16, 33, kernel_count=3, dilated_kernel_size=(7, 1), stride=(2, 1), padding=(1, 0)). Am I correct in this interpretation? The convolutional kernels in my model are indeed elongated, with a kernel size of (7, 1).

tmasquelier commented 10 months ago

Hi @Lutong-Qin First of all, thanks a lot for your interest in our work! I think you should increase the dilated_kernel_size. You should think in terms of receptive field (RF). Your Conv2d layer with kernel_size=(3, 1) and dilation=(7, 1) has a RF size of (15,1), so I would use at least that for the dilated_kernel_size. The kernel_count is a free parameter. For the stride, I would use the same as in your Conv2d layer (stride=(2, 1)) You may want to increase the padding to have the same output size. I hope this helps. Do not hesitate if you have other questions!

K-H-Ismail commented 10 months ago

@Lutong-Qin I've released a new method called Dcls2dK1d which is more appropriate for your use case of flat convolutions as it avoids to create positions across the non-flat dimension. Here is a use case:

import torch
from DCLS.construct.modules import  Dcls2dK1d

m = Dcls2dK1d(
    out_channels=32,
    in_channels=32,
    kernel_count=3,
    dilated_kernel_size=11,
    flat_dim=0, # the flat dimensions dimension, here it is equivalent to torch.nn.Conv2d with kernel_size=(1,11)
    groups=1,
    padding=(0, 11 // 2),
)
# The last dimension of the input is always where positions are learned
input = torch.randn(8, 32, 56, 56)
output = m(input)
loss = output.sum()
loss.backward()
print(output.size(), m.weight.grad.size(), m.P.grad.size())

You can either use this one which is cleaner or Dcls2d.

K-H-Ismail commented 10 months ago

Also, to make your life easier, here is a snipet that permits to change all the flat 2D convs (1,7) and (7,1) in a model by the suitable Dcls2dK1d. This code synchronizes positions along h and W separately for all convolutions inside the same stage. We have tried that on ParCNetv2 model that makes use of flat convs.

import copy
# Helper function that replaces all ".int." patterns
# by "[int]" in a character string
def replace_dots_brackets(name):
    name_split = name.split(".")
    name_split = ["[" + i + "]" if i.isdigit() else "." + i for i in name_split]
    return "".join(name_split[:-1]), name_split[-1][1:]

# Helper function that replaces all the
# 2D depthwise separable convolution in
# a model by synchronized Dcls2d ones
def replace_depthwise_dcls(model, version="gauss"):
    in_channelsh, in_channelsw, Ph, SIGh, Pw, SIGw, = 0, 0, None, None, None, None
    # Loop over all model modules
    for name, module in model.named_modules():
        # if the module is a depthwise separable Conv2d module
        if (
            isinstance(module, nn.Conv2d)
            and module.groups == module.in_channels == module.out_channels
            and len(module.kernel_size) == 2
            and min(module.kernel_size) == 1
        ):
            print(name)
            name_eval, last_layer = replace_dots_brackets(name)
            dilated_kernel_size = max(module.kernel_size)
            dcls_conv = Dcls2dK1d(
                module.in_channels,
                module.out_channels,
                kernel_count=dilated_kernel_size // 3,
                dilated_kernel_size=dilated_kernel_size,
                padding=module.padding,
                groups=module.groups,
                version=version,
                bias=module.bias is not None,
                flat_dim=module.kernel_size.index(dilated_kernel_size),
            )

            nn.init.normal_(dcls_conv.weight, std=0.02)
            if module.bias is not None:
                nn.init.constant_(dcls_conv.bias, 0)

            # Synchronise positions and standard
            # deviations belonging to the same stage
            if in_channelsh < module.in_channels and module.kernel_size[0] == 1:
                in_channelsh = module.in_channels
                Ph, SIGh = dcls_conv.P, dcls_conv.SIG

            if in_channelsw < module.in_channels and module.kernel_size[-1] == 1:
                in_channelsw = module.in_channels
                Pw, SIGw = dcls_conv.P, dcls_conv.SIG                

            if module.kernel_size[0] == 1:
                dcls_conv.P, dcls_conv.SIG = Ph, SIGh

            if module.kernel_size[-1] == 1:
                dcls_conv.P, dcls_conv.SIG = Pw, SIGw

            setattr(eval("model" + name_eval), last_layer, dcls_conv)
    return model
@register_model
def parcnetv2_dcls_tiny(pretrained=False, **kwargs):
    model = ParCNetV2(
        depths=[3, 3, 12, 3],
        dims=[64, 128, 320, 512],
        token_mixers=ParC_V2,
        mlps=partial(BGU, mlp_ratio=5),
        **kwargs,
    )
    model = replace_depthwise_dcls(copy.deepcopy(model), version="gauss")
    return model
Lutong-Qin commented 10 months ago

I see, 'dilated_kernel_size' refers to the concept of an 'effective kernel size for dilated convolutions'. The following two are equivalent in form.

   ```class CNN(nn.Module):
def __init__(self, input_channel, num_classes):
    super(CNN, self).__init__()

    self.layer1 = self._make_layers(input_channel, 64, (3,1), (2,1), (1,0), (3,1))
    self.layer2 = self._make_layers(64, 128, (3,1), (2,1), (1,0), (3,1))
    self.layer3 = self._make_layers(128, 256, (3,1), (2,1), (1,0), (3,1))
    self.fc = nn.Linear(256*4*1, num_classes)

def _make_layers(self, input_channel, output_channel, kernel_size, stride, padding, dilation):
    return nn.Sequential(
        Dcls2d(input_channel, output_channel, kernel_count=3,
               dilated_kernel_size=((kernel_size[0]+(kernel_size[0]-1)*(dilation[0]-1)),(kernel_size[1]+(kernel_size[1]-1)*(dilation[1]-1))),
               stride=stride,
               padding=padding),
        nn.BatchNorm2d(output_channel),
        nn.ReLU(inplace=True)
    )

def forward(self, x):
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = F.max_pool2d(x, (4,3))
    x = x.view(x.size(0), -1)
    out = self.fc(x)
    return out

        ```class CNN(nn.Module):
    def __init__(self, input_channel, num_classes):
        super(CNN, self).__init__()

        self.layer1 = self._make_layers(input_channel, 64, (3,1), (2,1), (1,0), (3,1))
        self.layer2 = self._make_layers(64, 128, (3,1), (2,1), (1,0), (3,1))
        self.layer3 = self._make_layers(128, 256, (3,1), (2,1), (1,0), (3,1))
        self.fc = nn.Linear(256*4*1, num_classes)

    def _make_layers(self, input_channel, output_channel, kernel_size, stride, padding, dilation):
        return nn.Sequential(
            Dcls2d(input_channel, output_channel, kernel_count=3,
                   dilated_kernel_size=((kernel_size[0]+(kernel_size[0]-1)*(dilation[0]-1)),(kernel_size[1]+(kernel_size[1]-1)*(dilation[1]-1))),
                   stride=stride,
                   padding=padding),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = F.max_pool2d(x, (4,3))
        x = x.view(x.size(0), -1)
        out = self.fc(x)
        return out```