ivadomed / ms-lesion-agnostic

Deep Learning contrasts "agnostic" tool for MS lesion segmentation in the spinal cord
MIT License
4 stars 0 forks source link

Training of an STU-Net model for ms lesion segmentation #29

Open plbenveniste opened 1 month ago

plbenveniste commented 1 month ago

In this issue, I explore the work done to segment MS lesion in the spinal cord using the STU-Net.

I used the code from this repo: https://github.com/uni-medical/STU-Net

The dataset used is the nnUnet preprocessed data store in : ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/

plbenveniste commented 1 month ago

Here are the steps taken to train an STU-Net (the documentation of the repo is not up to date).

In the project folder:

The trainer needs to be updated with the following code:

class STUNetTrainer(nnUNetTrainer):
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
                 device: torch.device = torch.device('cuda')):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.num_epochs = 1000
        self.initial_lr = 1e-2

    @staticmethod
    def build_network_architecture(plans_manager,
                                   dataset_json,
                                   configuration_manager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module: 
        label_manager = plans_manager.get_label_manager(dataset_json)
        num_classes=label_manager.num_segmentation_heads
        kernel_sizes = [[3,3,3]] * 6
        strides=configuration_manager.pool_op_kernel_sizes[1:]
        if len(strides)>5:
            strides = strides[:5]
        while len(strides)<5:
            strides.append([1,1,1])
        return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]], 
                      pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)

    def initialize(self):
        if not self.was_initialized:
            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                                   self.dataset_json)

            self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
                                                           self.configuration_manager,
                                                           self.num_input_channels,
                                                           enable_deep_supervision=True).to(self.device)
            # compile network for free speedup
            if self._do_i_compile():
                self.print_to_log_file('Compiling network...')
                self.network = torch.compile(self.network)

            self.optimizer, self.lr_scheduler = self.configure_optimizers()
            # if ddp, wrap in DDP wrapper
            if self.is_ddp:
                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
                self.network = DDP(self.network, device_ids=[self.local_rank])

            self.loss = self._build_loss()
            self.was_initialized = True
        else:
            raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                               "That should not happen.")

Than I export the nnUNet folders (where the dataset has been preprocessed):

export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_preprocessed"

Then I launched the training (on koios):

CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model  -tr STUNetTrainer_base_ft

Referencing this issue which helped: https://github.com/uni-medical/STU-Net/issues/34

plbenveniste commented 4 weeks ago

Before running inference, the file predict_from_raw_data.py needed to be copied from STU-Net/nnUNet-2.2/nnunetv2/inference/ to the nnUNet repo.

The inference on the test set was done using:

CUDA_VISIBLE_DEVICES=0 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set -d 201 -c 3d_fullres -f 1 -chk checkpoint_best.pth

The results were computed (in the venv_nnunet environment) with:

python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs  -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set

And the plots were obtained doing:

python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test
plbenveniste commented 4 weeks ago

Here are the results computed with the training of the STU-Net with the base model: dice_scores_contrast dice_scores_orientation dice_scores_site

The results are pretty similar to that of the nnUNet model trained.

TODO: