mlcommons / GaNDLF

A generalizable application framework for segmentation, regression, and classification using PyTorch
https://gandlf.org
Apache License 2.0
163 stars 79 forks source link

Implement DenseVNet to GaNDLF #526

Closed carlpe closed 1 year ago

carlpe commented 1 year ago

Implement DenseVNet to GaNDLF

https://niftynet.readthedocs.io/en/dev/niftynet.network.dense_vnet.html#niftynet.network.dense_vnet.DenseVNet

https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6076994/

image
sarthakpati commented 1 year ago

Thank you for your feedback and suggestion! I found 3 implementations for this using PyTorch:

Could you please specify which one would be more appropriate for your use case?

carlpe commented 1 year ago

Hi Sarthak,

I am really not sure, but perhaps the last one, from project Monai would be appropriate? The question is if that would work as it seems that it has not been merged. If it does not work, we could start with the first one on the list.

Thank you

sarthakpati commented 1 year ago

A preliminary version based on the first implementation is ready here: https://github.com/sarthakpati/GaNDLF/tree/526-implement-densevnet-to-gandlf

The unit tests seem to pass, but I am unsure if the results are expected. Could you please try and let us know?

carlpe commented 1 year ago

That was quick @sarthakpati !

I will try it out once I get the opportunity, currently traveling across the atlantic

carlpe commented 1 year ago

Config file:

# affix version
version:
  {
    minimum: 0.0.16,
    maximum: 0.0.16 # this should NOT be made a variable, but should be tested after every tag is created
  }
## Choose the model parameters here
model:
  {
    dimension: 3, # the dimension of the model and dataset: defines dimensionality of computations
    base_filters: 16, # Set base filters: number of filters present in the initial module of the U-Net convolution; for IncU-Net, keep this divisible by 4
    architecture: densevnet, # options: unet, resunet, deep_resunet, deep_unet, light_resunet, light_unet, fcn, uinc, vgg, densenet
    norm_type: batch, # options: batch, instance, or none (only for VGG); used for all networks
    final_layer: softmax, # can be either sigmoid, softmax or none (none == regression/logits)
    class_list: [0,1,2], # Set the list of labels the model should train on and predict
    ignore_label_validation: 0, # this is the location of the class_list whose performance is ignore during validation metric calculation
    amp: True, # Set if you want to use Automatic Mixed Precision for your operations or not - options: True, False
    print_summary: True, # prints the summary of the model before training; defaults to True

    ## unet_multilayer, unetr, transunet have the following optional parameter:
    depth: 3,

    ## imagenet_unet has the following optional parameter:
    # pretrained (bool) - if True (default), uses the pretrained imagenet weights
    # final_layer - one of ["sigmoid", "softmax", "logsoftmax", "tanh", "identity"]
    # encoder_name (str) - the name of the encoder to use, pick from https://github.com/qubvel/segmentation_models.pytorch#encoders
    # decoder_use_batchnorm (str) - whether to use batch norm or not or inplace, this will override 'norm_type', see https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/decoders/unet/model.py
    # decoder_attention_type (str) - the decoder attention type, see https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/decoders/unet/model.py
    # encoder_depth (int) - the depth of the encoder, also picked up from 'depth'
    # decoder_channels (list) - a list of numbers of channels for each decoder layer, should be same length as 'encoder_depth'
    # converter_type (str) - either acs (targets ACSConv) or conv3d (targets nn.Conv3d) or soft (targets SoftACSConv with learnable weights, default); see https://doi.org/10.1109/JBHI.2021.3049452
    # the following parameters can be used to convert the "imagenet_unet" model to a classifier/regression network; they only come into the picture when the "problem_type" is identified as not segmentation. 
    # - pooling (str): One of "max", "avg"; default is "avg"
    # - dropout (float): Dropout factor in [0, 1); default is 0.2
  }
## metrics to evaluate the validation performance
metrics:
  - dice # segmentation
## this customizes the inference, primarily used for segmentation outputs
inference_mechanism: {
  grid_aggregator_overlap: crop, # this option provides the option to strategize the grid aggregation output; should be either 'crop' or 'average' - https://torchio.readthedocs.io/patches/patch_inference.html#grid-aggregator
  patch_overlap: 0, # amount of overlap of patches during inference, defaults to 0; see https://torchio.readthedocs.io/patches/patch_inference.html#gridsampler
}
# this is to enable or disable lazy loading - setting to true reads all data once during data loading, resulting in improvements
# in I/O at the expense of memory consumption
in_memory: True
# this will save the generated masks for validation and testing data for qualitative analysis
save_output: True
# this will save the patches used during training for qualitative analysis
save_training: False
# Set the Modality : rad for radiology, path for histopathology
modality: rad
## Patch size during training - 2D patch for breast images since third dimension is not patched 
patch_size: [128,128,128]
# uniform: UniformSampler or label: LabelSampler
patch_sampler: uniform
# patch_sampler: label
# patch_sampler:
#   {
#     label:
#       {
#         padding_type: constant # how the label gets padded, for options, see 'mode' in https://numpy.org/doc/stable/reference/generated/numpy.pad.html
#       }
#   }
#If enabled, this parameter pads images and labels when label sampler is used
enable_padding: False
# Number of epochs
num_epochs: 100000
# Set the patience - measured in number of epochs after which, if the performance metric does not improve, exit the training loop - defaults to the number of epochs
patience: 5000
# Set the batch size
batch_size: 6
# gradient clip : norm, value, agc
# clip_mode: norm
# clip_gradient value
# clip_grad: 0.1
## Set the initial learning rate
learning_rate: 0.01
# Learning rate scheduler - options:"triangle", "triangle_modified", "exp", "step", "reduce-on-plateau", "cosineannealing", "triangular", "triangular2", "exp_range"
# triangle/triangle_modified use LambdaLR but triangular/triangular2/exp_range uses CyclicLR
scheduler: cosineannealing
  # {
  #   type: cosineannealing,
    # min_lr: 0.00001,
    # max_lr: 1,
  # }
# Set which loss function you want to use - options : 'dc' - for dice only, 'dcce' - for sum of dice and CE and you can guess the next (only lower-case please)
# options: dc (dice only), dc_log (-log of dice), ce (), dcce (sum of dice and ce), mse () ...
# mse is the MSE defined by torch and can define a variable 'reduction'; see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss
# use mse_torch for regression/classification problems and dice for segmentation
loss_function: dcce
# this parameter weights the loss to handle imbalanced losses better
weighted_loss: True 
#loss_function:
#  {
#    'mse':{
#      'reduction': 'mean' # see https://pytorch.org/docs/stable/generated/torch.nn.MSELoss.html#torch.nn.MSELoss for all options
#    }
#  }
optimizer: adam
nested_training:
  {
    testing: -30, # this controls the testing data splits for final model evaluation; use '1' if this is to be disabled
    validation: -9 # this controls the validation data splits for model training
  }
# data_augmentation: 
#   {
  #   default_probability: 1.0, # keeping probability 1.0 to ensure that all augmentations are applied
  #   'affine':{ # for options, see https://torchio.readthedocs.io/transforms/augmentation.html#randomaffine
  #     'scales': [0.5, 1.5],
  #     'degrees': 25,
  #     'translation': 2,
  #   },
  #   'elastic',
  #   'kspace':{
  #     'probability': 1
  #   },
  #   'motion':{
  #     'probability': 1
  #   },
  #   'bias',
  #   blur, # this is a gaussian blur, and can take 'std' as a subkey, however, the default 'std' is [0, 0.015 * std(image)]
  #   ## example of blur with specific std range
  #   # 'blur': {
  #   #   'std': [0, 1] # example std-dev range, for details, see https://torchio.readthedocs.io/transforms/augmentation.html#torchio.transforms.RandomBlur
  #   # },
  #   'noise': { # for details, see https://torchio.readthedocs.io/transforms/augmentation.html#torchio.transforms.RandomNoise
  #     'mean': 0, # default mean
  #     'std': [0, 1] # default std-dev range
  #   },
  #   noise_var, # this is a gaussian noise, and can take 'std' and 'mean' as a subkey, however, the default 'std' is [0, 0.015 * std(image)]
  #   'gamma',
  #   'swap':{
  #     'patch_size': 15, # patch size for swapping; if a single number if provided, the same number is used for all axes
  #     'num_iterations': 50, # number of times that two patches will be swapped, defaults to 100
  #   },
  #   'flip':{
  #     'axis': [0,1,2] # one or more axes can be put here. if this isn't defined, all axes are considered
  #   },
  #   'anisotropic':{
  #     'axis': [0,1],
  #     'downsampling': [2,2.5]
  #   },
  #   'rotate_90':{ # explicitly rotate image by 90
  #     'axis': [0,2] # one or more axes can be put here. if this isn't defined, all axes are considered
  #   },
  #   'rotate_180', # explicitly rotate image by 180; if 'axis' isn't defined, default is [1,2,3]
  # }
# ## post-processing steps - only applied before output labels are saved
# data_postprocessing:
#   {
#     'fill_holes', # this will fill holes in the image
#     'mapping': {0: 0, 1: 1, 2: 4}, # this will map the labels to a new set of labels, useful to convert labels from combinatorial training (i.e., combined segmentation labels)
#   }
## parallel training on HPC - here goes the command to prepend to send to a high performance computing
# cluster for parallel computing during multi-fold training
# not used for single fold training
# this gets passed before the training_loop, so ensure enough memory is provided along with other parameters
# that your HPC would expect
# ${outputDir} will be changed to the outputDir you pass in CLI + '/${fold_number}'
# ensure that the correct location of the virtual environment is getting invoked, otherwise it would pick up the system python, which might not have all dependencies
# parallel_compute_command: 'qsub -b y -l gpu -l h_vmem=32G -cwd -o ${outputDir}/\$JOB_ID.stdout -e ${outputDir}/\$JOB_ID.stderr `pwd`/sge_wrapper _correct_location_of_virtual_environment_/venv/bin/python'
## queue configuration - https://torchio.readthedocs.io/data/patch_training.html?#queue
# this determines the maximum number of patches that can be stored in the queue. Using a large number means that the queue needs to be filled less often, but more CPU memory is needed to store the patches
q_max_length: 40
# this determines the number of patches to extract from each volume. A small number of patches ensures a large variability in the queue, but training will be slower
q_samples_per_volume: 5
# this determines the number subprocesses to use for data loading; '0' means main process is used
q_num_workers: 20 # scale this according to available CPU resources
# used for debugging
q_verbose: True

logs_training.csv

epoch_no,train_loss,train_dice
0,1.0000572764512263,-2.3867104160176083e-05
1,1.0,0.0
2,1.0,0.0
3,1.0,0.0
4,1.0,0.0
5,1.0,0.0
6,0.9999999250426437,0.0
7,1.0,0.0
8,1.0,0.0
9,1.0,0.0
10,1.0,0.0
11,1.0,0.0
12,1.0,0.0
13,0.9999999250426437,0.0
14,1.0,0.0
15,1.0,0.0
16,1.0,0.0
17,1.0,0.0
18,1.0,0.0
19,1.0,0.0
20,0.9999999250426437,0.0
21,1.0,0.0
22,1.0,0.0
23,1.0,0.0
24,1.0,0.0
25,1.0,0.0
26,1.0,0.0
27,1.0,0.0
28,0.9999999259457444,0.0
29,1.0,0.0
30,1.0,0.0
31,1.0,0.0
32,1.0,0.0
33,1.0,0.0
34,1.0,0.0
35,0.9999999250426437,0.0
36,1.0,0.0
37,1.0,0.0
38,1.0,0.0
39,1.0,0.0
40,1.0,0.0
41,1.0,0.0
42,1.0,0.0
43,0.9999999250426437,0.0
44,1.0,0.0
45,1.0,0.0
46,1.0,0.0
47,1.0,0.0
48,1.0,0.0
49,1.0,0.0

logs_validation.csv

epoch_no,valid_loss,valid_dice
0,0.9969627439975739,0.003103292337618768
1,0.9972733199596405,0.0029670228599570692
2,0.9970957994461059,0.00343705068808049
3,0.997797566652298,0.003026781778316945
4,0.9980143249034882,0.002899824071209878
5,0.9978552162647247,0.003034979058429599
6,0.9971363425254822,0.003269073914270848
7,0.9971782982349395,0.0030285805580206214
8,0.9978388547897339,0.0029356601182371376
9,0.9978420197963714,0.0029096938436850906
10,0.9978804528713227,0.0028765499708242716
11,0.9978895425796509,0.002869846601970494
12,0.9978834569454194,0.002881262043956667
13,0.9976308465003967,0.0028953364468179642
14,0.9977571487426757,0.0028960482217371465
15,0.9978282272815704,0.0029001160990446808
16,0.9978475630283355,0.002899083390366286
17,0.997849702835083,0.002898866881150752
18,0.9978659987449646,0.002896873140707612
19,0.9978699326515198,0.0028959849616512654
20,0.9977187037467956,0.002893106988631189
21,0.9977897763252258,0.0028910385095514356
22,0.9977940022945404,0.0028907033847644925
23,0.9978367686271667,0.002887433546129614
24,0.997847831249237,0.0028861146885901688
25,0.9978702306747437,0.0028819039696827533
26,0.9978783130645752,0.002879004494752735
27,0.9978791177272797,0.002878568589221686
28,0.9977535724639892,0.0028740785433910786
29,0.9977844715118408,0.0028722704271785917
30,0.9978403806686401,0.00286685653263703
31,0.9978619992733002,0.0028632209519855677
32,0.997864431142807,0.002862657676450908
33,0.9978822529315948,0.002857255982235074
34,0.9978865921497345,0.002855152334086597
35,0.9977645993232727,0.002848858619108796
36,0.9978000402450562,0.002844794269185513
37,0.9978064119815826,0.002844170422758907
38,0.9978546559810638,0.0028383528580889104
39,0.9978667795658112,0.0028360979515127836
40,0.9978908419609069,0.00282959018368274
41,0.997899466753006,0.002825476776342839
42,0.9979004502296448,0.0028248236631043255
43,0.9977828383445739,0.002819003676995635
44,0.9978051722049713,0.002816858654841781
45,0.9978632867336273,0.002810529514681548
46,0.9978853225708008,0.0028066353290341793
47,0.9978877246379853,0.0028060500510036944
48,0.9979057013988495,0.002800512034446001
49,0.9979099929332733,0.0027983693638816477
sarthakpati commented 1 year ago

Since I have never used this network before, I can only guess why this is happening. Can you try playing with the learning rate? It seems like this network takes time to train, so perhaps something like learning rate of 1 or so to start off with...

carlpe commented 1 year ago

Yes @sarthakpati, I will try different settings, and also different data sets 🙂

Thank you

carlpe commented 1 year ago

So.. I have tried with several different combinations of learning_rate now and also with different data.

The same thing happens, that the train_loss = 1.0 and the train_dice = 0.0

When I tried it with a different data set (MRI data), the same thing happens during training, but here the validation_dice was negative on some of the epochs (I am unsure if this is normal).

I tried using the MRI data with the resunet config and there everything seems to be working fine, with good results straight away.

sarthakpati commented 1 year ago

Hmm... Then it definitely means there is an issue with the DenseVNet implementation. I am unsure if I can personally debug this in more detail till before the end of Q1 2023, but I will keep this open. It would be amazing if you are able to debug it and check what the problem might be, otherwise we will need to discuss this internally to check where we can put it on the roadmap.

carlpe commented 1 year ago

No problem 😊

github-actions[bot] commented 1 year ago

Stale issue message