Closed LIKP0 closed 1 year ago
Hi @LIKP0,
thanks for your interest in nnU-Net. You are correct; implementing Mean Teacher in nnU-Net is not as straightforward as for simpler implementations of the U-Net.
However, I believe that it would definitely be possible. So, I will give you some guidance on which changes need to be made and where you can learn about implementing them yourself. Of course, I might oversee some things, but I hope this helps you to get started on your project.
You would require for the full mean teacher a separate 'nnU-NetTrainer' (handling the logic of the training and using the EMA model), a dataloader (loading the unannotated data alongside the annotated data), as well as uncertainty estimations.
For the nnU-Net Trainer, I would suggest reading the documentation on how to extend nnU-Net.
Regarding the dataloader, I would suggest reading the code for the dataloaders, as this is a pretty specific change and not covered in the documentation.
Monte Carlo Dropout is currently not in the nnU-Net model. You need to implement this yourself following the instructions in the extending nnU-Net documentation.
Best regards,
Carsten
Update: There is no EMA model in standard nnU-Net.
Thanks for your great advice! Especially for your reminding on existing EMA model.
I'm trying to customize a nnUNetTrainer class combining the features in trainging/nnUNetTrainer/variants/
. I think the documentation on how to extend nnUNet is really useful.
Thanks for your great work again! I come from nnUNet V1 and I really see the significant improvement on V2. It's really meaningful to research!
Wait me several days and I will share my results here.
Hi @sten2lu ,
I just want a little confirm that nnUNet doesn't need nonlinearity at the end of network anymore.
For example, in my following code the output and target come directly from the network and they don't need to apply softmax again, right?
def UA_consistency_loss(self, output, target, uncertainty): # Uncertainty-Aware Consistency Loss
# Your architecture should NOT apply any nonlinearities at the end (softmax, sigmoid etc). nnU-Net does that?
# output_softmax = F.softmax(output, dim=1)
# target_softmax = F.softmax(target, dim=1)
mse_loss = (output - target) ** 2
mask = (uncertainty < self.UA_threshold).float()
consistency_loss = torch.sum(mask * mse_loss) / (2 * torch.sum(mask) + 1e-16)
return self.UA_consistency_weight * consistency_loss
Thanks in advance! Have a good day!
Hi @LIKP0, nnU-Nets standard forward pass returns the logits. Therefore you will need to add softmax to obtain the likelihoods. In nnU-Net the softmax is applied during postprocessing.
Best regards, Carsten
Thank you @sten2lu ,Thank you nnUNet! I have successfully completed my task and really appreciate to all of you!
Currently, there is no EMA model in the standard nnU-Net Trainer. This has been updated in my previous answer.
I am sorry for any resulting confusion.
Best regards, Carsten
Thank you @sten2lu ,Thank you nnUNet! I have successfully completed my task and really appreciate to all of you!
Hi @LIKP0 , could you please share the code the you mentioned above? I am also curious about the Mean teacher in nnU-Net.
Thanks.
Hi @lichen14, my code is as followings. I can't promise the code is very accurate, maybe you can find some bugs and tell me.
In my experiment, the improvement of MT is not that obvious (<1%). I think maybe nnUNet itself is ruboust enough.
import os
import shutil
import sys
import math
import numpy as np
from typing import List, Tuple
import torch
from torch import autocast
import torch.nn.functional as F
from torch import distributed as dist
from batchgenerators.utilities.file_and_folder_operations import join, isfile, save_json, maybe_mkdir_p
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
LimitedLenWrapper
from nnunetv2.training.dataloading.data_loader_3d import nnUNetDataLoader3D
from nnunetv2.training.dataloading.nnunet_dataset import nnUNetDataset
from nnunetv2.training.dataloading.utils import unpack_dataset
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.helpers import dummy_context, empty_cache
from nnunetv2.utilities.collate_outputs import collate_outputs
class nnUNetTrainerMT(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
"""used for debugging plans etc
When debug using the local port, you should change the paths.py file:
# nnUNet_raw = os.environ.get('nnUNet_raw')
# nnUNet_preprocessed = os.environ.get('nnUNet_preprocessed')
# nnUNet_results = os.environ.get('nnUNet_results')
nnUNet_raw = "nnUNet_raw/nnUNet_raw"
nnUNet_preprocessed = "nnUNet_raw/nnUNet_preprocessed"
nnUNet_results = "nnUNet_raw/nnUNet_results_MT"
"""
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.num_iterations_per_epoch = 100
self.num_epochs = 150
self.extra_dataset_folder = 'nnUNet_raw/nnUNet_preprocessed/Dataset062_Atlas_only/nnUNetPlans_3d_fullres/'
self.extra_dataset_src = 'nnUNet_raw/nnUNet_raw/Dataset062_Atlas_only/imagesTr/'
self.dataloader_extra = None
self.HQ_loss_weight = 1.0
self.ema_model = None
self.ema_decay = 0.99
self.UA_base_consistency_weight = 10.0
self.UA_base_threshold = 0.75
self.UA_consistency_weight = None # ramp up in num_epochs step from 0 to base weight
self.UA_threshold = None # ramp up in num_epochs step from 0.75 to 1
def on_train_start(self):
if not self.was_initialized:
self.initialize()
maybe_mkdir_p(self.output_folder)
# make sure deep supervision is on in the network
self.set_deep_supervision_enabled(True)
self.print_plans()
empty_cache(self.device)
# maybe unpack
if self.unpack_dataset and self.local_rank == 0:
self.print_to_log_file('unpacking dataset...')
unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
unpack_dataset(self.extra_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
self.print_to_log_file('unpacking done...')
if self.is_ddp:
dist.barrier()
# dataloaders must be instantiated here because they need access to the training data which may not be present
# when doing inference
self.dataloader_train, self.dataloader_val, self.dataloader_extra = self.get_dataloaders()
# copy plans and dataset.json so that they can be used for restoring everything we need for inference
save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)
# we don't really need the fingerprint but its still handy to have it with the others
shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
join(self.output_folder_base, 'dataset_fingerprint.json'))
# produces a pdf in output folder
self.plot_network_architecture()
self._save_debug_information()
# print(f"batch size: {self.batch_size}")
# print(f"oversample: {self.oversample_foreground_percent}")
def on_train_end(self):
# dirty hack because on_epoch_end increments the epoch counter and this is executed afterwards.
# This will lead to the wrong current epoch to be stored
self.current_epoch -= 1
self.save_checkpoint(join(self.output_folder, "checkpoint_final.pth"))
self.current_epoch += 1
# now we can delete latest
if self.local_rank == 0 and isfile(join(self.output_folder, "checkpoint_latest.pth")):
os.remove(join(self.output_folder, "checkpoint_latest.pth"))
# shut down dataloaders
old_stdout = sys.stdout
with open(os.devnull, 'w') as f:
sys.stdout = f
if self.dataloader_train is not None:
self.dataloader_train._finish()
if self.dataloader_val is not None:
self.dataloader_val._finish()
if self.dataloader_extra is not None:
self.dataloader_extra._finish()
sys.stdout = old_stdout
empty_cache(self.device)
self.print_to_log_file("Training done.")
def get_tr_and_val_datasets(self):
# create dataset split
tr_keys, val_keys = self.do_split()
# load the datasets for training and validation. Note that we always draw random samples so we really don't
# care about distributing training cases across GPUs.
dataset_tr = nnUNetDataset(self.preprocessed_dataset_folder, tr_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
num_images_properties_loading_threshold=0)
dataset_val = nnUNetDataset(self.preprocessed_dataset_folder, val_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
num_images_properties_loading_threshold=0)
# Learn from do_split to get extra_keys
extra_keys = sorted([i[:4] for i in os.listdir(self.extra_dataset_src)])
dataset_extra = nnUNetDataset(self.extra_dataset_folder, extra_keys,
folder_with_segs_from_previous_stage=self.folder_with_segs_from_previous_stage,
num_images_properties_loading_threshold=0)
return dataset_tr, dataset_val, dataset_extra
def get_plain_dataloaders(self, initial_patch_size: Tuple[int, ...], dim: int):
dataset_tr, dataset_val, dataset_extra = self.get_tr_and_val_datasets()
dl_tr = nnUNetDataLoader3D(dataset_tr, self.batch_size,
initial_patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None)
dl_val = nnUNetDataLoader3D(dataset_val, self.batch_size,
self.configuration_manager.patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None)
dl_extra = nnUNetDataLoader3D(dataset_extra, self.batch_size,
self.configuration_manager.patch_size,
self.configuration_manager.patch_size,
self.label_manager,
oversample_foreground_percent=self.oversample_foreground_percent,
sampling_probabilities=None, pad_sides=None)
return dl_tr, dl_val, dl_extra
def get_dataloaders(self):
# we use the patch size to determine whether we need 2D or 3D dataloaders. We also use it to determine whether
# we need to use dummy 2D augmentation (in case of 3D training) and what our initial patch size should be
patch_size = self.configuration_manager.patch_size
dim = len(patch_size)
# needed for deep supervision: how much do we need to downscale the segmentation targets for the different
# outputs?
deep_supervision_scales = self._get_deep_supervision_scales()
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
self.configure_rotation_dummyDA_mirroring_and_inital_patch_size()
# training pipeline
tr_transforms = self.get_training_transforms(
patch_size, rotation_for_DA, deep_supervision_scales, mirror_axes, do_dummy_2d_data_aug,
order_resampling_data=3, order_resampling_seg=1,
use_mask_for_norm=self.configuration_manager.use_mask_for_norm,
is_cascaded=self.is_cascaded, foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
# validation pipeline
val_transforms = self.get_validation_transforms(deep_supervision_scales,
is_cascaded=self.is_cascaded,
foreground_labels=self.label_manager.foreground_labels,
regions=self.label_manager.foreground_regions if
self.label_manager.has_regions else None,
ignore_label=self.label_manager.ignore_label)
dl_tr, dl_val, dl_extra = self.get_plain_dataloaders(initial_patch_size, dim)
allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
mt_gen_extra = SingleThreadedAugmenter(dl_extra, tr_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
num_processes=allowed_num_processes, num_cached=6, seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
wait_time=0.02)
mt_gen_extra = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_extra,
transform=tr_transforms,
num_processes=allowed_num_processes, num_cached=6, seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.02)
return mt_gen_train, mt_gen_val, mt_gen_extra
def run_training(self):
self.on_train_start()
for epoch in range(self.current_epoch, self.num_epochs):
self.on_epoch_start()
self.on_train_epoch_start()
train_outputs = []
for batch_id in range(self.num_iterations_per_epoch):
train_outputs.append(self.train_step(next(self.dataloader_train), next(self.dataloader_extra)))
self.on_train_epoch_end(train_outputs)
with torch.no_grad():
self.on_validation_epoch_start()
val_outputs = []
for batch_id in range(self.num_val_iterations_per_epoch):
val_outputs.append(self.validation_step(next(self.dataloader_val)))
self.on_validation_epoch_end(val_outputs)
self.on_epoch_end()
self.on_train_end()
def initialize(self):
super().initialize()
self.ema_model = self.build_network_architecture(self.plans_manager, self.dataset_json,
self.configuration_manager,
self.num_input_channels,
enable_deep_supervision=True).to(self.device)
def on_train_epoch_start(self):
super().on_train_epoch_start()
self.ema_model.train()
self.UA_consistency_weight = self.UA_base_consistency_weight * self.sigmoid_rampup(self.current_epoch,
self.num_epochs // 2)
self.UA_threshold = self.UA_base_threshold + 0.25 * self.sigmoid_rampup(self.current_epoch,
self.num_epochs // 2)
self.print_to_log_file(f'Epoch {self.current_epoch} UA_consistency_weight: {self.UA_consistency_weight}')
self.print_to_log_file(f'Epoch {self.current_epoch} UA_threshold: {self.UA_threshold}')
def train_step(self, batch: dict, extra_batch: dict) -> dict:
data = batch['data']
target = batch['target']
extra_data = extra_batch['data']
extra_target = extra_batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
extra_data = extra_data.to(self.device, non_blocking=True)
if isinstance(extra_target, list):
extra_target = [i.to(self.device, non_blocking=True) for i in extra_target]
else:
extra_target = extra_target.to(self.device, non_blocking=True)
self.optimizer.zero_grad(set_to_none=True)
with autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
# List: 5, [(2, 36, 128, 128, 128), (2, 36, 64, 64, 64), ..., (2, 36, 32, 32, 32)]
# There are 36 labels including background, so output channel is 36
output = self.network(data)
l = self.loss(output, target) # nnUNet base loss: compute loss on 5 scale and aggregate
with torch.no_grad():
ema_output = self.ema_model(self.get_noisy_input(data)) # only use top scale
uncertainty = self.MC_uncertainty(self.ema_model, data) # forward T times on ema_model to get uncertainty
l2 = self.UA_consistency_loss(output[0], ema_output[0], uncertainty)
extra_output = self.network(extra_data)
l3 = self.loss(extra_output, extra_target)
l3 = self.HQ_loss_weight * l3
total_loss = l + l2 + l3
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
self.grad_scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
self.optimizer.step()
# Update EMA model one time per iteration
self.update_ema_variables(self.network, self.ema_model, self.current_epoch)
return {'loss': total_loss.detach().cpu().numpy(), 'base_loss': l.detach().cpu().numpy(),
'UA_loss': l2.detach().cpu().numpy(), 'HQ_loss': l3.detach().cpu().numpy()}
def on_train_epoch_end(self, train_outputs: List[dict]):
outputs = collate_outputs(train_outputs)
loss_here = np.mean(outputs['loss'])
base_loss = np.mean(outputs['base_loss'])
UA_loss = np.mean(outputs['UA_loss'])
HQ_loss = np.mean(outputs['HQ_loss'])
self.logger.log('train_losses', loss_here, self.current_epoch)
self.print_to_log_file(f'Epoch {self.current_epoch} base_losses: ', base_loss)
self.print_to_log_file(f'Epoch {self.current_epoch} UA_consistency_loss: ', UA_loss)
self.print_to_log_file(f'Epoch {self.current_epoch} HQ_loss: ', HQ_loss)
def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): # No mirroring
rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes = \
super().configure_rotation_dummyDA_mirroring_and_inital_patch_size()
mirror_axes = None
self.inference_allowed_mirroring_axes = None
return rotation_for_DA, do_dummy_2d_data_aug, initial_patch_size, mirror_axes
def update_ema_variables(self, model, ema_model, global_step):
# Use the true average until the exponential average is- more correct
alpha = min(1 - 1 / (global_step + 1), self.ema_decay)
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(alpha).add_(param.data, alpha=1 - alpha)
def UA_consistency_loss(self, output, target, uncertainty): # Uncertainty-Aware Consistency Loss
output_softmax = F.softmax(output, dim=1)
target_softmax = F.softmax(target, dim=1)
mse_loss = (output_softmax - target_softmax) ** 2
mask = (uncertainty < self.UA_threshold).float()
consistency_loss = torch.sum(mask * mse_loss) / (2 * torch.sum(mask) + 1e-16) # 2 for batch size 2
return self.UA_consistency_weight * consistency_loss
def MC_uncertainty(self, ema_model, data):
T = 6
data_r = data.repeat(2, 1, 1, 1, 1) # (2, 1, 128, 128, 128) ==> (4, 1, 128, 128, 128)
stride = data_r.shape[0] // 2 # 2
preds = torch.zeros(
[stride * T, 36, data.shape[2], data.shape[3], data.shape[4]]).cuda() # (16, 36, 128, 128, 128)
for i in range(T // 2):
ema_inputs = self.get_noisy_input(data_r)
with torch.no_grad():
preds[2 * stride * i:2 * stride * (i + 1)] = ema_model(ema_inputs)[0] # only use top scale
preds = F.softmax(preds, dim=1)
preds = preds.reshape(T, stride, 36, data.shape[2], data.shape[3], data.shape[4])
preds = torch.mean(preds, dim=0)
uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True)
uncertainty = uncertainty / math.log(36) # normalize uncertainty to 0 to 1, cuz ln36 is the max value!!!
return uncertainty
def get_noisy_input(self, data):
return data + torch.clamp(torch.randn_like(data) * 0.1, -0.2, 0.2)
def sigmoid_rampup(self, current, rampup_length):
"""Exponential rampup from https://arxiv.org/abs/1610.02242"""
if rampup_length == 0:
return 1.0
else:
current = np.clip(current, 0.0, rampup_length)
phase = 1.0 - current / rampup_length
return float(np.exp(-5.0 * phase * phase))
# def save_MRI(self, img, path):
# ref_img = sitk.ReadImage("oasis_atlas.nii.gz")
# img = sitk.GetImageFromArray(img[0, 0, ...].cpu().detach().numpy())
# img.SetOrigin(ref_img.GetOrigin())
# img.SetDirection(ref_img.GetDirection())
# img.SetSpacing(ref_img.GetSpacing())
# sitk.WriteImage(img, path)
Hi @lichen14, my code is as followings. I can't promise the code is very accurate, maybe you can find some bugs and tell me.
In my experiment, the improvement of MT is not that obvious (<1%). I think maybe nnUNet itself is ruboust enough.
Hi @LIKP0 , thanks for your reply. Your code looks like a separate file, such as a `nnUNetTrainerMT.py,` and then runs through `nnUNetv2_train`. I'll try to test and if there's any progress, I'll continue @you~
By the way, I wonder which dataset was used in your experiment to get a <1% improvement?
Thanks.
@lichen14 My experiment is based on OASIS-I and I recommend this preprocessed dataset. I generate many pseudo lables by registration methods. Then, use the original labels as high quality labels and pseudo labels as low quality ones. The segmentaion is performed on 35 brain structures, but I forget the concrete dice scores.
Hi everyone, I'm trying to combine the self-supervised architecture mean-teacher (MT) with nnUnet, which means to replace the Unet architecture with nnUnet in the student and teacher model in the following figure.
I know it's simple to apply MT on Unet but how to apply it on nnUnet? For example, how to do EMA in a proper way on nnUnet. I really have no direction towards huge code of nnUnet...
Could anyone give some advice? Thanks in advance!
The picture is cited as Fig 1 in Yu L, Wang S, Li X, et al. Uncertainty-aware self-ensembling model for semi-supervised 3D left atrium segmentation[C]//Medical Image Computing and Computer Assisted Intervention–MICCAI 2019: 22nd International Conference, Shenzhen, China, October 13–17, 2019, Proceedings, Part II 22. Springer International Publishing, 2019: 605-613.