kuza55 / keras-extras

Extra batteries for Keras
Apache License 2.0
257 stars 65 forks source link

Potential incompatibility with keras model checkpointing #23

Closed munsanje closed 7 years ago

munsanje commented 7 years ago

I recently adopted the multi_gpu module to parallelize learning across multiple gpus. On 8 K80 teslas I get a speed-up of roughly 4x, and learning appears to take place, as the loss goes down per iteration. However, when I actually test the model and visualize the results, it appears to perform in exactly the same way as without training. Previously, at the same loss I achieved while training with multi_gpu, I'd get drastically different performance. I've been working with this model for months and so have proven the learnability of the problem and the success of the architecture, so the results make no sense. I'm using keras's built-in ModelCheckpoint callback to automatically save my model after every epoch in which the validation loss has decreased. My guess is that there is a silent conflict between how the model is saved and this module. Any help debugging this would be greatly appreciated.

munsanje commented 7 years ago

Managed to fix this issue by slightly adapting @tstandley's solution in #3. Can confirm that the original multi_gpu code was the cause. Redefining both the model.save and model.save_weights functions as described in #3 solved the problem. Code:

    save_model_function = type(model.save)
    def save_old_model(self_, model_path, overwrite=True):
        model.save(model_path, overwrite)
    new_model.save = save_model_function(save_old_model, new_model)
    # update weight saving scheme to save underlying model weights

    save_weights_function = type(model.save_weights)

    def save_old_weights(self_, weights_path, overwrite=True):
        model.save_weights(weights_path, overwrite)
    new_model.save_weights = save_weights_function(save_old_weights, new_model)
    return new_model
CeadeS commented 7 years ago

@munsanje do you have a example code for the solution? I have a similar problem and unfortunately not understanding where to put the code in, that you mentioned.

munsanje commented 7 years ago

@CeadeS Yeah sure. I modified the last segment of the module with the code:

# merge outputs on CPU
with tf.device('/cpu:0'):
    merged = []
    for outputs in outputs_all:
        merged.append(merge(outputs, mode='concat', concat_axis=0))

    # update model saving scheme to save underlying model rather than parallel
    new_model = Model(input=model.inputs, output=merged)
    save_model_function = type(model.save)

    def save_old_model(self_, model_path, overwrite=True):
        model.save(model_path, overwrite)
    new_model.save = save_model_function(save_old_model, new_model)
    # update weight saving scheme to save underlying model weights

    save_weights_function = type(model.save_weights)

    def save_old_weights(self_, weights_path, overwrite=True):
        model.save_weights(weights_path, overwrite)
    new_model.save_weights = save_weights_function(save_old_weights, new_model)
    return new_model