frankkramer-lab / MIScnn

A framework for Medical Image Segmentation with Convolutional Neural Networks and Deep Learning
GNU General Public License v3.0
400 stars 116 forks source link

Can't Load Model (custom_objects) #112

Closed ChrisJWest closed 2 years ago

ChrisJWest commented 2 years ago

Hi! I was just playing around with model saving and model loading, and it seems like loading now requires custom_objects. I got this error when I tried to load a model from assets:

_WARNING:tensorflow:Unable to restore custom metric. Please ensure that the layer implements get_config and from_config when saving. In addition, please use the custom_objects arg when calling load_model()._

I saw this was already in the project tickets so I assume it is known. Is there a way I can get around this on my own? (like something I can put in manually into the custom_objects to get unet_standard loading to work). I'm not quite familiar enough with how custom_objects works to define what is missing.

Thanks!

muellerdo commented 2 years ago

Hello @ChrisJWest,

MIScnn using the (Tensorflow-) Keras model management system:
https://keras.io/guides/serialization_and_saving/

In summary, there are two ways:

1 Using custom_object functionality:

You can always pass the custom created architectures or loss/metric functions as a dictionary during the loading process.

For example according to Keras:

model = CustomModel([16, 16, 10])
# Build the model by calling it
input_arr = tf.random.uniform((1, 5))
outputs = model(input_arr)
model.save("my_model")

# Option 1: Load with the custom_object argument.
loaded_1 = keras.models.load_model(
    "my_model", custom_objects={"CustomModel": CustomModel}
)

This would look like something like this in MIScnn:

model = Neural_Network(...)
model.load("my_model", custom_objects={"unet_standard": unet_standard})

However, there is a way more easier approach.

2 Using the H5 format for Model management:

Keras also supports saving a single HDF5 file containing the model's architecture, weights values, and compile() information. It is a light-weight alternative to SavedModel.

You just have to store your models with the ending .h5 or .hdf5.
This will include also any custom architectures or loss function in the hdf5 which allow you (or anyone else) to load them simply by just running model.load(path) ;)

model = Neural_Network(...)
model.dump("my_model.hdf5")
model.load("my_model.hdf5")

Happy holidays, Dominik

ChrisJWest commented 2 years ago

This solution worked perfectly :) my U-net is running great, thank you!