MIC-DKFZ / nnUNet

Apache License 2.0
5.95k stars 1.77k forks source link

overfitting and low dice #2512

Open sgengwbvkvjz opened 2 months ago

sgengwbvkvjz commented 2 months ago

Hi, I met some problems when training nnUNetv2. I have 101 PET images and you can check out from my dataset.json dataset.json. I don't know if the following outputs affects the results.

<2024-09-16 10:07:57.622053: Using splits from existing split file: /data/gzz/project/nnUNet/PET_data_raw/nnUNet_preprocessed/Dataset502_Organ/splits_final.json 2024-09-16 10:07:57.622207: The split file contains 5 splits. 2024-09-16 10:07:57.622245: Desired fold for training: 1 2024-09-16 10:07:57.622275: This split has 73 training and 18 validation cases. using pin_memory on device 0 using pin_memory on device 0 2024-09-16 10:08:04.112355: Using torch.compile... /home/robot/anaconda3/envs/nnunet/lib/python3.10/site-packages/torch/optim/lr_scheduler.py:60: UserWarning: The verbose parameter is deprecated. Please use get_last_lr() to access the learning rate. warnings.warn( This is the configuration used by this training: Configuration name: 3d_lowres {'data_identifier': 'nnUNetPlans_3d_lowres', 'preprocessor_name': 'DefaultPreprocessor', 'batch_size': 2, 'patch_size': [80, 288, 96], 'median_image_size_in_voxels': [132, 485, 134], 'spacing': [4.325730846013894, 3.994898837381438, 4.325730846013894], 'normalization_schemes': ['ZScoreNormalization'], 'use_mask_for_norm': [False], 'resampling_fn_data': 'resample_data_or_seg_to_shape', 'resampling_fn_seg': 'resample_data_or_seg_to_shape', 'resampling_fn_data_kwargs': {'is_seg': False, 'order': 3, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_seg_kwargs': {'is_seg': True, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'resampling_fn_probabilities': 'resample_data_or_seg_to_shape', 'resampling_fn_probabilities_kwargs': {'is_seg': False, 'order': 1, 'order_z': 0, 'force_separate_z': None}, 'architecture': {'network_class_name': 'dynamic_network_architectures.architectures.unet.PlainConvUNet', 'arch_kwargs': {'n_stages': 6, 'features_per_stage': [32, 64, 128, 256, 320, 320], 'conv_op': 'torch.nn.modules.conv.Conv3d', 'kernel_sizes': [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]], 'strides': [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 1]], 'n_conv_per_stage': [2, 2, 2, 2, 2, 2], 'n_conv_per_stage_decoder': [2, 2, 2, 2, 2], 'conv_bias': True, 'norm_op': 'torch.nn.modules.instancenorm.InstanceNorm3d', 'norm_op_kwargs': {'eps': 1e-05, 'affine': True}, 'dropout_op': None, 'dropout_op_kwargs': None, 'nonlin': 'torch.nn.LeakyReLU', 'nonlin_kwargs': {'inplace': True}, 'deep_supervision': True}, '_kw_requires_import': ['conv_op', 'norm_op', 'dropout_op', 'nonlin']}, 'batch_dice': False, 'next_stage': '3d_cascade_fullres'}

These are the global plan.json settings: {'dataset_name': 'Dataset502_Organ', 'plans_name': 'nnUNetPlans', 'original_median_spacing_after_transp': [3.125, 2.885999917984009, 3.125], 'original_median_shape_after_transp': [183, 671, 185], 'image_reader_writer': 'SimpleITKIO', 'transpose_forward': [1, 0, 2], 'transpose_backward': [1, 0, 2], 'experiment_planner_used': 'ExperimentPlanner', 'label_manager': 'LabelManager', 'foreground_intensity_properties_per_channel': {'0': {'max': 282966.65625, 'mean': 10044.5263671875, 'median': 7800.95458984375, 'min': 607.34716796875, 'percentile_00_5': 1188.931396484375, 'percentile_99_5': 53437.6796875, 'std': 8429.16015625}}}

2024-09-16 10:08:04.972868: unpacking dataset... 2024-09-16 10:08:08.791880: unpacking done... 2024-09-16 10:08:08.799376: Unable to plot network architecture: nnUNet_compile is enabled! 2024-09-16 10:08:08.807241: 2024-09-16 10:08:08.807357: Epoch 0 2024-09-16 10:08:08.807469: Current learning rate: 0.01 W0916 10:08:18.029000 140473405380416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range. W0916 10:08:18.176000 140473405380416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range. W0916 10:08:18.294000 140473405380416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range. W0916 10:08:18.356000 140473405380416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range. W0916 10:08:18.420000 140473405380416 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] xindex is not in var_ranges, defaulting to unknown range. W0916 10:08:24.831000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] d0 is not in var_ranges, defaulting to unknown range. W0916 10:08:25.208000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] d0 is not in var_ranges, defaulting to unknown range. W0916 10:08:25.491000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] d0 is not in var_ranges, defaulting to unknown range. W0916 10:08:25.690000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] d0 is not in var_ranges, defaulting to unknown range. W0916 10:08:26.697000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] d0 is not in var_ranges, defaulting to unknown range. W0916 10:08:27.510000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] q0 is not in var_ranges, defaulting to unknown range. W0916 10:08:27.574000 140467811383040 torch/fx/experimental/symbolic_shapes.py:4449] [0/0] z0 is not in var_ranges, defaulting to unknown range. .....>

While training with 3d_fullres, I got poor results and the lowest verification result was 0.1 and the highest was 0.3. Loss and dice curves may indicate an overfitting. I tried modifying the learning rate and weight_decay, but it doesn't seem to work. fold0:lr=0.01,weight_decay=3e-5,mean dice =0.2362... image lr=0.0001,weight_decay=3e-5,mean dice =0.2597.... image

While training with 3d_lowres,fold0's dice turned to 0. The train loss and val loss do not fall. image

I don't know what to do next and need some help. Thanks a lot!