Open plbenveniste opened 1 month ago
Here are the steps taken to train an STU-Net (the documentation of the repo is not up to date).
In the project folder:
git clone https://github.com/uni-medical/STU-Net/
git clone https://github.com/MIC-DKFZ/nnUNet/
conda create -n venv_stunet2 python=3.9
and activate itconda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
cp STU-Net/nnUNet-2.2/nnunetv2/run/run_finetuning_stunet.py nnUNet/nnunetv2/run/
cp STU-Net/nnUNet-2.2/nnunetv2/training/nnUNetTrainer/STUNetTrainer.py nnUNet/nnunetv2/training/nnUNetTrainer/
pip install -e .
base_ep4k.model
The trainer needs to be updated with the following code:
class STUNetTrainer(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_epochs = 1000
self.initial_lr = 1e-2
@staticmethod
def build_network_architecture(plans_manager,
dataset_json,
configuration_manager,
num_input_channels,
enable_deep_supervision: bool = True) -> nn.Module:
label_manager = plans_manager.get_label_manager(dataset_json)
num_classes=label_manager.num_segmentation_heads
kernel_sizes = [[3,3,3]] * 6
strides=configuration_manager.pool_op_kernel_sizes[1:]
if len(strides)>5:
strides = strides[:5]
while len(strides)<5:
strides.append([1,1,1])
return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]],
pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)
def initialize(self):
if not self.was_initialized:
self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
self.dataset_json)
self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
self.configuration_manager,
self.num_input_channels,
enable_deep_supervision=True).to(self.device)
# compile network for free speedup
if self._do_i_compile():
self.print_to_log_file('Compiling network...')
self.network = torch.compile(self.network)
self.optimizer, self.lr_scheduler = self.configure_optimizers()
# if ddp, wrap in DDP wrapper
if self.is_ddp:
self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
self.network = DDP(self.network, device_ids=[self.local_rank])
self.loss = self._build_loss()
self.was_initialized = True
else:
raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
"That should not happen.")
Than I export the nnUNet folders (where the dataset has been preprocessed):
export nnUNet_raw="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw"
export nnUNet_results="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results"
export nnUNet_preprocessed="/home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_preprocessed"
Then I launched the training (on koios):
CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 201 3d_fullres 1 -pretrained_weights /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/stunet_pretrained/base_ep4k.model -tr STUNetTrainer_base_ft
Referencing this issue which helped: https://github.com/uni-medical/STU-Net/issues/34
Before running inference, the file predict_from_raw_data.py
needed to be copied from STU-Net/nnUNet-2.2/nnunetv2/inference/
to the nnUNet repo.
The inference on the test set was done using:
CUDA_VISIBLE_DEVICES=0 nnUNetv2_predict -i /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -o /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set -d 201 -c 3d_fullres -f 1 -chk checkpoint_best.pth
The results were computed (in the venv_nnunet
environment) with:
python nnunet/evaluate_predictions.py -pred-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ -label-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/labelsTs -image-folder ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/imagesTs/ -conversion-dict ~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/conversion_dict.json -output-folder /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set
And the plots were obtained doing:
python nnunet/plot_performance.py --pred-dir-path /home/plbenveniste/net/ms-lesion-agnostic/stunet_experiments/second_exp/nnUNet_results/Dataset201_msLesionAgnostic/STUNetTrainer_base_ft__nnUNetPlans__3d_fullres/fold_1/test_set/ --data-json-path /home/plbenveniste/net/ms-lesion-agnostic/msd_data/dataset_2024-07-24_seed42_lesionOnly.json --split test
Here are the results computed with the training of the STU-Net with the base model:
The results are pretty similar to that of the nnUNet model trained.
TODO:
In this issue, I explore the work done to segment MS lesion in the spinal cord using the STU-Net.
I used the code from this repo: https://github.com/uni-medical/STU-Net
The dataset used is the nnUnet preprocessed data store in :
~/net/ms-lesion-agnostic/nnunet_experiments/nnUNet_raw/Dataset201_msLesionAgnostic/