yaoppeng / U-Net_v2

211 stars 18 forks source link

Small object segmentation result may not good, how to enhance? #11

Closed lhwcv closed 9 months ago

lhwcv commented 9 months ago

Thank you for your work, But I found that Unet_V2 is not better than Unet full resolution baseline in small object segmentation, for example in vessel segmentation task. How to modify the encoder to output full resolution instead of 4X downsampled, could you give a example, so we can compare more result on different task?

Thank you again!

yaoppeng commented 9 months ago

You need to avoid too many downsamplings. In my dataset, the target objects are relatively large, so I utilized the PVT encoder to downsample the input four times initially, aiming to reduce computation. For example, when employing U-Net, the network definition would be as follows:

import torch
from torch import nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(1e-2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(1e-2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(1e-2)
        )

    def forward(self, x):
        f1 = self.conv1(x)

        f2 = self.conv2(f1)

        f3 = self.conv3(f2)

        return f1, f2, f3

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = Encoder()

        self.decoder = YourDecoder()

        self.sdi1 = SDI()
        self.sdi2 = SDI()
        self.sdi3 = SDI()

    def forward(self, x):
        f1, f2, f3 = self.encoder()

        .....

        f1 = self.sdi(f1, f2, f3)
        f2 = self.sdi(f1, f2, f3)
        f3 = self.sdi(f1, f2, f3)

        seg_outs = self.decoder(f1, f2, f3)

        if deep_supervision:
            return seg_outs[::-1]
        else:
            return seg_outs[-1]

When using the encoder in practice:

model = Encoder()
f1, f2, f3 = self.model(torch.rand(2, 1, 256, 256))

You would obtain: f1: (2, 32, 256, 256), f2: (2, 64, 128, 128), f3: (2, 128, 64, 64). As observed, f1 maintains the same resolution as the input image, f2 is downsampled by a factor of 2, and f3 is downsampled by a factor of 4. This strategy ensures the output retains full resolution instead of a 4x downsampled version. The process aligns with the original U-Net.

You may also consider experimenting with the VGG encoder.

I am curious about the result. Let me know if you got any.

lhwcv commented 9 months ago

Thank you for your guide.

SwimmingLiu commented 9 months ago

You need to avoid too many downsamplings. In my dataset, the target objects are relatively large, so I utilized the PVT encoder to downsample the input four times initially, aiming to reduce computation. For example, when employing U-Net, the network definition would be as follows:

import torch
from torch import nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(32, 32, 3, 1, 1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(1e-2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(1e-2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 3, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(1e-2),
            nn.Conv2d(128, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(1e-2)
        )

    def forward(self, x):
        f1 = self.conv1(x)

        f2 = self.conv2(f1)

        f3 = self.conv3(f2)

        return f1, f2, f3

class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = Encoder()

        self.decoder = YourDecoder()

        self.sdi1 = SDI()
        self.sdi2 = SDI()
        self.sdi3 = SDI()

    def forward(self, x):
        f1, f2, f3 = self.encoder()

        .....

        f1 = self.sdi(f1, f2, f3)
        f2 = self.sdi(f1, f2, f3)
        f3 = self.sdi(f1, f2, f3)

        seg_outs = self.decoder(f1, f2, f3)

        if deep_supervision:
            return seg_outs[::-1]
        else:
            return seg_outs[-1]

When using the encoder in practice:

model = Encoder()
f1, f2, f3 = self.model(torch.rand(2, 1, 256, 256))

You would obtain: f1: (2, 32, 256, 256), f2: (2, 64, 128, 128), f3: (2, 128, 64, 64). As observed, f1 maintains the same resolution as the input image, f2 is downsampled by a factor of 2, and f3 is downsampled by a factor of 4. This strategy ensures the output retains full resolution instead of a 4x downsampled version. The process aligns with the original U-Net.

You may also consider experimenting with the VGG encoder.

I am curious about the result. Let me know if you got any.

Hello, Sir. I wonder if there is difference in two downsamping ways that "Conv with stride=2" and "MaxPool"? And "avoid too many downsamplings" is equal to "make Unet shallower" ?