Stability-AI / stablediffusion

High-Resolution Image Synthesis with Latent Diffusion Models
MIT License
37.84k stars 4.88k forks source link

What's the idea behind asymmetric padding during downsampling? #355

Open Arksyd96 opened 6 months ago

Arksyd96 commented 6 months ago

What's the idea behind this down-sampling with asymmetric padding : Why don't we just use a symmetric padding of 1, everything would fit perfectly.

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x
ChaosAdmStudent commented 4 months ago

Did you get an answer to this question? I have been wondering myself.

Arksyd96 commented 4 months ago

No, unfortunately not yet !

tian-2024 commented 3 months ago

What's the idea behind this down-sampling with asymmetric padding : Why don't we just use a symmetric padding of 1, everything would fit perfectly.

class Downsample(nn.Module):
    def __init__(self, in_channels, with_conv):
        super().__init__()
        self.with_conv = with_conv
        if self.with_conv:
            # no asymmetric padding in torch conv, must do it ourselves
            self.conv = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=3,
                                        stride=2,
                                        padding=0)

    def forward(self, x):
        if self.with_conv:
            pad = (0,1,0,1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.conv(x)
        else:
            x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
        return x

I also wonder why it is.