MIC-DKFZ / nnUNet

Apache License 2.0
5.97k stars 1.77k forks source link

How to combine a mean-teacher architecture with nnUnet? #1709

Closed LIKP0 closed 1 year ago

LIKP0 commented 1 year ago

Hi everyone, I'm trying to combine the self-supervised architecture mean-teacher (MT) with nnUnet, which means to replace the Unet architecture with nnUnet in the student and teacher model in the following figure.

I know it's simple to apply MT on Unet but how to apply it on nnUnet? For example, how to do EMA in a proper way on nnUnet. I really have no direction towards huge code of nnUnet...

Could anyone give some advice? Thanks in advance!

image

The picture is cited as Fig 1 in Yu L, Wang S, Li X, et al. Uncertainty-aware self-ensembling model for semi-supervised 3D left atrium segmentation[C]//Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part II 22. Springer International Publishing, 2019: 605-613.

sten2lu commented 1 year ago

Hi @LIKP0,

thanks for your interest in nnU-Net. You are correct; implementing Mean Teacher in nnU-Net is not as straightforward as for simpler implementations of the U-Net.

However, I believe that it would definitely be possible. So, I will give you some guidance on which changes need to be made and where you can learn about implementing them yourself. Of course, I might oversee some things, but I hope this helps you to get started on your project.

You would require for the full mean teacher a separate 'nnU-NetTrainer' (handling the logic of the training and using the EMA model), a dataloader (loading the unannotated data alongside the annotated data), as well as uncertainty estimations.

For the nnU-Net Trainer, I would suggest reading the documentation on how to extend nnU-Net.

Regarding the dataloader, I would suggest reading the code for the dataloaders, as this is a pretty specific change and not covered in the documentation.

Monte Carlo Dropout is currently not in the nnU-Net model. You need to implement this yourself following the instructions in the extending nnU-Net documentation.

Best regards,

Carsten

Update: There is no EMA model in standard nnU-Net.

LIKP0 commented 1 year ago

Thanks for your great advice! Especially for your reminding on existing EMA model.

I'm trying to customize a nnUNetTrainer class combining the features in trainging/nnUNetTrainer/variants/. I think the documentation on how to extend nnUNet is really useful.

Thanks for your great work again! I come from nnUNet V1 and I really see the significant improvement on V2. It's really meaningful to research!

Wait me several days and I will share my results here.

LIKP0 commented 1 year ago

Hi @sten2lu ,

I just want a little confirm that nnUNet doesn't need nonlinearity at the end of network anymore.

For example, in my following code the output and target come directly from the network and they don't need to apply softmax again, right?

    def UA_consistency_loss(self, output, target, uncertainty):  # Uncertainty-Aware Consistency Loss
        # Your architecture should NOT apply any nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that?
        # output_softmax = F.softmax(output, dim=1)
        # target_softmax = F.softmax(target, dim=1)
        mse_loss = (output - target) ** 2
        mask = (uncertainty < self.UA_threshold).float()
        consistency_loss = torch.sum(mask * mse_loss) / (2 * torch.sum(mask) + 1e-16)
        return self.UA_consistency_weight * consistency_loss

Thanks in advance! Have a good day!

sten2lu commented 1 year ago

Hi @LIKP0, nnU-Nets standard forward pass returns the logits. Therefore you will need to add softmax to obtain the likelihoods. In nnU-Net the softmax is applied during postprocessing.

Best regards, Carsten

LIKP0 commented 1 year ago

Thank you @sten2lu ,Thank you nnUNet! I have successfully completed my task and really appreciate to all of you!

sten2lu commented 11 months ago

Currently, there is no EMA model in the standard nnU-Net Trainer. This has been updated in my previous answer.

I am sorry for any resulting confusion.

Best regards, Carsten

lichen14 commented 4 weeks ago

Thank you @sten2lu ,Thank you nnUNet! I have successfully completed my task and really appreciate to all of you!

Hi @LIKP0 , could you please share the code the you mentioned above? I am also curious about the Mean teacher in nnU-Net.

Thanks.

LIKP0 commented 4 weeks ago

Hi @lichen14, my code is as followings. I can't promise the code is very accurate, maybe you can find some bugs and tell me.

In my experiment, the improvement of MT is not that obvious (<1%). I think maybe nnUNet itself is ruboust enough.

  import os
  import shutil
  import sys
  import math
  import numpy as np
  from typing import List, Tuple

  import torch
  from torch import autocast
  import torch.nn.functional as F
  from torch import distributed as dist

  from batchgenerators.utilities.file_and_folder_operations import join, isfile, save_json, maybe_mkdir_p
  from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
  from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
      LimitedLenWrapper
  from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
  from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
  from nnunetv2.training.dataloading.utils import unpack_dataset
  from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
  from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
  from nnunetv2.utilities.helpers import dummy_context, empty_cache
  from nnunetv2.utilities.collate_outputs import collate_outputs

  class nnUNetTrainerMT(nnUNetTrainer):

      def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
                   device: torch.device = torch.device('cuda')):
          """used for debugging plans etc
          When debug using the local port, you should change the paths.py file:
          # nnUNet_raw = os.environ.get('nnUNet_raw')
          # nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
          # nnUNet_results = os.environ.get('nnUNet_results')

          nnUNet_raw = "nnUNet_raw/nnUNet_raw"
          nnUNet_preprocessed = "nnUNet_raw/nnUNet_preprocessed"
          nnUNet_results = "nnUNet_raw/nnUNet_results_MT"
          """
          super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
          self.num_iterations_per_epoch = 100
          self.num_epochs = 150
          self.extra_dataset_folder = 'nnUNet_raw/nnUNet_preprocessed/Dataset062_Atlas_only/nnUNetPlans_3d_fullres/'
          self.extra_dataset_src = 'nnUNet_raw/nnUNet_raw/Dataset062_Atlas_only/imagesTr/'
          self.dataloader_extra = None
          self.HQ_loss_weight = 1.0
          self.ema_model = None
          self.ema_decay = 0.99
          self.UA_base_consistency_weight = 10.0
          self.UA_base_threshold = 0.75
          self.UA_consistency_weight = None  # ramp up in num_epochs step from 0 to base weight
          self.UA_threshold = None  # ramp up in num_epochs step from 0.75 to 1

      def on_train_start(self):
          if not self.was_initialized:
              self.initialize()

          maybe_mkdir_p(self.output_folder)

          # make sure deep supervision is on in the network
          self.set_deep_supervision_enabled(True)

          self.print_plans()
          empty_cache(self.device)

          # maybe unpack
          if self.unpack_dataset and self.local_rank == 0:
              self.print_to_log_file('unpacking dataset...')
              unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
                             num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
              unpack_dataset(self.extra_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
                             num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
              self.print_to_log_file('unpacking done...')

          if self.is_ddp:
              dist.barrier()

          # dataloaders must be instantiated here because they need access to the training data which may not be present
          # when doing inference
          self.dataloader_train, self.dataloader_val, self.dataloader_extra = self.get_dataloaders()

          # copy plans and dataset.json so that they can be used for restoring everything we need for inference
          save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
          save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)

          # we don't really need the fingerprint but its still handy to have it with the others
          shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
                      join(self.output_folder_base, 'dataset_fingerprint.json'))

          # produces a pdf in output folder
          self.plot_network_architecture()

          self._save_debug_information()

          # print(f"batch size: {self.batch_size}")
          # print(f"oversample: {self.oversample_foreground_percent}")

      def on_train_end(self):
          # dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards.
          # This will lead to the wrong current epoch to be stored
          self.current_epoch -= 1
          self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
          self.current_epoch += 1

          # now we can delete latest
          if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
              os.remove(join(self.output_folder, "checkpoint_latest.pth"))

          # shut down dataloaders
          old_stdout = sys.stdout
          with open(os.devnull, 'w') as f:
              sys.stdout = f
              if self.dataloader_train is not None:
                  self.dataloader_train._finish()
              if self.dataloader_val is not None:
                  self.dataloader_val._finish()
              if self.dataloader_extra is not None:
                  self.dataloader_extra._finish()
              sys.stdout = old_stdout

          empty_cache(self.device)
          self.print_to_log_file("Training done.")

      def get_tr_and_val_datasets(self):
          # create dataset split
          tr_keys, val_keys = self.do_split()

          # load the datasets for training and validation. Note that we always draw random samples so we really don't
          # care about distributing training cases across GPUs.
          dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
                                     folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
                                     num_images_properties_loading_threshold=0)
          dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
                                      folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
                                      num_images_properties_loading_threshold=0)

          # Learn from do_split to get extra_keys
          extra_keys = sorted([i[:4] for i in os.listdir(self.extra_dataset_src)])
          dataset_extra = nnUNetDataset(self.extra_dataset_folder, extra_keys,
                                        folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
                                        num_images_properties_loading_threshold=0)
          return dataset_tr, dataset_val, dataset_extra

      def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
          dataset_tr, dataset_val, dataset_extra = self.get_tr_and_val_datasets()

          dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
                                     initial_patch_size,
                                     self.configuration_manager.patch_size,
                                     self.label_manager,
                                     oversample_foreground_percent=self.oversample_foreground_percent,
                                     sampling_probabilities=None, pad_sides=None)
          dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
                                      self.configuration_manager.patch_size,
                                      self.configuration_manager.patch_size,
                                      self.label_manager,
                                      oversample_foreground_percent=self.oversample_foreground_percent,
                                      sampling_probabilities=None, pad_sides=None)
          dl_extra = nnUNetDataLoader3D(dataset_extra, self.batch_size,
                                        self.configuration_manager.patch_size,
                                        self.configuration_manager.patch_size,
                                        self.label_manager,
                                        oversample_foreground_percent=self.oversample_foreground_percent,
                                        sampling_probabilities=None, pad_sides=None)

          return dl_tr, dl_val, dl_extra

      def get_dataloaders(self):
          # we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
          # we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
          patch_size = self.configuration_manager.patch_size
          dim = len(patch_size)

          # needed for deep supervision: how much do we need to downscale the segmentation targets for the different
          # outputs?
          deep_supervision_scales = self._get_deep_supervision_scales()

          rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
              self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()

          # training pipeline
          tr_transforms = self.get_training_transforms(
              patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
              order_resampling_data=3, order_resampling_seg=1,
              use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
              is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
              regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
              ignore_label=self.label_manager.ignore_label)

          # validation pipeline
          val_transforms = self.get_validation_transforms(deep_supervision_scales,
                                                          is_cascaded=self.is_cascaded,
                                                          foreground_labels=self.label_manager.foreground_labels,
                                                          regions=self.label_manager.foreground_regions if
                                                          self.label_manager.has_regions else None,
                                                          ignore_label=self.label_manager.ignore_label)

          dl_tr, dl_val, dl_extra = self.get_plain_dataloaders(initial_patch_size, dim)

          allowed_num_processes = get_allowed_n_proc_DA()
          if allowed_num_processes == 0:
              mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
              mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
              mt_gen_extra = SingleThreadedAugmenter(dl_extra, tr_transforms)
          else:
              mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
                                               num_processes=allowed_num_processes, num_cached=6, seeds=None,
                                               pin_memory=self.device.type == 'cuda', wait_time=0.02)
              mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
                                             transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
                                             num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
                                             wait_time=0.02)
              mt_gen_extra = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_extra,
                                               transform=tr_transforms,
                                               num_processes=allowed_num_processes, num_cached=6, seeds=None,
                                               pin_memory=self.device.type == 'cuda', wait_time=0.02)
          return mt_gen_train, mt_gen_val, mt_gen_extra

      def run_training(self):
          self.on_train_start()

          for epoch in range(self.current_epoch, self.num_epochs):
              self.on_epoch_start()

              self.on_train_epoch_start()
              train_outputs = []
              for batch_id in range(self.num_iterations_per_epoch):
                  train_outputs.append(self.train_step(next(self.dataloader_train), next(self.dataloader_extra)))
              self.on_train_epoch_end(train_outputs)

              with torch.no_grad():
                  self.on_validation_epoch_start()
                  val_outputs = []
                  for batch_id in range(self.num_val_iterations_per_epoch):
                      val_outputs.append(self.validation_step(next(self.dataloader_val)))
                  self.on_validation_epoch_end(val_outputs)

              self.on_epoch_end()

          self.on_train_end()

      def initialize(self):
          super().initialize()
          self.ema_model = self.build_network_architecture(self.plans_manager, self.dataset_json,
                                                           self.configuration_manager,
                                                           self.num_input_channels,
                                                           enable_deep_supervision=True).to(self.device)

      def on_train_epoch_start(self):
          super().on_train_epoch_start()
          self.ema_model.train()
          self.UA_consistency_weight = self.UA_base_consistency_weight * self.sigmoid_rampup(self.current_epoch,
                                                                                             self.num_epochs // 2)
          self.UA_threshold = self.UA_base_threshold + 0.25 * self.sigmoid_rampup(self.current_epoch,
                                                                                  self.num_epochs // 2)
          self.print_to_log_file(f'Epoch {self.current_epoch} UA_consistency_weight: {self.UA_consistency_weight}')
          self.print_to_log_file(f'Epoch {self.current_epoch} UA_threshold: {self.UA_threshold}')

      def train_step(self, batch: dict, extra_batch: dict) -> dict:
          data = batch['data']
          target = batch['target']
          extra_data = extra_batch['data']
          extra_target = extra_batch['target']

          data = data.to(self.device, non_blocking=True)
          if isinstance(target, list):
              target = [i.to(self.device, non_blocking=True) for i in target]
          else:
              target = target.to(self.device, non_blocking=True)

          extra_data = extra_data.to(self.device, non_blocking=True)
          if isinstance(extra_target, list):
              extra_target = [i.to(self.device, non_blocking=True) for i in extra_target]
          else:
              extra_target = extra_target.to(self.device, non_blocking=True)

          self.optimizer.zero_grad(set_to_none=True)
          with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
              # List: 5, [(2, 36, 128, 128, 128), (2, 36, 64, 64, 64), ..., (2, 36, 32, 32, 32)]
              # There are 36 labels including background, so output channel is 36
              output = self.network(data)
              l = self.loss(output, target)  # nnUNet base loss: compute loss on 5 scale and aggregate

              with torch.no_grad():
                  ema_output = self.ema_model(self.get_noisy_input(data))  # only use top scale
              uncertainty = self.MC_uncertainty(self.ema_model, data)  # forward T times on ema_model to get uncertainty
              l2 = self.UA_consistency_loss(output[0], ema_output[0], uncertainty)

              extra_output = self.network(extra_data)
              l3 = self.loss(extra_output, extra_target)
              l3 = self.HQ_loss_weight * l3

          total_loss = l + l2 + l3
          if self.grad_scaler is not None:
              self.grad_scaler.scale(total_loss).backward()
              self.grad_scaler.unscale_(self.optimizer)
              torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
              self.grad_scaler.step(self.optimizer)
              self.grad_scaler.update()
          else:
              total_loss.backward()
              torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
              self.optimizer.step()

          # Update EMA model one time per iteration
          self.update_ema_variables(self.network, self.ema_model, self.current_epoch)

          return {'loss': total_loss.detach().cpu().numpy(), 'base_loss': l.detach().cpu().numpy(),
                  'UA_loss': l2.detach().cpu().numpy(), 'HQ_loss': l3.detach().cpu().numpy()}

      def on_train_epoch_end(self, train_outputs: List[dict]):
          outputs = collate_outputs(train_outputs)
          loss_here = np.mean(outputs['loss'])
          base_loss = np.mean(outputs['base_loss'])
          UA_loss = np.mean(outputs['UA_loss'])
          HQ_loss = np.mean(outputs['HQ_loss'])

          self.logger.log('train_losses', loss_here, self.current_epoch)
          self.print_to_log_file(f'Epoch {self.current_epoch} base_losses: ', base_loss)
          self.print_to_log_file(f'Epoch {self.current_epoch} UA_consistency_loss: ', UA_loss)
          self.print_to_log_file(f'Epoch {self.current_epoch} HQ_loss: ', HQ_loss)

      def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):  # No mirroring
          rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
              super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
          mirror_axes = None
          self.inference_allowed_mirroring_axes = None
          return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes

      def update_ema_variables(self, model, ema_model, global_step):
          # Use the true average until the exponential average is- more correct
          alpha = min(1 - 1 / (global_step + 1), self.ema_decay)
          for ema_param, param in zip(ema_model.parameters(), model.parameters()):
              ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)

      def UA_consistency_loss(self, output, target, uncertainty):  # Uncertainty-Aware Consistency Loss
          output_softmax = F.softmax(output, dim=1)
          target_softmax = F.softmax(target, dim=1)
          mse_loss = (output_softmax - target_softmax) ** 2
          mask = (uncertainty < self.UA_threshold).float()
          consistency_loss = torch.sum(mask * mse_loss) / (2 * torch.sum(mask) + 1e-16)  # 2 for batch size 2
          return self.UA_consistency_weight * consistency_loss

      def MC_uncertainty(self, ema_model, data):
          T = 6
          data_r = data.repeat(2, 1, 1, 1, 1)  # (2, 1, 128, 128, 128) ==> (4, 1, 128, 128, 128)
          stride = data_r.shape[0] // 2  # 2
          preds = torch.zeros(
              [stride * T, 36, data.shape[2], data.shape[3], data.shape[4]]).cuda()  # (16, 36, 128, 128, 128)
          for i in range(T // 2):
              ema_inputs = self.get_noisy_input(data_r)
              with torch.no_grad():
                  preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs)[0]  # only use top scale
          preds = F.softmax(preds, dim=1)
          preds = preds.reshape(T, stride, 36, data.shape[2], data.shape[3], data.shape[4])
          preds = torch.mean(preds, dim=0)
          uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)
          uncertainty = uncertainty / math.log(36)  # normalize uncertainty to 0 to 1, cuz ln36 is the max value!!!
          return uncertainty

      def get_noisy_input(self, data):
          return data + torch.clamp(torch.randn_like(data) * 0.1, -0.2, 0.2)

      def sigmoid_rampup(self, current, rampup_length):
          """Exponential rampup from https://arxiv.org/abs/1610.02242"""
          if rampup_length == 0:
              return 1.0
          else:
              current = np.clip(current, 0.0, rampup_length)
              phase = 1.0 - current / rampup_length
              return float(np.exp(-5.0 * phase * phase))

      # def save_MRI(self, img, path):
      #     ref_img = sitk.ReadImage("oasis_atlas.nii.gz")
      #     img = sitk.GetImageFromArray(img[0, 0, ...].cpu().detach().numpy())
      #     img.SetOrigin(ref_img.GetOrigin())
      #     img.SetDirection(ref_img.GetDirection())
      #     img.SetSpacing(ref_img.GetSpacing())
      #     sitk.WriteImage(img, path)
lichen14 commented 4 weeks ago

Hi @lichen14, my code is as followings. I can't promise the code is very accurate, maybe you can find some bugs and tell me.

In my experiment, the improvement of MT is not that obvious (<1%). I think maybe nnUNet itself is ruboust enough.


Hi @LIKP0 , thanks for your reply.
Your code looks like a separate file, such as a `nnUNetTrainerMT.py,` and then runs through `nnUNetv2_train`. 
I'll try to test and if there's any progress, I'll continue @you~

By the way, I wonder which dataset was used in your experiment to get a <1% improvement?

Thanks.

LIKP0 commented 4 weeks ago

@lichen14 My experiment is based on OASIS-I and I recommend this preprocessed dataset. I generate many pseudo lables by registration methods. Then, use the original labels as high quality labels and pseudo labels as low quality ones. The segmentaion is performed on 35 brain structures, but I forget the concrete dice scores.