Open Jacky-Android opened 11 months ago
It is recommended that you write the UNetV2 class like this to support more backbones, timm==0.9.12
import timm class UNetV2(nn.Module): """ use SpatialAtt + ChannelAtt """ def __init__(self, channel=32, n_classes=1, deep_supervision=True, backbone ='pvt_v2_b2',pretrained=False): super().__init__() self.deep_supervision = deep_supervision self.encoder = timm.create_model(backbone,pretrained=pretrained,features_only=True,out_indices=(0,1,2,3)) channel1,channel2,channel3,channel4 = self.encoder.feature_info.channels() self.ca_1 = ChannelAttention(channel1) self.sa_1 = SpatialAttention() self.ca_2 = ChannelAttention(channel2) self.sa_2 = SpatialAttention() self.ca_3 = ChannelAttention(channel3) self.sa_3 = SpatialAttention() self.ca_4 = ChannelAttention(channel4) self.sa_4 = SpatialAttention() self.Translayer_1 = BasicConv2d(channel1, channel, 1) self.Translayer_2 = BasicConv2d(channel2, channel, 1) self.Translayer_3 = BasicConv2d(channel3, channel, 1) self.Translayer_4 = BasicConv2d(channel4, channel, 1) self.sdi_1 = SDI(channel) self.sdi_2 = SDI(channel) self.sdi_3 = SDI(channel) self.sdi_4 = SDI(channel) self.seg_outs = nn.ModuleList([ nn.Conv2d(channel, n_classes, 1, 1)] * 4) self.deconv2 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, bias=False) self.deconv3 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, bias=False) self.deconv4 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, bias=False) self.deconv5 = nn.ConvTranspose2d(channel, channel, kernel_size=4, stride=2, padding=1, bias=False)
Thanks for your valuable recommendation.
I will definitely modify it later and make it more general, especially for 3D volumes.
It is recommended that you write the UNetV2 class like this to support more backbones, timm==0.9.12