juglab / n2v

This is the implementation of Noise2Void training.
Other
394 stars 108 forks source link

Model export does not work for N2V2 #128

Closed tibuch closed 1 year ago

tibuch commented 1 year ago

The blurpool implementation does not support model export:

1/1 [==============================] - 0s 173ms/step

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In [21], line 1
----> 1 model.export_TF(name='Noise2Void - 2D SEM Example', 
      2                 description='This is the 2D Noise2Void example trained on SEM data in python.', 
      3                 authors=["Tim-Oliver Buchholz", "Alexander Krull", "Florian Jug"],
      4                 test_img=X_val[0,...,0], axes='YX',
      5                 patch_shape=patch_shape)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/models/base_model.py:32, in suppress_without_basedir.<locals>._suppress_without_basedir.<locals>.wrapper(*args, **kwargs)
     30     warn is False or warnings.warn("Suppressing call of '%s' (due to basedir=None)." % f.__name__)
     31 else:
---> 32     return f(*args, **kwargs)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/n2v/models/n2v_standard.py:473, in N2V.export_TF(self, name, description, authors, test_img, axes, patch_shape, fname)
    464 # CSBDeep Export
    465 meta = {
    466     'type': self.__class__.__name__,
    467     'version': package_version,
   (...)
    471     'tile_overlap': self._axes_tile_overlap(self.config.axes),
    472 }
--> 473 export_SavedModel(self.keras_model, str(fname), meta=meta)
    474 # CSBDeep Export Done
    475 
    476 # Replace : with -
    477 name = name.replace(':', ' -')

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/utils/tf.py:230, in export_SavedModel(model, outpath, meta, format)
    228 with tempfile.TemporaryDirectory() as tmpdir:
    229     tmpsubdir = os.path.join(tmpdir,'model')
--> 230     export_to_dir(tmpsubdir)
    231     shutil.make_archive(outpath, format, tmpsubdir)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/csbdeep/utils/tf.py:207, in export_SavedModel.<locals>.export_to_dir(dirname)
    204 weights = model.get_weights()
    205 with tf.Graph().as_default():
    206     # clone model in new graph and set weights
--> 207     _model = clone_model(model)
    208     _model.set_weights(weights)
    209     _export(_model)

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:505, in clone_model(model, input_tensors, clone_function)
    501     return _clone_sequential_model(
    502         model, input_tensors=input_tensors, layer_fn=clone_function
    503     )
    504 else:
--> 505     return _clone_functional_model(
    506         model, input_tensors=input_tensors, layer_fn=clone_function
    507     )

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:208, in _clone_functional_model(model, input_tensors, layer_fn)
    202 if not callable(layer_fn):
    203     raise ValueError(
    204         "Expected `layer_fn` argument to be a callable. "
    205         f"Received: layer_fn={layer_fn}"
    206     )
--> 208 model_configs, created_layers = _clone_layers_and_model_config(
    209     model, new_input_layers, layer_fn
    210 )
    211 # Reconstruct model from the config, using the cloned layers.
    212 (
    213     input_tensors,
    214     output_tensors,
   (...)
    217     model_configs, created_layers=created_layers
    218 )

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:273, in _clone_layers_and_model_config(model, input_layers, layer_fn)
    270         created_layers[layer.name] = layer_fn(layer)
    271     return {}
--> 273 config = functional.get_network_config(
    274     model, serialize_layer_fn=_copy_layer
    275 )
    276 return config, created_layers

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/engine/functional.py:1563, in get_network_config(network, serialize_layer_fn, config)
   1558         node_data = node.serialize(
   1559             _make_node_key, node_conversion_map
   1560         )
   1561         filtered_inbound_nodes.append(node_data)
-> 1563 layer_config = serialize_layer_fn(layer)
   1564 layer_config["name"] = layer.name
   1565 layer_config["inbound_nodes"] = filtered_inbound_nodes

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:270, in _clone_layers_and_model_config.<locals>._copy_layer(layer)
    268     created_layers[layer.name] = InputLayer(**layer.get_config())
    269 else:
--> 270     created_layers[layer.name] = layer_fn(layer)
    271 return {}

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/models/cloning.py:50, in _clone_layer(layer)
     49 def _clone_layer(layer):
---> 50     return layer.__class__.from_config(layer.get_config())

File /tungstenfs/scratch/gmicro_share/_software/CondaEnvs/Linux/n2v-test-pr/lib/python3.10/site-packages/keras/engine/base_layer.py:786, in Layer.get_config(self)
    783 # Check that either the only argument in the `__init__` is  `self`,
    784 # or that `get_config` has been overridden:
    785 if extra_args and hasattr(self.get_config, "_is_default"):
--> 786     raise NotImplementedError(
    787         textwrap.dedent(
    788             f"""
    789   Layer {self.__class__.__name__} has arguments {extra_args}
    790   in `__init__` and therefore must override `get_config()`.
    791 
    792   Example:
    793 
    794   class CustomLayer(keras.layers.Layer):
    795       def __init__(self, arg1, arg2):
    796           super().__init__()
    797           self.arg1 = arg1
    798           self.arg2 = arg2
    799 
    800       def get_config(self):
    801           config = super().get_config()
    802           config.update({{
    803               "arg1": self.arg1,
    804               "arg2": self.arg2,
    805           }})
    806           return config"""
    807         )
    808     )
    810 return config

NotImplementedError: 
Layer MaxBlurPool2D has arguments ['pool']
in `__init__` and therefore must override `get_config()`.

Example:

class CustomLayer(keras.layers.Layer):
    def __init__(self, arg1, arg2):
        super().__init__()
        self.arg1 = arg1
        self.arg2 = arg2

    def get_config(self):
        config = super().get_config()
        config.update({
            "arg1": self.arg1,
            "arg2": self.arg2,
        })
        return config
jdeschamps commented 1 year ago

Fixed here: https://github.com/juglab/n2v/pull/130