YimianDai / open-aff

code and trained models for "Attentional Feature Fusion"
729 stars 95 forks source link

__init__() takes from 1 to 3 positional arguments but 6 were given #32

Open Thar-un opened 2 years ago

Thar-un commented 2 years ago

Hi I was trying modify the class AFF() code to support new version of keras, but stuggling with this error The modified AFF class `class AFF(tf.keras.layers.Layer): ''' 多特征融合 AFF '''

def __init__(self, channels=64, r=4):
    super().__init__()
    inter_channels = int(channels // r)

    self.local_att = tf.keras.Sequential(
        Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same'),
        tf.keras.layers.BatchNormalization(inter_channels),
        tf.keras.layers.ReLU(),
        Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same'),
        tf.keras.layers.BatchNormalization(channels),
    )

    self.global_att = tf.keras.Sequential(
        tf.keras.layers.AveragePooling2D(1),
        Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same'),
        tf.keras.layers.BatchNormalization(inter_channels),
        tf.keras.layers.ReLU(),
        Conv2D(filters=64, kernel_size=(3,3), strides=1, padding='same'),
        tf.keras.layers.BatchNormalization(channels),
    )

    self.sigmoid = nn.Sigmoid()

def forward(self, x, residual):
    xa = x + residual
    xl = self.local_att(xa)
    xg = self.global_att(xa)
    xlg = xl + xg
    wei = self.sigmoid(xlg)

    xo = 2 * x * wei + 2 * residual * (1 - wei)
    return xo`

The error image

The create model function


    tf.keras.backend.clear_session()

    input = Input(shape=(256,256,3), name="input_layer")
    print("Input =",input.shape)

    conv_block = Convolutional_block()(input)
    print("Conv block =",conv_block.shape)
    ca_block = Channel_attention()(conv_block)
    sa_block = SpatialGate()(conv_block)
    # AFF block instead of concatenate
    ca_block = AFF()(ca_block)

    model = Model(inputs=[input], outputs=[ca_block])
    return model

model = create_model()
model.summary()```

Input is an image of size 256,256,3