pai-plznw4me / stamp_sign_segmentor

stamp sign 객체를 segmentation 하는 모델 입니다.
0 stars 0 forks source link

unet 에서 fmap 의 크기가 다른 이슈 #1

Open pai-plznw4me opened 3 years ago

pai-plznw4me commented 3 years ago

input size 을 50x50 으로 할 시에 encode : (50x50) → (25x25) → (13x13) → (7x7) → (4x4) decode : (4 x 4) → (8 x 8)

concatenate 될 7x7 와 decode 8x8 이 크기가 다르기 때문에 아래와 같은 ValueError 발생

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concat axis. Got inputs shapes: [(None, 7, 7, 64), (None, 8, 8, 64)]
pai-plznw4me commented 3 years ago

dynamic shape 을 활용해 padding 을 통해 concatenate 할 layer간의 크기를 일정하게 맞춤

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate
import tensorflow as tf

# @tf.function
def dynamic_padding(inputs):
    encode = Conv2D(filters=64, padding='same', strides=2, kernel_size=3)(inputs)
    decode = Conv2DTranspose(filters=64, kernel_size=3, padding='same', strides=2)(encode)

    def _pad_h(x):
        return tf.pad(x, [[0, 0], [0, 1], [0, 0], [0, 0]])

    def _pad_w(x):
        return tf.pad(x, [[0, 0], [0, 0], [0, 1], [0, 0]])

    input_h = tf.shape(inputs)[1]
    input_w = tf.shape(inputs)[2]

    decode_h = tf.shape(decode)[1]
    decode_w = tf.shape(decode)[2]

    inputs = tf.cond(tf.equal(input_h, decode_h), lambda: inputs, false_fn=lambda: _pad_h(inputs))
    inputs = tf.cond(tf.equal(input_w, decode_w), lambda: inputs, false_fn=lambda: _pad_w(inputs))

    concat_layer = tf.concat([inputs, decode], axis=-1)

    return concat_layer

if __name__ == '__main__':
    input_shape = (1, 7, 7, 3)
    inputs = tf.zeros(shape=input_shape)
    dynamic_padding(inputs)

하지만 위 코드를 graph 형태로 변형시 에러가 발생

에러는 아래와 같음, 에러가 발생하는 원인은 keras.layers.Conv2D 에서 발생함. keras.layers.Conv2D에서 내부적으로 tf.Variable 을 사용하는데 tf.function 안에서 tf.Variable 을 호출 할 수 없기 때문이다.

    ValueError: tf.function-decorated function tried to create variables on non-first call.

아래와 같이 수정

from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, Concatenate
import tensorflow as tf

@tf.function
def dynamic_padding(inputs, **kwargs):
    encode = kwargs['conv_fn'](inputs)
    decode = kwargs['conv_tp_fn'](encode)

    def _pad_h(x):
        return tf.pad(x, [[0, 0], [0, 1], [0, 0], [0, 0]])

    def _pad_w(x):
        return tf.pad(x, [[0, 0], [0, 0], [0, 1], [0, 0]])

    input_h = tf.shape(inputs)[1]
    input_w = tf.shape(inputs)[2]

    decode_h = tf.shape(decode)[1]
    decode_w = tf.shape(decode)[2]

    inputs = tf.cond(tf.equal(input_h, decode_h), lambda: inputs, false_fn=lambda: _pad_h(inputs))
    inputs = tf.cond(tf.equal(input_w, decode_w), lambda: inputs, false_fn=lambda: _pad_w(inputs))

    concat_layer = tf.concat([inputs, decode], axis=-1)

    return concat_layer

if __name__ == '__main__':
    input_shape = (1, 7, 7, 3)
    conv_fn = Conv2D(filters=64, padding='same', strides=2, kernel_size=3)
    conv_tp_fn = Conv2DTranspose(filters=64, kernel_size=3, padding='same', strides=2)
    inputs = tf.zeros(shape=input_shape)
    padded_layer = dynamic_padding(inputs, conv_fn=conv_fn, conv_tp_fn=conv_tp_fn)
    print(padded_layer)