dmlc / keras

Deep Learning library for Python. Convnets, recurrent neural networks, and more. Runs on MXNet, Theano or TensorFlow.
http://keras.io/
Other
125 stars 34 forks source link

Fixed bug related to KerasSymbol.tensor #93

Open jricheimer opened 6 years ago

jricheimer commented 6 years ago

This prevents the unfortunate scenario in which layer.set_weights() is called before training on some or all layers in a model, and then the model is trained and saved.

When the weights are set, since weight.tensor already exists from the random initialization, it is getting reassigned to a new symbol, whereas weight._bind_values[weight.name] remains pointing to the same symbol and only its values are being replaced by the new data. When model._sync_weights() is eventually called, the model._args[weight.name] and weight._bind_values[weight.name] contains the updated (trained) weights, but weight.tensor contains the old initialization values. And it is the weight.tensor which gets evaluated when calling model.save_weights() or layer.get_weights(). Therefore, incorrect weights are getting saved to the Keras model files.

Potentially, this is also the bug causing this issue.

The easy fix here is to keep the symbol which weight.tensor is pointing to and only replace it's values with new data, same as is done with weight._bind_values[weight.name].