DeepRegNet / DeepReg

Medical image registration using deep learning
Apache License 2.0
564 stars 76 forks source link

Nan/inf loss encountered #690

Closed mathpluscode closed 3 years ago

mathpluscode commented 3 years ago

Subject of the issue

More discussion is here https://github.com/Project-MONAI/MONAI/issues/1868

Using script https://github.com/DeepRegNet/Benchmark/blob/9-data-for-miccai/benchmark/balakrishnan2019/voxel_morph_balakrishnan_2019.py with the latest main branch.

226/226 [==============================] - 3083s 14s/step - loss: -0.7058 - loss/regularization_GradientNorm: 2.1987e-04 - loss/regularization_GradientNorm_weighted: 2.1987e-04 - loss/image_LocalNormalizedCrossCorrelationLoss: -0.7060 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.7060
Epoch 2/2000
226/226 [==============================] - 3266s 14s/step - loss: -0.4778 - loss/regularization_GradientNorm: 0.3009 - loss/regularization_GradientNorm_weighted: 0.3009 - loss/image_LocalNormalizedCrossCorrelationLoss: -0.7787 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.7787
Epoch 3/2000
226/226 [==============================] - 3280s 15s/step - loss: nan - loss/regularization_GradientNorm: nan - loss/regularization_GradientNorm_weighted: nan - loss/image_LocalNormalizedCrossCorrelationLoss: -inf - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -inf

Your environment

mathpluscode commented 3 years ago

The loss is not nan/inf when using the following config

dataset:
  dir:
    train: "/raid/candi/Yunguan/DeepReg/neuroimaging/preprocessed" # required
    valid:
    test:
  format: "nifti"
  type: "unpaired" # paired / unpaired / grouped
  labeled: false # whether to use the labels if available, "true" or "false"
  image_shape: [196, 112, 96]

train:
  # define neural network structure
  method: "ddf" # options include "ddf", "dvf", "conditional"
  backbone:
    name: "unet" # options include "local", "unet" and "global"
    num_channel_initial: 16 # number of initial channel in local net, controls the size of the network
    depth: 4
    # concat_skip: true
    # encode_num_channels: [16, 32, 32, 32, 32]
    # decode_num_channels: [32, 32, 32, 32, 32]

  # define the loss function for training
  loss:
    image:
      name: "lncc" # other options include "lncc", "ssd" and "gmi", for local normalised cross correlation,
      weight: 1.0
    label:
      weight: 0.0
      name: "dice" # options include "dice", "cross-entropy", "mean-squared", "generalised_dice" and "jaccard"
    regularization:
      weight: 1.0 # weight of regularization loss
      name: "gradient" # options include "bending", "gradient"

  # define the optimizer
  optimizer:
    name: "adam" # options include "adam", "sgd" and "rms"
    adam:
      learning_rate: 1.0e-4

  # define the hyper-parameters for preprocessing
  preprocess:
    data_augmentation:
      name: "affine"
    batch_size: 4
    shuffle_buffer_num_batch: 1 # shuffle_buffer_size = batch_size * shuffle_buffer_num_batch

  # other training hyper-parameters
  epochs: 2000 # number of training epochs
  save_period: 50 # the model will be saved every `save_period` epochs.
mathpluscode commented 3 years ago
226/226 [==============================] - 4779s 21s/step - loss: -0.1736 - loss/regularization_GradientNorm: 4.3662e-05 - loss/regularization_GradientNorm_weighted: 4.3662e-05 - loss/image_LocalNormalizedCrossC
orrelationLoss: -0.1736 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.1736
Epoch 2/50
226/226 [==============================] - 5154s 23s/step - loss: -0.1405 - loss/regularization_GradientNorm: 0.0362 - loss/regularization_GradientNorm_weighted: 0.0362 - loss/image_LocalNormalizedCrossCorrelati
onLoss: -0.1767 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.1767
Epoch 3/50
226/226 [==============================] - 5218s 23s/step - loss: -0.1474 - loss/regularization_GradientNorm: 0.0315 - loss/regularization_GradientNorm_weighted: 0.0315 - loss/image_LocalNormalizedCrossCorrelati
onLoss: -0.1788 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.1788
Epoch 4/50
226/226 [==============================] - 5211s 23s/step - loss: nan - loss/regularization_GradientNorm: nan - loss/regularization_GradientNorm_weighted: nan - loss/image_LocalNormalizedCrossCorrelationLoss: -i
nf - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -inf
Epoch 5/50
226/226 [==============================] - ETA: 0s - loss: nan - loss/regularization_GradientNorm: nan - loss/regularization_GradientNorm_weighted: nan - loss/image_LocalNormalizedCrossCorrelationLoss: -0.2500 -
 loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.2500 2021-03-15 06:06:09.605122: W tensorflow/core/framework/op_kernel.cc:1767] OP_REQUIRES failed at summary_kernels.cc:242 : Invalid argument: Nan i
n summary histogram for: Unet/conv3d/kernel_0
Traceback (most recent call last):
  File "voxel_morph_balakrishnan_2019.py", line 217, in <module>
    main()
  File "voxel_morph_balakrishnan_2019.py", line 212, in main
    ckpt_path="",
  File "/raid/candi/yunguan/Git/DeepReg/deepreg/train.py", line 167, in train
    callbacks=callbacks,
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1137, in fit
    callbacks.on_epoch_end(epoch, epoch_logs)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 412, in on_epoch_end
    callback.on_epoch_end(epoch, logs)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 2182, in on_epoch_end
    self._log_weights(epoch)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/callbacks.py", line 2234, in _log_weights
    summary_ops_v2.histogram(weight_name, weight, step=epoch)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 836, in histogram
    return summary_writer_function(name, tensor, function, family=family)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 765, in summary_writer_function
    should_record_summaries(), record, _nothing, name="")
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/framework/smart_cond.py", line 54, in smart_cond
    return true_fn()
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 758, in record
    with ops.control_dependencies([function(tag, scope)]):
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/summary_ops_v2.py", line 834, in function
    name=scope)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/gen_summary_ops.py", line 480, in write_histogram_summary
    writer, step, tag, values, name=name, ctx=_ctx)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/ops/gen_summary_ops.py", line 499, in write_histogram_summary_eager_fallback
    attrs=_attrs, ctx=ctx, name=name)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Nan in summary histogram for: Unet/conv3d/kernel_0 [Op:WriteHistogramSummary]
mathpluscode commented 3 years ago

image

It seems that the problem may come from the deformation regularization.

mathpluscode commented 3 years ago

changing deform loss to https://github.com/voxelmorph/voxelmorph/blob/legacy/src/losses.py#L88 does not solve the problem

image

Need to double check the resampler and related functions

eventually we got inf first in lncc, but the nan/inf checks in LNCC was not triggered?

image

mathpluscode commented 3 years ago

changed zero_boundary in resample, default to False, still got nan/inf

image

maybed linked to leaky relu

mathpluscode commented 3 years ago

change from leaky relu to relu does not solve the problem neither

image

mathpluscode commented 3 years ago

ok just found out that I added activation at the final layer before outputting DDF. :( it might be the cause, as leaky relu or relu both push DDF towards positive values, which may push DDF to be too large

mathpluscode commented 3 years ago

by fixing activation, we still got nan/inf

image

mathpluscode commented 3 years ago

change to SSD loss for image, then the nan/inf is gone.

image

So the problem must come from the LNCC loss

mathpluscode commented 3 years ago

change LNCC implementation to VM one does not solve the problem.

image

mathpluscode commented 3 years ago
Epoch 3/200
318/453 [====================>.........] - ETA: 7:05 - loss: 1.0236 - metric/moving_image_mean: 0.1243 - metric/moving_image_min: 0.0211 - metric/moving_image_max: 0.8673 - metric/fixed_image_mean: 0.1231 - metr
ic/fixed_image_min: 0.0217 - metric/fixed_image_max: 0.8634 - loss/regularization_GradientNorm: 1.8122 - loss/regularization_GradientNorm_weighted: 1.8122 - metric/ddf_mean: -50.4680 - metric/ddf_min: -351.1192
- metric/ddf_max: 0.5666 - loss/image_LocalNormalizedCrossCorrelationLoss: -0.7886 - loss/image_LocalNormalizedCrossCorrelationLoss_weighted: -0.7886Traceback (most recent call last):
  File "voxel_morph_balakrishnan_2019.py", line 285, in <module>
    main()
  File "voxel_morph_balakrishnan_2019.py", line 280, in main
    ckpt_path="",
  File "/home/yunguan/Git/DeepReg/deepreg/train.py", line 169, in train
    callbacks=callbacks,
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 108, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py", line 1098, in fit
    tmp_logs = train_function(iterator)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py", line 807, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 2829, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1848, in _filtered_call
    cancellation_manager=cancellation_manager)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 1924, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/function.py", line 550, in call
    ctx=ctx)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 138, in execute_with_callbacks
    tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
  File "/home/yunguan/miniconda3/envs/deepreg/lib/python3.7/site-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:

!!! Detected Infinity or NaN in output 0 of graph op "Sum" (# of outputs: 1) !!!
  dtype: <dtype: 'float32'>
  shape: ()

  Input tensors (2):                                                                                                                                                                                               
         0: Tensor("DDFModel/functional_1/tf_op_layer_RealDiv_1/RealDiv_1:0", shape=(), dtype=float32)                                                                                                             
         1: Tensor("DDFModel/functional_1/add_metric_11/Const:0", shape=(0,), dtype=int32)                                                                                                                         
  Graph name: "train_function"                                                                                                                                                                                     

  Stack trace of op's creation ("->": inferred user code):                                                                                                                                                         
    + ... (Omitted 13 frames)                                                                                                                                                                                      
    + ...packages/tensorflow/python/framework/func_graph.py (L969) wrapper                                                                                                                                         
    |   user_requested=True,
    + ...ackages/tensorflow/python/keras/engine/training.py (L806) train_function
    |   return step_function(self, iterator)
    + ...ackages/tensorflow/python/keras/engine/training.py (L796) step_function
    |   outputs = model.distribute_strategy.run(run_step, args=(data,))
    + ...ges/tensorflow/python/distribute/distribute_lib.py (L1211) run
    |   return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    + ...ges/tensorflow/python/distribute/distribute_lib.py (L2585) call_for_each_replica
    |   return self._call_for_each_replica(fn, args, kwargs)
    + ...ges/tensorflow/python/distribute/distribute_lib.py (L2945) _call_for_each_replica
    |   return fn(*args, **kwargs)
    + ...ackages/tensorflow/python/keras/engine/training.py (L789) run_step
    |   outputs = model.train_step(data)
    + ...ackages/tensorflow/python/keras/engine/training.py (L747) train_step
    |   y_pred = self(x, training=True)
    + ...kages/tensorflow/python/keras/engine/base_layer.py (L985) __call__
    |   outputs = call_fn(inputs, *args, **kwargs)
    + /home/yunguan/Git/DeepReg/deepreg/model/network.py (L248) call
 -> |   return self._model(inputs, training=training, mask=mask)  # pragma: no cover
    + ...kages/tensorflow/python/keras/engine/base_layer.py (L985) __call__
    |   outputs = call_fn(inputs, *args, **kwargs)
    + ...kages/tensorflow/python/keras/engine/functional.py (L386) call
    |   inputs, training=training, mask=mask)
    + ...kages/tensorflow/python/keras/engine/functional.py (L508) _run_internal_graph
    |   outputs = node.layer(*args, **kwargs)
    + ...kages/tensorflow/python/keras/engine/base_layer.py (L985) __call__
    |   outputs = call_fn(inputs, *args, **kwargs)
    + ...kages/tensorflow/python/keras/engine/base_layer.py (L3196) call
    |   self.add_metric(inputs, aggregation=self.aggregation, name=self.metric_name)
    + ...kages/tensorflow/python/keras/engine/base_layer.py (L1705) add_metric
    |   metric_obj(value)
    + ...7/site-packages/tensorflow/python/keras/metrics.py (L231) __call__
    |   replica_local_fn, *args, **kwargs)
    + ...hon/keras/distribute/distributed_training_utils.py (L1133) call_replica_local_fn
    |   return fn(*args, **kwargs)
    + ...7/site-packages/tensorflow/python/keras/metrics.py (L211) replica_local_fn
    |   update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
    + ...ges/tensorflow/python/keras/utils/metrics_utils.py (L90) decorated
    |   update_op = update_state_fn(*args, **kwargs)
    + ...7/site-packages/tensorflow/python/keras/metrics.py (L176) update_state_fn
    |   return ag_update_state(*args, **kwargs)
   + ...7/site-packages/tensorflow/python/keras/metrics.py (L371) update_state
    |   value_sum = math_ops.reduce_sum(values)
    + ...7/site-packages/tensorflow/python/util/dispatch.py (L201) wrapper
    |   return target(*args, **kwargs)
    + ....7/site-packages/tensorflow/python/ops/math_ops.py (L1983) reduce_sum
    |   _ReductionDims(input_tensor, axis))
    + ....7/site-packages/tensorflow/python/ops/math_ops.py (L1994) reduce_sum_with_dims
    |   gen_math_ops._sum(input_tensor, dims, keepdims, name=name))
    + ...ite-packages/tensorflow/python/ops/gen_math_ops.py (L10537) _sum
    |   name=name)
    + ...ages/tensorflow/python/framework/op_def_library.py (L744) _apply_op_helper
    |   attrs=attr_protos, op_def=op_def)
    + ...packages/tensorflow/python/framework/func_graph.py (L593) _create_op_internal
    |   compute_device)
    + ...7/site-packages/tensorflow/python/framework/ops.py (L3485) _create_op_internal
    |   op_def=op_def)
    + ...7/site-packages/tensorflow/python/framework/ops.py (L1949) __init__
    |   self._traceback = tf_stack.extract_stack()

 : Tensor had -Inf values
         [[node DDFModel/functional_1/add_metric_11/Sum/CheckNumericsV2 (defined at /home/yunguan/Git/DeepReg/deepreg/model/network.py:248) ]] [Op:__inference_train_function_25908]

Function call stack:
train_function
YipengHu commented 3 years ago

so which line is it?

mathpluscode commented 3 years ago

Confirm the problem comes from our LNCC implementation

image

should double check with niftireg

https://github.com/KCL-BMEIS/niftyreg/blob/master/reg-lib/cpu/_reg_lncc.cpp

mathpluscode commented 3 years ago

Also we found that VM used an EPS = 1e-5, while we were using 1e-7, maybe this is the cause :( otherwise can not print the values that caused inf for the moment

mathpluscode commented 3 years ago

Actually, as our implementation was originally from VM, and I just found that multiple users have found issues with NCC there

mathpluscode commented 3 years ago

reopen this issue as the previous fix was buggy