keras-team / keras-core

A multi-backend implementation of the Keras API, with support for TensorFlow, JAX, and PyTorch.
Apache License 2.0
1.27k stars 115 forks source link

tf.saved_model.save() raises ValueError when saving a model with TF backend #445

Closed ageron closed 1 year ago

ageron commented 1 year ago

I've installed Keras-Core with pip install keras-core and created a simple Sequential model using the TensorFlow backend, and I'm trying to export it to TensorFlow's SavedModel format using tf.saved_model.save(), but this failed with the following error:

ValueError: Unable to save the object ListWrapper([<KerasVariable shape=(3, 1), dtype=float32, name=variable_2>, <KerasVariable shape=(1,), dtype=float32, name=variable_3>]) (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified outside the wrapper (its final value was [<KerasVariable shape=(3, 1), dtype=float32, name=variable_2>, <KerasVariable shape=(1,), dtype=float32, name=variable_3>], its value when a checkpoint dependency was added was None), which breaks restoration on object creation.

If you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored.

Here is the code to reproduce the error (gist):

import keras_core as keras
import numpy as np
import tensorflow as tf

model = keras.Sequential([
    keras.layers.Dense(1)
])
model.compile(loss="mse", optimizer="adam")

X_train = np.random.rand(100, 3)
y_train = np.random.rand(100, 1)
model.fit(X_train, y_train)

tf.saved_model.save(model, "my_keras_core_model")

I've tried with both TensorFlow 2.12 and 2.13, and got the same error.

And here is the full stacktrace:

``` --------------------------------------------------------------------------- ValueError Traceback (most recent call last) [](https://localhost:8080/#) in () ----> 1 tf.saved_model.save(model, "my_keras_core_model") 14 frames [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in save(obj, export_dir, signatures, options) 1278 # pylint: enable=line-too-long 1279 metrics.IncrementWriteApi(_SAVE_V2_LABEL) -> 1280 save_and_return_nodes(obj, export_dir, signatures, options) 1281 1282 metrics.IncrementWrite(write_version="2") [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in save_and_return_nodes(obj, export_dir, signatures, options, experimental_skip_checkpoint) 1311 1312 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = ( -> 1313 _build_meta_graph(obj, signatures, options, meta_graph_def)) 1314 saved_model.saved_model_schema_version = ( 1315 constants.SAVED_MODEL_SCHEMA_VERSION) [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in _build_meta_graph(obj, signatures, options, meta_graph_def) 1491 1492 with save_context.save_context(options): -> 1493 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def) [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in _build_meta_graph_impl(obj, signatures, options, meta_graph_def) 1433 1434 # Use _SaveableView to provide a frozen listing of properties and functions. -> 1435 saveable_view = _SaveableView(augmented_graph_view, options) 1436 object_saver = checkpoint.TrackableSaver(augmented_graph_view) 1437 asset_info, exported_graph = _fill_meta_graph_def( [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in __init__(self, augmented_graph_view, options) 267 (self._trackable_objects, self.node_paths, self.node_ids, 268 self._slot_variables, self.object_names) = ( --> 269 checkpoint_util.objects_ids_and_slot_variables_and_paths( 270 self.augmented_graph_view)) 271 [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/util.py](https://localhost:8080/#) in objects_ids_and_slot_variables_and_paths(graph_view) 155 object -> node id, slot variables, object_names) 156 """ --> 157 trackable_objects, node_paths = graph_view.breadth_first_traversal() 158 object_names = object_identity.ObjectIdentityDictionary() 159 for obj, path in node_paths.items(): [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/graph_view.py](https://localhost:8080/#) in breadth_first_traversal(self) 122 123 def breadth_first_traversal(self): --> 124 return self._breadth_first_traversal() 125 126 def _breadth_first_traversal(self): [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in _breadth_first_traversal(self) 144 145 trackable_objects, _ = ( --> 146 super(_AugmentedGraphView, self)._breadth_first_traversal()) 147 148 asset_paths = object_identity.ObjectIdentityDictionary() [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/graph_view.py](https://localhost:8080/#) in _breadth_first_traversal(self) 126 def _breadth_first_traversal(self): 127 """Find shortest paths to all dependencies of self.root.""" --> 128 return super(ObjectGraphView, self)._descendants_with_paths() 129 130 def serialize_object_graph(self, saveables_cache=None): [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/trackable_view.py](https://localhost:8080/#) in _descendants_with_paths(self) 109 current_trackable = to_visit.popleft() 110 bfs_sorted.append(current_trackable) --> 111 for name, dependency in self.children(current_trackable).items(): 112 if dependency not in node_paths: 113 node_paths[dependency] = ( [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/graph_view.py](https://localhost:8080/#) in children(self, obj, save_type, **kwargs) 95 """ 96 children = {} ---> 97 for name, ref in self.list_children(obj, **kwargs): 98 children[name] = ref 99 return children [/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/save.py](https://localhost:8080/#) in list_children(self, obj) 178 children = self._children_cache[obj] = {} 179 --> 180 for name, child in super(_AugmentedGraphView, self).list_children( 181 obj, 182 save_type=base.SaveType.SAVEDMODEL, [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/graph_view.py](https://localhost:8080/#) in list_children(self, obj, save_type, **kwargs) 73 """ 74 children = [] ---> 75 for name, ref in super(ObjectGraphView, 76 self).children(obj, save_type, **kwargs).items(): 77 children.append(base.TrackableReference(name, ref)) [/usr/local/lib/python3.10/dist-packages/tensorflow/python/checkpoint/trackable_view.py](https://localhost:8080/#) in children(cls, obj, save_type, **kwargs) 82 obj._maybe_initialize_trackable() 83 children = {} ---> 84 for name, ref in obj._trackable_children(save_type, **kwargs).items(): 85 ref = converter.convert_to_trackable(ref, parent=obj) 86 children[name] = ref [/usr/local/lib/python3.10/dist-packages/tensorflow/python/trackable/data_structures.py](https://localhost:8080/#) in _trackable_children(self, save_type, **kwargs) 561 "non-trackable object; it will be subsequently ignored.") 562 if self._external_modification: --> 563 raise ValueError( 564 f"Unable to save the object {self} (a list wrapper constructed to " 565 "track trackable TensorFlow objects). The wrapped list was modified " ValueError: Unable to save the object ListWrapper([, ]) (a list wrapper constructed to track trackable TensorFlow objects). The wrapped list was modified outside the wrapper (its final value was [, ], its value when a checkpoint dependency was added was None), which breaks restoration on object creation. If you don't need this list checkpointed, wrap it in a NoDependency object; it will be subsequently ignored. ```
bermeitinger-b commented 1 year ago

I think you're supposed to use the new keras.saving namespace:

keras.saving.save_model(model, 'model.keras')
ageron commented 1 year ago

Hi @bermeitinger-b , Thanks for your response. I believe model.save(...) is an alias for keras.saving.save_model(model, ...). It works fine, but it saves the model using the Keras format. I'd like to export the model to TensorFlow's SavedModel format. It is no longer supported by model.save() (or by keras.saving.save_model()) but it should be supported by tf.saved_model.save(), so I believe there's a bug, either in Keras-Core or in TF. I've updated my bug report to make it clear that I'm trying to export to the SavedModel format.

vishalsubbiah commented 1 year ago

This seems to be because the KerasVariable within the Dense layer isn't a TF trackable obj like tf.Variable etc. Looks like KerasVariable needs to inherit from tensorflow.python.trackable.base.Trackable to work with tf.saved_model.save? Both the Dense layer and sequential seem to inherit from it, but the weights within Dense don't seem to.

asingh9530 commented 1 year ago

@ageron As per code this format most likely will not be supported just see the internal code below.

        if save_format in ["h5", "tf"]:
            raise ValueError(
                "`'h5'` and `'t5'` formats are no longer supported via the "
                "`save_format` option. Please use the new `'keras'` format. "
                f"Received: save_format={save_format}"
            )
        if save_format not in ["keras", "keras_v3"]:
            raise ValueError(
                "Unknown `save_format` value. Only the `'keras'` format is "
                f"currently supported. Received: save_format={save_format}"
            )
        if not str(filepath).endswith(".keras"):
            raise ValueError(
                "The filename must end in `.keras`. "
                f"Received: filepath={filepath}"
            )

maybe in future this can add but not as of now 😕

ageron commented 1 year ago

@asingh9530 , thanks for your message. However I think you are referring to the model.save() method, whereas I am trying to export the Keras-Core model to the SavedModel format using tf.saved_model.save(), which is supposed to work (confirmed by @fchollet in personal communication). The Keras-Core announcement does say that Keras-Core models can be exported to the SavedModel format:

image
asingh9530 commented 1 year ago

@ageron aah sorry my bad. actually it is a known issue you can check it here

fchollet commented 1 year ago

We're looking into this issue. The goal is to have tf.saved_model.save() work on any Keras model.

fchollet commented 1 year ago

It should be fixed now, please check that it works for you.

ageron commented 1 year ago

Awesome, it works now, thanks @fchollet!