Closed markomitos closed 1 week ago
Hi @markomitos -
In keras3 DisableSharedObjectScope() class is already created but not used while cloning the model. I think it is not required in keras3 as in keras3 has custom object creation using serialization/deserialization and by cloning the model get the config rather then recreating the model from its config.
What currently happens if you try to clone a model that contains shared embeddings?
It clones like nothing happened and continues with the test. Is there a way to check if the layers still contain a shared embedding, or is it not even a problem in the new keras version?
The test I mentioned is similar to this test, just using keras 3 instead of keras 2: https://github.com/google-parfait/tensorflow-federated/blob/523c129676236f7060fafb95b2a8fed683a5e519/tensorflow_federated/python/learning/models/functional_test.py#L977
It clones like nothing happened and continues with the test. Is there a way to check if the layers still contain a shared embedding, or is it not even a problem in the new keras version?
You can test layer equality in the new model to see if the embedding layer is shared.
or is it not even a problem in the new keras version?
By default, if you have a shared layer in a model, the new layer will also be shared in the cloned version of the model.
If you want to change that behavior, you can! For instance if you want the new layer to be duplicated in the cloned model instead of shared, you can do that easily. Just pass a custom clone_function
argument to clone_model
.
Example:
def clone_function(layer):
config = layer.get_config()
return layer.__class__.from_config(config)
new_model = clone_model(model, clone_function=clone_function)
I am trying to add support for keras 3 to TensorFlow Federated and I need to check whether there was shared embeddings between layers when cloning a model and if that is the case to raise an error. Here is the code in question: https://github.com/google-parfait/tensorflow-federated/blob/523c129676236f7060fafb95b2a8fed683a5e519/tensorflow_federated/python/learning/models/functional.py#L502
Is there something similar to this legacy function in tf_keras in keras 3? https://github.com/keras-team/tf-keras/blob/c5f97730b2e495f5f56fc2267d22504075e46337/tf_keras/models/cloning.py#L525