neuronets / nobrainer_training_scripts

1 stars 2 forks source link

custom metric is failing in nobrainer #27

Open hvgazula opened 3 months ago

hvgazula commented 3 months ago

replaced dice in metrics.py with

class dice(keras.metrics.Metric):
    def __init__(self, name="dice", **kwargs):
        super(dice, self).__init__(name=name, **kwargs)
        self.dice = self.add_variable(shape=(), name="dice", initializer="zeros")

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.convert_to_tensor(y_pred)
        y_true = tf.cast(y_true, y_pred.dtype)
        eps = tf.keras.backend.epsilon()

        intersection = tf.reduce_sum(y_true * y_pred, axis=(1, 2, 3, 4))
        union = tf.reduce_sum(y_true, axis=(1, 2, 3, 4)) + tf.reduce_sum(
            y_pred, axis=self.axis
        )
        self.dice = (2 * intersection + eps) / (union + eps)

    def result(self):
        return self.dice.value

    def reset_state(self):
        self.dice.assign(0.0)

    def get_config(self):
        base_config = super(dice, self).get_config()
        return base_config

gives the following error

  File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer_training_scripts/1.2.0/scripts/misc/warm_start_multi_gpu.py", line 60, in <module>
    history = bem.fit(
  File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer/nobrainer/processing/segmentation.py", line 94, in fit
    _compile()
  File "/net/vast-storage/scratch/vast/gablab/hgazula/nobrainer/nobrainer/processing/segmentation.py", line 79, in _compile
    self.model_.compile(
  File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 70, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/om/user/hgazula/venvs/nobrainer_satra/lib/python3.10/site-packages/keras/src/engine/training.py", line 3893, in _validate_compile
    for v in getattr(metric, "variables", []):
TypeError: 'property' object is not iterable
hvgazula commented 3 months ago

TODO: run a small example outside nobrainer (using the sample data)