LeeJunHyun / Image_Segmentation

Pytorch implementation of U-Net, R2U-Net, Attention U-Net, and Attention R2U-Net.
2.75k stars 602 forks source link

Attention gate after upsampling #74

Closed Glaadiss closed 1 year ago

Glaadiss commented 3 years ago

Hey, in this paper https://arxiv.org/pdf/1804.03999.pdf and other tutorials the attention gate takes the gating signal from the same block that is being unsampled in the next step, whereas in the code in this repo the gating signal comes from the block, which is already unsampled. Is this correct and I'm missing something, or this is a mistake?

how it is now:

  # AttU_net in forward func
  d5 = self.Up5(x5)
  x4 = self.Att5(g=d5,x=x4)
  d5 = torch.cat((x4,d5),dim=1)        
  d5 = self.Up_conv5(d5)

how it should be based on the paper:

  # AttU_net in forward func
  x4 = self.Att5(g=d5,x=x4)
  d5 = self.Up5(x5)
  d5 = torch.cat((x4,d5),dim=1)        
  d5 = self.Up_conv5(d5)
GivralNguyen commented 3 years ago

Yea i'm finding this really strange too . have you figure out the problem? is this code wrong?

Glaadiss commented 3 years ago

I implemented the architecture in the way it's described in the paper, so I'll try to create PR fixing that in my free time.

Lloyd-Pottiger commented 3 years ago

According to the offical code of the paper: https://github.com/ozan-oktay/Attention-Gated-Networks the gating signal is mainly:

nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)),
      nn.BatchNorm3d(out_size),
      nn.ReLU(inplace=True),
)

And then input into Attention Gate, in order to match the size of encoding layer, we need to do upsampling of the output of gating signal. Therefore, this version just minus the one conv which called gating sigal. I think it is not necessary but can reduce a lot of calculations. I will make a compare when I am free.

Ankan1998 commented 2 years ago

Yes, I was also going through the attention Unet code, they are not in coherence with paper

unetattn

The gating signal channel dimension should be double of skipping and the feature map should be halved. But here It is implemented in a different way

HulahaVla commented 2 years ago

yeah, an ye yi yang