yaoppeng / U-Net_v2

214 stars 20 forks source link

It is recommended that you write like this to support more backbones #7

Open Jacky-Android opened 11 months ago

Jacky-Android commented 11 months ago

It is recommended that you write the UNetV2 class like this to support more backbones, timm==0.9.12

import timm
class UNetV2(nn.Module):
    """
    use SpatialAtt + ChannelAtt
    """
    def __init__(self, channel=32, n_classes=1, deep_supervision=True, backbone ='pvt_v2_b2',pretrained=False):
        super().__init__()
        self.deep_supervision = deep_supervision

        self.encoder = timm.create_model(backbone,pretrained=pretrained,features_only=True,out_indices=(0,1,2,3))

        channel1,channel2,channel3,channel4  = self.encoder.feature_info.channels()

        self.ca_1 = ChannelAttention(channel1)
        self.sa_1 = SpatialAttention()

        self.ca_2 = ChannelAttention(channel2)
        self.sa_2 = SpatialAttention()

        self.ca_3 = ChannelAttention(channel3)
        self.sa_3 = SpatialAttention()

        self.ca_4 = ChannelAttention(channel4)
        self.sa_4 = SpatialAttention()

        self.Translayer_1 = BasicConv2d(channel1, channel, 1)
        self.Translayer_2 = BasicConv2d(channel2, channel, 1)
        self.Translayer_3 = BasicConv2d(channel3, channel, 1)
        self.Translayer_4 = BasicConv2d(channel4, channel, 1)

        self.sdi_1 = SDI(channel)
        self.sdi_2 = SDI(channel)
        self.sdi_3 = SDI(channel)
        self.sdi_4 = SDI(channel)

        self.seg_outs = nn.ModuleList([
            nn.Conv2d(channel, n_classes, 1, 1)] * 4)

        self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1,
                                          bias=False)
        self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
                                          padding=1, bias=False)
        self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
                                          padding=1, bias=False)
        self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2,
                                          padding=1, bias=False)
yaoppeng commented 11 months ago

Thanks for your valuable recommendation.

I will definitely modify it later and make it more general, especially for 3D volumes.