rishigami / Swin-Transformer-TF

Tensorflow implementation of Swin Transformer model.
Apache License 2.0
198 stars 46 forks source link

NotImplementedError during model save #13

Closed Bibhash123 closed 2 years ago

Bibhash123 commented 2 years ago

I have defined a model as follows:

def buildModel(LR = LR):
    backbone = SwinTransformer('swin_large_224', num_classes=None, include_top=False, pretrained=True, use_tpu=False)

    inp = L.Input(shape=(224,224,3))
    emb = backbone(inp)
    out = L.Dense(1,activation="relu")(emb)

    model = tf.keras.Model(inputs=inp,outputs=out)
    optimizer = tf.keras.optimizers.Adam(lr = LR)
    model.compile(loss="mse",optimizer=optimizer,metrics=[tf.keras.metrics.RootMeanSquaredError()])
    return model

Now when I save this model using model.save("./model.hdf5") I get the following error:

NotImplementedError                       Traceback (most recent call last)
/tmp/ipykernel_43/131311624.py in <module>
----> 1 model.save("model.hdf5")

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
   2000     # pylint: enable=line-too-long
   2001     save.save_model(self, filepath, overwrite, include_optimizer, save_format,
-> 2002                     signatures, options, save_traces)
   2003 
   2004   def save_weights(self,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
    152           'or using `save_weights`.')
    153     hdf5_format.save_model_to_hdf5(
--> 154         model, filepath, overwrite, include_optimizer)
    155   else:
    156     saved_model_save.save(model, filepath, overwrite, include_optimizer,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/hdf5_format.py in save_model_to_hdf5(model, filepath, overwrite, include_optimizer)
    113 
    114   try:
--> 115     model_metadata = saving_utils.model_metadata(model, include_optimizer)
    116     for k, v in model_metadata.items():
    117       if isinstance(v, (dict, list, tuple)):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    156   except NotImplementedError as e:
    157     if require_config:
--> 158       raise e
    159 
    160   metadata = dict(

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saving_utils.py in model_metadata(model, include_optimizer, require_config)
    153   model_config = {'class_name': model.__class__.__name__}
    154   try:
--> 155     model_config['config'] = model.get_config()
    156   except NotImplementedError as e:
    157     if require_config:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_config(self)
    648 
    649   def get_config(self):
--> 650     return copy.deepcopy(get_network_config(self))
    651 
    652   @classmethod

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/functional.py in get_network_config(network, serialize_layer_fn)
   1347         filtered_inbound_nodes.append(node_data)
   1348 
-> 1349     layer_config = serialize_layer_fn(layer)
   1350     layer_config['name'] = layer.name
   1351     layer_config['inbound_nodes'] = filtered_inbound_nodes

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    248         return serialize_keras_class_and_config(
    249             name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
--> 250       raise e
    251     serialization_config = {}
    252     for key, item in config.items():

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/utils/generic_utils.py in serialize_keras_object(instance)
    243     name = get_registered_name(instance.__class__)
    244     try:
--> 245       config = instance.get_config()
    246     except NotImplementedError as e:
    247       if _SKIP_FAILED_SERIALIZATION:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py in get_config(self)
   2252 
   2253   def get_config(self):
-> 2254     raise NotImplementedError
   2255 
   2256   @classmethod

NotImplementedError: 
rishigami commented 2 years ago

You should save the weights using model.save_weights().

li-qilei commented 4 months ago

Could you please suggest a way to do model.save("./model.hdf5")?