Open devLupin opened 1 year ago
class SpatialGate(tf.keras.Model):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.spatial = Conv2D(filters=1, kernel_size=kernel_size, strides=1, padding='same', use_bias=False)
self.bn = BatchNormalization(momentum=0.01, epsilon=1e-5)
self.relu = Activation('relu')
def call(self, x):
avg_pool = tf.reduce_mean(x, axis=[3], keepdims=True)
max_pool = tf.reduce_max(x, axis=[3], keepdims=True)
x_compress = tf.concat([avg_pool,max_pool], 3)
x_out = self.spatial(x_compress)
x_out = self.bn(x_out)
x_out = self.relu(x_out)
scale = tf.math.sigmoid(x_out) # broadcasting
return x * scale
CBAM: Convolutional Block Attention Module
Overview
spatial attention map
생성channel attention
과 달리 유익한 부분이 어디에 있는지 초점Process
average pooling
($F{avg}$),max pooling
($F{max}$)computed as:
$\textbf{M}_{s}\left(F\right) = \sigma\left(f^{7x7}\left(\left[\text{AvgPool}\left(F\right);\text{MaxPool}\left(F\right)\right]\right)\right)$
Figure
Official code(Pytorch)