bigmb / Unet-Segmentation-Pytorch-Nest-of-Unets

Implementation of different kinds of Unet Models for Image Segmentation - Unet , RCNN-Unet, Attention Unet, RCNN-Attention Unet, Nested Unet
MIT License
1.87k stars 345 forks source link

r2unet #67

Open huasheng76 opened 5 months ago

huasheng76 commented 5 months ago

Hello, my R2UNet's training performance is very poor, worse than UNet. Do you have the same problem, and how can it be resolved?

bigmb commented 5 months ago

Using R2Unet can go wrong as recurrent blocks do not learn that well in some cases. Try running the nested Unet model and check the results.

huasheng76 commented 5 months ago

Using R2Unet can go wrong as recurrent blocks do not learn that well in some cases. Try running the nested Unet model and check the results.

Okay, I'm trying to combine nested UNet and attention UNet to see how it performs. Are you interested in writing a piece of code that combines them?

bigmb commented 5 months ago

Won't be addding new features in this code now. But if you add a new model send me pull request.

zhibaishouheilab commented 1 month ago

`class Recurrent_block(nn.Module): """ Recurrent Block for R2Unet_CNN """ def init(self, out_ch, t=2): super(Recurrent_block, self).init()

    self.t = t
    self.out_ch = out_ch
    self.conv = nn.Sequential(
        nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True)
    )

def forward(self, x):
    for i in range(self.t):
        if i == 0:
            x = self.conv(x)
        out = self.conv(x + x)
    return out`

Here may be an error: x1 = self.conv(x) out = self.conv(x + x1)