oarriaga / STN.keras

Implementation of spatial transformer networks (STNs) in keras 2 with tensorflow as backend.
MIT License
278 stars 75 forks source link

Loading trained model with custom layers #7

Closed SivaKanishka closed 6 years ago

SivaKanishka commented 6 years ago

Hi @oarriaga

Could you elaborate on the steps for loading the trained model. I am able to train the model using your code. But, while loading the trained model it gives an error saying SpatialTransformer layer is not found

Thanks

oarriaga commented 6 years ago

you have pass to the load model function of keras the custom layers. Also, you can just save the weights and load them again in a model.

mrgloom commented 5 years ago

When I try to load model that use BilinearInterpolation layer like this:

        self.model = load_model(model_filepath, custom_objects={
            ...
            'BilinearInterpolation' : BilinearInterpolation})

I get error:

    'BilinearInterpolation' : BilinearInterpolation})
  File "/usr/local/lib/python3.6/site-packages/keras/models.py", line 270, in load_model
    model = model_from_config(model_config, custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/site-packages/keras/models.py", line 347, in model_from_config
    return layer_module.deserialize(config, custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
    list(custom_objects.items())))
  File "/usr/local/lib/python3.6/site-packages/keras/engine/topology.py", line 2525, in from_config
    process_layer(layer_data)
  File "/usr/local/lib/python3.6/site-packages/keras/engine/topology.py", line 2511, in process_layer
    custom_objects=custom_objects)
  File "/usr/local/lib/python3.6/site-packages/keras/layers/__init__.py", line 55, in deserialize
    printable_module_name='layer')
  File "/usr/local/lib/python3.6/site-packages/keras/utils/generic_utils.py", line 146, in deserialize_keras_object
    return cls.from_config(config['config'])
  File "/usr/local/lib/python3.6/site-packages/keras/engine/topology.py", line 1271, in from_config
    return cls(**config)
TypeError: __init__() missing 1 required positional argument: 'output_size'

Solved with adding default parameter:

    def __init__(self, output_size=(64,64), **kwargs):
        self.output_size = output_size
        super(BilinearInterpolation, self).__init__(**kwargs)
skulhare commented 5 years ago

Thanks a lot, that helped.