BiaPyX / BiaPy

Open source Python library for building bioimage analysis pipelines
https://BiaPyX.github.io
MIT License
116 stars 28 forks source link

Question about Attention U-net #26

Closed HenryLyu closed 1 year ago

HenryLyu commented 1 year ago

Dear authors, thank you very much for making the code publicly available.

I have a question regards to the attention block implementation in Attention U-net. Below is the code snippet from your implementation:

g1 = conv(filters, kernel_size = 1)(shortcut) g1 = BatchNormalization() (g1) if batch_norm else g1 x1 = conv(filters, kernel_size = 1)(x) x1 = BatchNormalization() (x1) if batch_norm else x1

g1_x1 = Add()([g1,x1]) psi = Activation('relu')(g1_x1) psi = conv(1, kernel_size = 1)(psi) psi = BatchNormalization() (psi) if batch_norm else psi psi = Activation('sigmoid')(psi) x = Multiply()([x,psi])

In the code above, shortcut features are used as the gating signal, thus the final feature map after concatenation is: upsampled features * attention coefficient + upsampled features

However, in the original paper of Attention U-net, it seems that they used upsampled features as the gating signal. So their feature map after concatenation is: shortcut features * attention coefficient + upsampled features.

image

I conducted a simple experiment by swapping two variables passed to the attention block, but it gave a comparable (even worse) performance. I am not sure if I understood the written code incorrectly.

Best Regards,

danifranco commented 1 year ago

Hello,

Thank you for using our library!

The gating signal is composed of the skip connection features + upsampled features, not just of one of them. Maybe that image is not very clear, as the dashed line representing the skip connection and the green arrow end in the gate function. Maybe you can see ibetter that they use both inputs in the figure 2 of the original paper:

image

Notice how we sum both inputs also with this: g1_x1 = Add()([g1,x1]).

You have here another implementation that makes like us the attention gate:

https://github.com/yingkaisha/keras-unet-collection/blob/d30f14a259656d2f26ea11ed978255d6a7d0ce37/keras_unet_collection/_model_att_unet_2d.py#L52

Then, the output of the attention module, e.g. AttentionBlock function in our code, is concatenated with the upsampled features:

https://github.com/danifranco/BiaPy/blob/6265aaf5dfdcf0e1d9d413288c3db8da2f63b78f/models/attention_unet.py#L148

Best regards,

Dani

HenryLyu commented 1 year ago

Thank you very much for your detailed explanation! I also check the original code repository, which follows what you clarified.

Attention gate calculation: https://github.com/ozan-oktay/Attention-Gated-Networks/blob/eee4881fdc31920efd873773e0b744df8dacbfb6/models/layers/grid_attention_layer.py#L84 Network forward: https://github.com/ozan-oktay/Attention-Gated-Networks/blob/eee4881fdc31920efd873773e0b744df8dacbfb6/models/networks/unet_CT_single_att_dsv_3D.py#L68

Best Regards, Henry