keras-team / keras

Deep Learning for humans
http://keras.io/
Apache License 2.0
61.98k stars 19.47k forks source link

Graph disconnected error when using skip connections in an autoencoder #13160

Closed Eichhof closed 3 years ago

Eichhof commented 5 years ago

Hello

I have implemented a simple variational autoencoder in Keras with 2 convolutional layers in the encoder and decoder. The code is shown below. Now, I have extended my implementation with two skip connections (similar to U-Net). The skip connections are named merge1and merge2in the below code. Without the skip connections everything works fine but with the skip connections I'm getting the following error message:

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("encoder_input:0", shape=(?, 64, 80, 1), dtype=float32) at layer "encoder_input". The following previous layers were accessed without issue: []

Is there a problem in my code?

import keras
    from keras import backend as K
    from keras.layers import (Dense, Input, Flatten)
    from keras.layers import Conv2D, Lambda, MaxPooling2D, UpSampling2D, concatenate
    from keras.models import Model
    from keras.layers import Reshape
    from keras.losses import mse

    def sampling(args):
        z_mean, z_log_var = args
        batch = K.shape(z_mean)[0]
        dim = K.int_shape(z_mean)[1]
        epsilon = K.random_normal(shape=(batch, dim))
        return z_mean + K.exp(0.5 * z_log_var) * epsilon

    image_size = (64,80,1)
    inputs = Input(shape=image_size, name='encoder_input')

    conv1 = Conv2D(64, 3, activation='relu', padding='same')(inputs)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = Conv2D(128, 3, activation='relu', padding='same')(pool1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    shape = K.int_shape(pool2)

    x = Flatten()(pool2)
    x = Dense(16, activation='relu')(x)
    z_mean = Dense(6, name='z_mean')(x)
    z_log_var = Dense(6, name='z_log_var')(x)

    z = Lambda(sampling, output_shape=(6,), name='z')([z_mean, z_log_var])
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

    latent_inputs = Input(shape=(6,), name='z_sampling')
    x = Dense(16, activation='relu')(latent_inputs)
    x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(x)
    x = Reshape((shape[1], shape[2], shape[3]))(x)

    up1 = UpSampling2D((2, 2))(x)
    up1 = Conv2D(128, 2, activation='relu', padding='same')(up1)
    merge1 = concatenate([conv2, up1], axis=3)

    up2 = UpSampling2D((2, 2))(merge1)
    up2 = Conv2D(64, 2, activation='relu', padding='same')(up2)
    merge2 = concatenate([conv1, up2], axis=3)

    out = Conv2D(1, 1, activation='sigmoid')(merge2)

    decoder = Model(latent_inputs, out, name='decoder')

    outputs = decoder(encoder(inputs)[2])
    vae = Model(inputs, outputs, name='vae')

    def vae_loss(x, x_decoded_mean):
        reconstruction_loss = mse(K.flatten(x), K.flatten(x_decoded_mean))
        reconstruction_loss *= image_size[0] * image_size[1]
        kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
        kl_loss = K.sum(kl_loss, axis=-1)
        kl_loss *= -0.5
        vae_loss = K.mean(reconstruction_loss + kl_loss)
        return vae_loss

    optimizer = keras.optimizers.Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.000)
    vae.compile(loss=vae_loss, optimizer=optimizer)
    vae.fit(train_X, train_X,
            epochs=500,
            batch_size=128,
            verbose=1,
            shuffle=True,
            validation_data=(valid_X, valid_X))
gowthamkpr commented 5 years ago

@Eichhof Please refer to this following issue especially the solution provided by brstiner. This should help you in sloving the issue. thanks!