nengo / nengo-dl

Deep learning integration for Nengo
https://www.nengo.ai/nengo-dl
Other
88 stars 22 forks source link

Implement NengoDL version of ModelCheckpoint #133

Closed arvoelke closed 4 years ago

arvoelke commented 4 years ago

Related to #132.

Minimal reproducer:

with nengo.Network() as model:
    nengo.Probe(nengo.Node(0))

with nengo_dl.Simulator(model, minibatch_size=2) as sim:
    sim.keras_model.save_weights("temp.hdf5")

with nengo_dl.Simulator(model, minibatch_size=1) as sim:
    sim.keras_model.load_weights("temp.hdf5")

Stack trace:

ValueError                                Traceback (most recent call last)
<ipython-input-15-5bffe7574ed5> in <module>
      6 
      7 with nengo_dl.Simulator(model, minibatch_size=1) as sim:
----> 8     sim.keras_model.load_weights("temp.hdf5")

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in load_weights(self, filepath, by_name)
    179         raise ValueError('Load weights is not yet supported with TPUStrategy '
    180                          'with steps_per_run greater than 1.')
--> 181     return super(Model, self).load_weights(filepath, by_name)
    182 
    183   @trackable.no_automatic_dependency_tracking

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in load_weights(self, filepath, by_name)
   1175         saving.load_weights_from_hdf5_group_by_name(f, self.layers)
   1176       else:
-> 1177         saving.load_weights_from_hdf5_group(f, self.layers)
   1178 
   1179   def _updated_config(self):

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/hdf5_format.py in load_weights_from_hdf5_group(f, layers)
    697                        str(len(weight_values)) + ' elements.')
    698     weight_value_tuples += zip(symbolic_weights, weight_values)
--> 699   K.batch_set_value(weight_value_tuples)
    700 
    701 

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/keras/backend.py in batch_set_value(tuples)
   3356             assign_placeholder = array_ops.placeholder(tf_dtype,
   3357                                                        shape=value.shape)
-> 3358             assign_op = x.assign(assign_placeholder)
   3359             x._assign_placeholder = assign_placeholder
   3360             x._assign_op = assign_op

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py in assign(self, value, use_locking, name, read_value)
    812     with _handle_graph(self.handle):
    813       value_tensor = ops.convert_to_tensor(value, dtype=self.dtype)
--> 814       self._shape.assert_is_compatible_with(value_tensor.shape)
    815       assign_op = gen_resource_variable_ops.assign_variable_op(
    816           self.handle, value_tensor, name=name)

~/anaconda3/envs/nengo-dl/lib/python3.7/site-packages/tensorflow_core/python/framework/tensor_shape.py in assert_is_compatible_with(self, other)
   1113     """
   1114     if not self.is_compatible_with(other):
-> 1115       raise ValueError("Shapes %s and %s are incompatible" % (self, other))
   1116 
   1117   def most_specific_compatible_shape(self, other):

ValueError: Shapes (1, 1) and (2, 1) are incompatible

Expected behaviour: Expected this to be okay, as one might want to change the minibatch size from one run to another (e.g., to work around #132 or #121, or to experiment with different batch sizes) while reusing the same model parameters from a previous run.

drasmuss commented 4 years ago

If you use sim.save_params/load_params then this will work. When using keras_model.save_weights you're also saving the internal simulation state (which has a minibatch dimension), which is why those parameters don't transfer between models with different minibatch size.

arvoelke commented 4 years ago

Good to know. For reference the reason [save|load]_weights is being used as opposed to [save|load]_params is because I'm using the ModelCheckpoint callback with save_weights_only=True. Setting this to False triggers some issue in the serialization/deserialization logic when trying to load them back in.

drasmuss commented 4 years ago

Going to use this issue to track the idea of implementing some thin wrapper around ModelCheckpoint that will call sim.save_params instead of sim.keras_model.save_weights.

drasmuss commented 4 years ago

Note to future self. An easier fix might be to store simulator state as a simple tf.Variable (not added through layer.add_weights). Then Keras wouldn't track it, and sim.keras_model.save_weights would produce the same behaviour as sim.save_params (only saving the trainable parameters of the model).