jakeret / tf_unet

Generic U-Net Tensorflow implementation for image segmentation
GNU General Public License v3.0
1.9k stars 748 forks source link

How to concatenate Conv2DTranspose and Conv2D #279

Closed sharkdeng closed 5 years ago

sharkdeng commented 5 years ago

Hello, I was building the architecture according to the paper. I got error in concatenating Conv2DTransposed tensor (56, 56, 512) and 4th Conv2D tensor (64, 64, 512). I don't know what's going wrong. Thank you!

Following is my code:

input = Input(shape=(572, 572, 3))
input = cv2.resize(a_images[0], (572, 572))
input = input.reshape(1, 572, 572, 3)
input = tf.Variable(input)

## Contracting
# 1 - 64
conv1 = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1))(input) # (?, 570, 570, 64)
conv1 = BatchNormalization()(conv1)
conv1 = Activation('relu')(conv1)

conv1 = Conv2D(filters=64, kernel_size=(3, 3), strides=(1, 1))(conv1) # (?, 568, 568, 64)
conv1 = BatchNormalization()(conv1)
conv1 = Activation('relu')(conv1)

pool1 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv1) # (?, 284, 284, 64)

# 2 - 128
conv2 = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1))(pool1) # (?, 282, 282, 128)
conv2 = BatchNormalization()(conv2)
conv2 = Activation('relu')(conv2)

conv2 = Conv2D(filters=128, kernel_size=(3, 3), strides=(1, 1))(conv2) # (?, 280, 280, 128)
conv2 = BatchNormalization()(conv2)
conv2 = Activation('relu')(conv2)

pool2 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv2) # (?, 140, 140, 128)

# 3 - 256
conv3 = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1))(pool2) # (?, 138, 138, 256)
conv3 = BatchNormalization()(conv3)
conv3 = Activation('relu')(conv3)

conv3 = Conv2D(filters=256, kernel_size=(3, 3), strides=(1, 1))(conv3) # (?, 136, 136, 256)
conv3 = BatchNormalization()(conv3)
conv3 = Activation('relu')(conv3)

pool3 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv3) # (?, 68, 68, 256)

# 4 - 512
conv4 = Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1))(pool3) # (?, 66, 66, 512)
conv4 = BatchNormalization()(conv4)
conv4 = Activation('relu')(conv4)

conv4 = Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1))(conv4) # (?, 64, 64, 512)
conv4 = BatchNormalization()(conv4)
conv4 = Activation('relu')(conv4)

pool4 = MaxPooling2D(pool_size=(2, 2), strides=(2, 2), padding='same')(conv4) # (?, 32, 32, 512)

# 5 - 1024
conv5 = Conv2D(filters=1024, kernel_size=(3, 3), strides=(1, 1))(pool4) # (?, 30, 30, 1024)
conv5 = BatchNormalization()(conv5)
conv5 = Activation('relu')(conv5)

conv5 = Conv2D(filters=1024, kernel_size=(3, 3), strides=(1, 1))(conv5) # (?, 28, 28, 1024)
conv5 = BatchNormalization()(conv5)
conv5 = Activation('relu')(conv5)

## Expansive
# 1 - 512
dconv1 = Conv2DTranspose(filters=512, kernel_size=(2, 2), strides=(2, 2))(conv5)  #(56, 56, 512)
dconv1 = BatchNormalization()(dconv1)
dconv1 = Activation('relu')(dconv1)

cat1 = Concatenate(axis=3)([conv4, dconv1]) # (56, 56, 1024)

conv6 = Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1))(cat1) # (54, 54, 512)
conv6 = BatchNormalization()(conv6)
conv6 = Activation('relu')(conv6)

conv6 = Conv2D(filters=512, kernel_size=(3, 3), strides=(1, 1))(conv6) # (52, 52, 512)
conv6 = BatchNormalization()(conv6)
conv6 = Activation('relu')(conv6)

se=tf.Session()
se.run(tf.global_variables_initializer())
result = se.run(dconv1)
print(result.shape)
sharkdeng commented 5 years ago

Found solution, use keras Cropping2D.