Open noneedanick opened 2 months ago
Hey, so all you need to test different architectures is to overwrite build_network_architecture plus any additional changes you'd like to make. Basically what you did is exactly correct. The 'quick and dirty' means that your architecture will not be considered during experiment planning. It is expected to be able to deal with whatever patch and batch size nnU-net gives it. Ideally is also used the nnU-net configured downsampling steps but that is optional. Best, Fabian
It just gave an error about this function so I commented out some parts of it since I won't use deep supervision. Now its working just fine !! Thanks a lot !!.
NOTE: Also I am a little bit inexperienced to use staticmethod, I believe it is not common to use 'self' in staticmethod. I just added it to use incoming patch_size as input image size but those are defined within class using self argument.
def set_deep_supervision_enabled(self, enabled: bool):
"""
This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
chances you need to change this as well!
"""
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
#if isinstance(mod, OptimizedModule):
#mod = mod._orig_mod
#mod.decoder.deep_supervision = enabled
Also for researchers who are searching for implementation of SwinUNETR and UNETR segmentation models within nnUNet framework here are my latest classes.
Classes constracted based on following notebooks defining these structures: https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb#scrollTo=xQn18qtvZChG
UNETR:
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import UNETR
from torch.optim import Adam, AdamW
import torch
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
class nnUNetTrainerUNETR(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
unpack_dataset: bool = True,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.enable_deep_supervision = False
self.initial_lr = 1e-4
self.weight_decay = 1e-5
self.num_epochs = 200
def set_deep_supervision_enabled(self, enabled: bool):
"""
This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
chances you need to change this as well!
"""
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
#if isinstance(mod, OptimizedModule):
#mod = mod._orig_mod
#mod.decoder.deep_supervision = enabled
@staticmethod
def build_network_architecture(architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels,
enable_deep_supervision):
patch_size = (64,96,96)
model = UNETR(
in_channels=num_input_channels,
out_channels=num_output_channels,
img_size=patch_size,
feature_size=16,
hidden_size=768,
mlp_dim=3072,
num_heads=12,
pos_embed="perceptron",
norm_name="instance",
res_block=True,
dropout_rate=0.0,
)
return model
def configure_optimizers(self):
optimizer = AdamW(self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay,
amsgrad=True)
lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
return optimizer, lr_scheduler
SwinUNETR:
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import SwinUNETR
from torch.optim import Adam, AdamW
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import torch
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1
class nnUNetTrainerSwinUNETR(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
unpack_dataset: bool = True,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.enable_deep_supervision = False
self.initial_lr = 1e-4
self.weight_decay = 1e-5
self.num_epochs = 200
def set_deep_supervision_enabled(self, enabled: bool):
"""
This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
chances you need to change this as well!
"""
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
#if isinstance(mod, OptimizedModule):
#mod = mod._orig_mod
#mod.decoder.deep_supervision = enabled
@staticmethod
def build_network_architecture(architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels,
enable_deep_supervision):
patch_size = (64,96,96)
model = SwinUNETR(
img_size=patch_size,
in_channels=num_input_channels,
out_channels=num_output_channels,
feature_size=48,
drop_rate=0.0,
attn_drop_rate=0.0,
dropout_path_rate=0.0,
use_checkpoint=True,
)
return model
def _build_loss(self):
loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)
if self.enable_deep_supervision:
deep_supervision_scales = self._get_deep_supervision_scales()
# we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
# this gives higher resolution outputs more weight in the loss
weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
weights[-1] = 0
# we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
weights = weights / weights.sum()
# now wrap the loss
loss = DeepSupervisionWrapper(loss, weights)
return loss
def configure_optimizers(self):
optimizer = AdamW(self.network.parameters(),
lr=self.initial_lr,
weight_decay=self.weight_decay,
amsgrad=True)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
return optimizer, lr_scheduler
Hey, thanks for sharing!
Three things to be aware of:
We have our own implementations for SwinUNetr internally that we were planning to release soon. Will be interesting to compare to yours. How has your experience with these architectures been so far? We haven't found them useful so far. They just can't keep up with a standard UNet or a UNet with residual encoder.
Best, Fabian
Thanks Fabian for your comments!
I couldnt find a proper way to automatically infer patch size from plans without changing build_network_architecture function signature :(
So far my comparisons were close to each other, very similar performances but it looks like 3DUNET is still going to be winner :laughing: . I am right now trying to deal with hardware issues to speed up the process. It takes over than 300 secs (same for each model) to train an epoch. Also at the validation side both my custom classes and unchanged 3d_fullres training takes over than 20 minutes for each validation case. I am trying to lower these timings because it feels like an eternity for me :laughing:.
Will share my trials and configurations here soon !
Best; Murat
Hah you are right the signature is
def build_network_architecture(architecture_class_name: str,
arch_init_kwargs: dict,
arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
num_input_channels: int,
num_output_channels: int,
enable_deep_supervision: bool = True) -> nn.Module:
I just read your code and forgot about how it is supposed to be. Easiest way to achieve what you need would be to change the arch_init_kwargs in the plans file so that the patch size is in there as well. Would need to be done as part of the experiment planner.
Those epoch times sound horrible, take a look here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/benchmarking.md There should be information for you how to approach debugging. Best, Fabian
Hey, how are you progressing? Is everything working now?
Hey Fabian ! Lately I am preparing to start a postdoc position in US so I couldn't shared my final results in here. But its in my mind and waiting a proper time (bc currently gathering my stuff for fully migrating to US) to share my results here. Nevertheless, I dealed with epoch time issue by increasing resample pixel values (since it was causing huge output array when automatically defined), my custom approach resulted in higher dice scores with nnUNet (expectedly 😆 ), SwinUNETR showed near similar performance but I need to test it again with more optimized parameters ( such as patch size, resample etc.).
Thanks for your interest by the way, I am honored ❤️
How to read the patch size by ConfigurationManager? Static methods cannot read self variables. Writing like this will result in an error!
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
from timm.models.layers import DropPath, trunc_normal_
# from torch.nn.init import xavier_uniform_, constant_, normal_
# import copy
# import math
from monai.networks.nets import SwinUNETR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainer_SwinUNETR(nnUNetTrainer):
def __init__(
self,
plans: dict,
configuration: str,
fold: int,
dataset_json: dict,
unpack_dataset: bool = True,
device: torch.device = torch.device("cuda"),
):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
self.enable_deep_supervision = False
self.initial_lr = 1e-2
# self.weight_decay = 3e-5
# self.num_epochs = 1000
self.num_epochs = 500
# self.num_epochs = 50
# self.patch_size =
def set_deep_supervision_enabled(self, enabled: bool):
"""
This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
chances you need to change this as well!
"""
if self.is_ddp:
mod = self.network.module
else:
mod = self.network
# if isinstance(mod, OptimizedModule):
# mod = mod._orig_mod
# mod.deep_supervised = enabled
@staticmethod
def build_network_architecture(architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels,
enable_deep_supervision):
patch_size = trainer.configuration_manager.patch_size
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
model = SwinUNETR(img_size=nnUNetTrainer_SwinUNETR.patch_size, in_channels=num_input_channels, out_channels=num_output_channels)
model.apply(InitWeights_He(1e-2))
return model
The function build_network_architecture uses a static method, which cannot introduce self.configuration_manager.patch_size!
Thank you for your reply!
Use the arch_init_kwargs
dictionary and add the patch size to that
Use the
arch_init_kwargs
dictionary and add the patch size to that
Sorry, I can't find arch_init_kwargs
, can you give me a minimum implementation? Thank you!
You find it in the plans file as part of the configuration you are trying to run. It's called arch_kwargs
in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']
You find it in the plans file as part of the configuration you are trying to run. It's called
arch_kwargs
in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']
I succeeded! I share the complete process! Open nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py Add the following code on line 33:
...
class ConfigurationManager(object):
def __init__(self, configuration_dict: dict):
self.configuration = configuration_dict
**self.configuration["architecture"]["arch_kwargs"]["patch_size"] = self.configuration["patch_size"]**
# backwards compatibility
if 'architecture' not in self.configuration.keys():
...
Then you can introduce patch_size in the custom network, and other parameters can also be operated in this way.
@staticmethod
def build_network_architecture(architecture_class_name,
arch_init_kwargs,
arch_init_kwargs_req_import,
num_input_channels,
num_output_channels,
enable_deep_supervision):
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
model = SwinUNETR(img_size=**arch_init_kwargs['patch_size']**, in_channels=num_input_channels, out_channels=num_output_channels)
model.apply(InitWeights_He(1e-2))
return model
Thank you again for your help! @FabianIsensee
Hey @Yanfeng-Zhou glad to hear it works now. It would be cleaner to add the patch size to the arch_kwargs in the plans file because that would not affect the rest of nnU-Net. Now you will always have this added even if you don't need it
I am currently trying to train different architectures within nnUNet platform to compare with nnUNet baseline architectures. I believe it should be convenient to use same preprocessing steps before comparison. Monai platform has a lot of capability to create most recent architectures without effort. So following the so called "Quick and dirty" methods I just created a custom trainer class but I am not sure this alone is sufficient to train with nnUNetv2_train function (with -tr UNETRTrainer). I would be grateful for any ideas and help. Here is my class: