avolkov1 / keras_experiments

Experimental Keras libraries and examples.
The Unlicense
86 stars 16 forks source link

Incompatible shapes for model with multiple inputs and custom loss. #8

Closed experiencor closed 7 years ago

experiencor commented 7 years ago

Hi @avolkov1

I encountered the following error when training a model with multiple inputs and custom loss:

InvalidArgumentError: Incompatible shapes: [8] vs. [16] [[Node: tower_1/model_4/lambda_3/add_6 = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:1"](tower_1/model_4/lambda_3/add_5, tower_1/model_4/lambda_3/mul_12)]] [[Node: loss/mul/_1069 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_9575_loss/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

The BATCH SIZE is 16. The sliced batch size is 8 for each of the 2 GPUs. The code of the model is:

true_image = Input(shape=(FRAME_H, FRAME_W, 3))  # adapt this if using channels_first image data format
jitt_image = Input(shape=(FRAME_H, FRAME_W, 3)) 

# encoder part
x = make_vgg16('encoder', trainable=True, return_all_layers=False)(true_image)

x = Flatten()(x)

z_mean = Dense(latent_dim)(x)
z_logv = Dense(latent_dim)(x)

# variational part
def sampling(args):
    z_mean, z_logv = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0., stddev=epsilon_std)

    return z_mean + K.exp(z_logv) * epsilon

z = Lambda(sampling)([z_mean, z_logv])

# decoder part
y = Dense(14 * 14 * 256, activation='relu')(z)
y = Reshape((14, 14, 256))(y)

y = Deconv2D(filters=256, kernel_size=(5, 5), strides=(2,2), padding='same')(y)
y = BatchNormalization()(y)
y = LeakyReLU(alpha=0.1)(y)

y = Deconv2D(filters=128, kernel_size=(5, 5), strides=(2,2), padding='same')(y)
y = BatchNormalization()(y)
y = LeakyReLU(alpha=0.1)(y)

y = Deconv2D(filters= 64, kernel_size=(5, 5), strides=(2,2), padding='same')(y)
y = BatchNormalization()(y)
y = LeakyReLU(alpha=0.1)(y)

y = Deconv2D(filters=  3, kernel_size=(5, 5), strides=(2,2), padding='same')(y)
y = BatchNormalization()(y)
fake_image = Activation('tanh')(y)

mean = tf.concat([tf.ones((FRAME_H,FRAME_W,1)) * 103.939, 
                  tf.ones((FRAME_H,FRAME_W,1)) * 116.779,
                  tf.ones((FRAME_H,FRAME_W,1)) * 123.68 ], axis=2)

def model_loss_no_vgg_head(args):
    true_image, fake_image  = args

    # perception loss
    perception_loss = 0.5 * K.mean(K.square(true_image - fake_image), axis=[-3, -2, -1])

    # kullback divergence
    kl_loss = -0.5 * K.mean(1 + z_logv - K.square(z_mean) - K.exp(z_logv), axis=-1)

    # print message for debugging
    #kl_loss = tf.Print(kl_loss, [tf.shape(lv1_loss)[0]], message='DEBUG', summarize=100) 

    return perception_loss * per_w + kl_loss  * kbd_w

total_loss_no_vgg_head = Lambda(model_loss_no_vgg_head, output_shape=(1,))([true_image, fake_image])

def model_loss_vgg_head(args):
    jitt, fake  = args

    # join jitt and fake image batches
    jitt_fake  = tf.concat([jitt, fake], axis=0)
    batch_size = tf.shape(jitt_fake)[0] / 2

    # preprocessing before vgg16
    jitt_fake = (jitt_fake + 1.) / 2. * 255.
    jitt_fake = jitt_fake - mean

    # extract features using vgg16
    jitt_fake = make_vgg16('vgg16')(jitt_fake)

    # perception loss
    lv1_loss, lv2_loss, lv3_loss, lv4_loss, lv5_loss = jitt_fake

    lv1_loss = tf.sigmoid(lv1_loss)
    lv2_loss = tf.sigmoid(lv2_loss)
    lv3_loss = tf.sigmoid(lv3_loss)
    lv4_loss = tf.sigmoid(lv4_loss)
    lv5_loss = tf.sigmoid(lv5_loss)

    lv1_loss = 0.5 * K.mean(K.square(lv1_loss[:batch_size] - lv1_loss[batch_size:]), axis=[-3, -2, -1])
    lv2_loss = 0.5 * K.mean(K.square(lv2_loss[:batch_size] - lv2_loss[batch_size:]), axis=[-3, -2, -1])
    lv3_loss = 0.5 * K.mean(K.square(lv3_loss[:batch_size] - lv3_loss[batch_size:]), axis=[-3, -2, -1])
    lv4_loss = 0.5 * K.mean(K.square(lv4_loss[:batch_size] - lv4_loss[batch_size:]), axis=[-3, -2, -1])
    lv5_loss = 0.5 * K.mean(K.square(lv5_loss[:batch_size] - lv5_loss[batch_size:]), axis=[-3, -2, -1])

    # kullback divergence
    kl_loss = -0.5 * K.mean(1 + z_logv - K.square(z_mean) - K.exp(z_logv), axis=-1)

    # print message for debugging
    #kl_loss = tf.Print(kl_loss, [tf.shape(lv1_loss)[0]], message='DEBUG', summarize=100) 

    return lv1_loss * lv1_w + lv2_loss * lv2_w + lv3_loss * lv3_w + \
           lv4_loss * lv4_w + lv5_loss * lv5_w + kl_loss  * kbd_w

total_loss_vgg_head = Lambda(model_loss_vgg_head, output_shape=(1,))([jitt_image, fake_image])

# models and encoder
model_vgg_head      = Model([true_image, jitt_image], total_loss_vgg_head)

This is the code to make the multi-GPU mode:

gdev_list = get_available_gpus()
mgpu_model = make_parallel(model_vgg_head, gdev_list)

I encountered this error when I run mgpu_model.fit_generator. Can you give me some pointers on how to fix this problem? Thanks in advance.

avolkov1 commented 7 years ago

@experiencor Did you run .compile? Could you post the print out of the model? Run this:

from keras_exp.multigpu import print_mgpu_modelsummary

# your code above
print_mgpu_modelsummary(mgpu_model)

I wasn't able to build the model just off the code you posted. There's a missing function and parameters that I'm not sure where they come from.

Parameters Unspecified:
FRAME_H, FRAME_W, latent_dim, epsilon_std, per_w, kbd_w, lv1_w, lv2_w, lv3_w, lv4_w, lv5_w

Function Not Defined:
make_vgg16

The other parameters/layers I figured out:

from keras.models import Model
from keras.layers import (
    Input, Lambda, Dense, Reshape, Deconv2D, BatchNormalization, LeakyReLU,
    Activation, Flatten)
import keras.backend as K
import tensorflow as tf

Ideally if you can share a working single GPU slimmed down example I'll debug it and try to run it with multiple GPUs. If there's proprietary code you don't want to share then at least please post the output of print_mgpu_modelsummary(mgpu_model) so I can see the layers and dimensions.

Maybe there's a bug in how I'm slicing and concatenating with multi-inputs, but I need more info.

experiencor commented 7 years ago

This is due to a bug in my model construction code.