Closed HenryLyu closed 1 year ago
Hello,
Thank you for using our library!
The gating signal is composed of the skip connection features + upsampled features, not just of one of them. Maybe that image is not very clear, as the dashed line representing the skip connection and the green arrow end in the gate function. Maybe you can see ibetter that they use both inputs in the figure 2 of the original paper:
Notice how we sum both inputs also with this: g1_x1 = Add()([g1,x1])
.
You have here another implementation that makes like us the attention gate:
Then, the output of the attention module, e.g. AttentionBlock
function in our code, is concatenated with the upsampled features:
Best regards,
Dani
Thank you very much for your detailed explanation! I also check the original code repository, which follows what you clarified.
Attention gate calculation: https://github.com/ozan-oktay/Attention-Gated-Networks/blob/eee4881fdc31920efd873773e0b744df8dacbfb6/models/layers/grid_attention_layer.py#L84 Network forward: https://github.com/ozan-oktay/Attention-Gated-Networks/blob/eee4881fdc31920efd873773e0b744df8dacbfb6/models/networks/unet_CT_single_att_dsv_3D.py#L68
Best Regards, Henry
Dear authors, thank you very much for making the code publicly available.
I have a question regards to the attention block implementation in Attention U-net. Below is the code snippet from your implementation:
g1 = conv(filters, kernel_size = 1)(shortcut) g1 = BatchNormalization() (g1) if batch_norm else g1 x1 = conv(filters, kernel_size = 1)(x) x1 = BatchNormalization() (x1) if batch_norm else x1
g1_x1 = Add()([g1,x1]) psi = Activation('relu')(g1_x1) psi = conv(1, kernel_size = 1)(psi) psi = BatchNormalization() (psi) if batch_norm else psi psi = Activation('sigmoid')(psi) x = Multiply()([x,psi])
In the code above, shortcut features are used as the gating signal, thus the final feature map after concatenation is: upsampled features * attention coefficient + upsampled features
However, in the original paper of Attention U-net, it seems that they used upsampled features as the gating signal. So their feature map after concatenation is: shortcut features * attention coefficient + upsampled features.
I conducted a simple experiment by swapping two variables passed to the attention block, but it gave a comparable (even worse) performance. I am not sure if I understood the written code incorrectly.
Best Regards,