MIC-DKFZ / nnUNet

Apache License 2.0
5.35k stars 1.63k forks source link

Using custom architectures #2069

Open noneedanick opened 2 months ago

noneedanick commented 2 months ago

I am currently trying to train different architectures within nnUNet platform to compare with nnUNet baseline architectures. I believe it should be convenient to use same preprocessing steps before comparison. Monai platform has a lot of capability to create most recent architectures without effort. So following the so called "Quick and dirty" methods I just created a custom trainer class but I am not sure this alone is sufficient to train with nnUNetv2_train function (with -tr UNETRTrainer). I would be grateful for any ideas and help. Here is my class:


from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager
from monai.networks.nets import UNETR
from nnunetv2.utilities.label_handling.label_handling import determine_num_input_channels
import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as DDP
from typing import Union, Tuple, List

class UNETRTrainer(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.plans_manager = PlansManager(plans)
        self.configuration_manager = self.plans_manager.get_configuration(configuration)
        self.enable_deep_supervision = False
        self.dataset_json = dataset_json
        ### Some hyperparameters for you to fiddle with
        self.initial_lr = 1e-2
        self.weight_decay = 3e-5
        self.oversample_foreground_percent = 0.33
        self.num_iterations_per_epoch = 250
        self.num_val_iterations_per_epoch = 50
        self.num_epochs = 500
        self.current_epoch = 0
        self.enable_deep_supervision = True

        ### Dealing with labels/regions
        self.label_manager = self.plans_manager.get_label_manager(dataset_json)

    def initialize(self):
        if not self.was_initialized:
            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                                   self.dataset_json)

            self.network = self.build_network_architecture(
                self.configuration_manager.network_arch_class_name,
                self.configuration_manager.network_arch_init_kwargs,
                self.configuration_manager.network_arch_init_kwargs_req_import,
                self.num_input_channels,
                self.label_manager.num_segmentation_heads,
                self.enable_deep_supervision
            ).to(self.device)
            # compile network for free speedup
            if self._do_i_compile():
                self.print_to_log_file('Using torch.compile...')
                self.network = torch.compile(self.network)

            self.optimizer, self.lr_scheduler = self.configure_optimizers()
            # if ddp, wrap in DDP wrapper
            if self.is_ddp:
                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
                self.network = DDP(self.network, device_ids=[self.local_rank])

            self.loss = self._build_loss()
            self.was_initialized = True
        else:
            raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                               "That should not happen.")

    @staticmethod
    def build_network_architecture(self,
                                   architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True) -> nn.Module:
        patch_size = self.configuration_manager.patch_size
        model = UNETR(
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        img_size=patch_size,
                        feature_size=16,
                        hidden_size=768,
                        mlp_dim=3072,
                        num_heads=12,
                        pos_embed="perceptron",
                        norm_name="instance",
                        res_block=True,
                        dropout_rate=0.0,
                    )
        return model
FabianIsensee commented 2 months ago

Hey, so all you need to test different architectures is to overwrite build_network_architecture plus any additional changes you'd like to make. Basically what you did is exactly correct. The 'quick and dirty' means that your architecture will not be considered during experiment planning. It is expected to be able to deal with whatever patch and batch size nnU-net gives it. Ideally is also used the nnU-net configured downsampling steps but that is optional. Best, Fabian

noneedanick commented 2 months ago

It just gave an error about this function so I commented out some parts of it since I won't use deep supervision. Now its working just fine !! Thanks a lot !!.

NOTE: Also I am a little bit inexperienced to use staticmethod, I believe it is not common to use 'self' in staticmethod. I just added it to use incoming patch_size as input image size but those are defined within class using self argument.

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled
noneedanick commented 2 months ago

Also for researchers who are searching for implementation of SwinUNETR and UNETR segmentation models within nnUNet framework here are my latest classes.

Classes constracted based on following notebooks defining these structures: https://colab.research.google.com/github/Project-MONAI/tutorials/blob/main/3d_segmentation/swin_unetr_brats21_segmentation_3d.ipynb#scrollTo=xQn18qtvZChG

https://github.com/Project-MONAI/tutorials/blob/main/3d_segmentation/unetr_btcv_segmentation_3d.ipynb

UNETR:

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import UNETR
from torch.optim import Adam, AdamW
import torch
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels

class nnUNetTrainerUNETR(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-4
        self.weight_decay = 1e-5
        self.num_epochs = 200

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled

    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        patch_size = (64,96,96)

        model = UNETR(
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        img_size=patch_size,
                        feature_size=16,
                        hidden_size=768,
                        mlp_dim=3072,
                        num_heads=12,
                        pos_embed="perceptron",
                        norm_name="instance",
                        res_block=True,
                        dropout_rate=0.0,
                    )
        return model

    def configure_optimizers(self):
        optimizer = AdamW(self.network.parameters(),
                          lr=self.initial_lr,
                          weight_decay=self.weight_decay,
                          amsgrad=True)

        lr_scheduler = PolyLRScheduler(optimizer, self.initial_lr, self.num_epochs)
        return optimizer, lr_scheduler     

SwinUNETR:

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from monai.networks.nets import SwinUNETR
from torch.optim import Adam, AdamW
import torch
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import torch
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.label_handling.label_handling import convert_labelmap_to_one_hot, determine_num_input_channels
from nnunetv2.training.loss.compound_losses import DC_and_BCE_loss, DC_and_CE_loss
from nnunetv2.training.loss.deep_supervision import DeepSupervisionWrapper
from nnunetv2.training.loss.dice import MemoryEfficientSoftDiceLoss
from nnunetv2.utilities.helpers import softmax_helper_dim1

class nnUNetTrainerSwinUNETR(nnUNetTrainer):
    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        unpack_dataset: bool = True,
        device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-4
        self.weight_decay = 1e-5
        self.num_epochs = 200

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        #if isinstance(mod, OptimizedModule):
            #mod = mod._orig_mod

        #mod.decoder.deep_supervision = enabled

    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        patch_size = (64,96,96)
        model = SwinUNETR(
                img_size=patch_size,
                        in_channels=num_input_channels,
                        out_channels=num_output_channels,
                        feature_size=48,
                        drop_rate=0.0,
                        attn_drop_rate=0.0,
                        dropout_path_rate=0.0,
                        use_checkpoint=True,
                    )
        return model

    def _build_loss(self):
        loss = MemoryEfficientSoftDiceLoss(**{'batch_dice': self.configuration_manager.batch_dice,
                                    'do_bg': self.label_manager.has_regions, 'smooth': 1e-5, 'ddp': self.is_ddp},
                            apply_nonlin=torch.sigmoid if self.label_manager.has_regions else softmax_helper_dim1)

        if self.enable_deep_supervision:
            deep_supervision_scales = self._get_deep_supervision_scales()

            # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
            # this gives higher resolution outputs more weight in the loss
            weights = np.array([1 / (2 ** i) for i in range(len(deep_supervision_scales))])
            weights[-1] = 0

            # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
            weights = weights / weights.sum()
            # now wrap the loss
            loss = DeepSupervisionWrapper(loss, weights)
        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.network.parameters(),
                          lr=self.initial_lr,
                          weight_decay=self.weight_decay,
                          amsgrad=True)

        lr_scheduler = CosineAnnealingLR(optimizer, T_max=self.num_epochs)
        return optimizer, lr_scheduler 
FabianIsensee commented 2 months ago

Hey, thanks for sharing!

Three things to be aware of:

We have our own implementations for SwinUNetr internally that we were planning to release soon. Will be interesting to compare to yours. How has your experience with these architectures been so far? We haven't found them useful so far. They just can't keep up with a standard UNet or a UNet with residual encoder.

Best, Fabian

noneedanick commented 2 months ago

Thanks Fabian for your comments!

I couldnt find a proper way to automatically infer patch size from plans without changing build_network_architecture function signature :(

So far my comparisons were close to each other, very similar performances but it looks like 3DUNET is still going to be winner :laughing: . I am right now trying to deal with hardware issues to speed up the process. It takes over than 300 secs (same for each model) to train an epoch. Also at the validation side both my custom classes and unchanged 3d_fullres training takes over than 20 minutes for each validation case. I am trying to lower these timings because it feels like an eternity for me :laughing:.

Will share my trials and configurations here soon !

Best; Murat

FabianIsensee commented 2 months ago

Hah you are right the signature is

def build_network_architecture(architecture_class_name: str,
                                   arch_init_kwargs: dict,
                                   arch_init_kwargs_req_import: Union[List[str], Tuple[str, ...]],
                                   num_input_channels: int,
                                   num_output_channels: int,
                                   enable_deep_supervision: bool = True) -> nn.Module:

I just read your code and forgot about how it is supposed to be. Easiest way to achieve what you need would be to change the arch_init_kwargs in the plans file so that the patch size is in there as well. Would need to be done as part of the experiment planner.

Those epoch times sound horrible, take a look here: https://github.com/MIC-DKFZ/nnUNet/blob/master/documentation/benchmarking.md There should be information for you how to approach debugging. Best, Fabian

FabianIsensee commented 1 month ago

Hey, how are you progressing? Is everything working now?

noneedanick commented 1 month ago

Hey Fabian ! Lately I am preparing to start a postdoc position in US so I couldn't shared my final results in here. But its in my mind and waiting a proper time (bc currently gathering my stuff for fully migrating to US) to share my results here. Nevertheless, I dealed with epoch time issue by increasing resample pixel values (since it was causing huge output array when automatically defined), my custom approach resulted in higher dice scores with nnUNet (expectedly 😆 ), SwinUNETR showed near similar performance but I need to test it again with more optimized parameters ( such as patch size, resample etc.).

Thanks for your interest by the way, I am honored ❤️

Yanfeng-Zhou commented 1 month ago

How to read the patch size by ConfigurationManager? Static methods cannot read self variables. Writing like this will result in an error!

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import functools
from torch.distributions.uniform import Uniform
import numpy as np
from timm.models.layers import DropPath, trunc_normal_

# from torch.nn.init import xavier_uniform_, constant_, normal_
# import copy
# import math

from monai.networks.nets import SwinUNETR
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

class nnUNetTrainer_SwinUNETR(nnUNetTrainer):
    def __init__(
            self,
            plans: dict,
            configuration: str,
            fold: int,
            dataset_json: dict,
            unpack_dataset: bool = True,
            device: torch.device = torch.device("cuda"),
    ):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.enable_deep_supervision = False
        self.initial_lr = 1e-2
        # self.weight_decay = 3e-5
        # self.num_epochs = 1000
        self.num_epochs = 500
        # self.num_epochs = 50
        # self.patch_size =

    def set_deep_supervision_enabled(self, enabled: bool):
        """
        This function is specific for the default architecture in nnU-Net. If you change the architecture, there are
        chances you need to change this as well!
        """
        if self.is_ddp:
            mod = self.network.module
        else:
            mod = self.network
        # if isinstance(mod, OptimizedModule):
        #     mod = mod._orig_mod
        # mod.deep_supervised = enabled
    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        patch_size = trainer.configuration_manager.patch_size
        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = SwinUNETR(img_size=nnUNetTrainer_SwinUNETR.patch_size, in_channels=num_input_channels, out_channels=num_output_channels)
        model.apply(InitWeights_He(1e-2))
        return model
Yanfeng-Zhou commented 1 month ago

The function build_network_architecture uses a static method, which cannot introduce self.configuration_manager.patch_size!

Thank you for your reply!

FabianIsensee commented 1 month ago

Use the arch_init_kwargs dictionary and add the patch size to that

Yanfeng-Zhou commented 1 month ago

Use the arch_init_kwargs dictionary and add the patch size to that

Sorry, I can't find arch_init_kwargs, can you give me a minimum implementation? Thank you!

FabianIsensee commented 4 weeks ago

You find it in the plans file as part of the configuration you are trying to run. It's called arch_kwargs in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']

Yanfeng-Zhou commented 4 weeks ago

You find it in the plans file as part of the configuration you are trying to run. It's called arch_kwargs in there. Just add the patch size there and then it will be available in build_network_architecture as arch_init_kwargs['patch_size']

I succeeded! I share the complete process! Open nnUNet/nnunetv2/utilities/plans_handling/plans_handler.py Add the following code on line 33:

...
class ConfigurationManager(object):
    def __init__(self, configuration_dict: dict):
        self.configuration = configuration_dict
        **self.configuration["architecture"]["arch_kwargs"]["patch_size"] = self.configuration["patch_size"]**
        # backwards compatibility
        if 'architecture' not in self.configuration.keys():
...

Then you can introduce patch_size in the custom network, and other parameters can also be operated in this way.

    @staticmethod
    def build_network_architecture(architecture_class_name,
                                   arch_init_kwargs,
                                   arch_init_kwargs_req_import,
                                   num_input_channels,
                                   num_output_channels,
                                   enable_deep_supervision):

        from dynamic_network_architectures.initialization.weight_init import InitWeights_He
        model = SwinUNETR(img_size=**arch_init_kwargs['patch_size']**, in_channels=num_input_channels, out_channels=num_output_channels)
        model.apply(InitWeights_He(1e-2))
        return model

Thank you again for your help! @FabianIsensee

FabianIsensee commented 2 weeks ago

Hey @Yanfeng-Zhou glad to hear it works now. It would be cleaner to add the patch size to the arch_kwargs in the plans file because that would not affect the rest of nnU-Net. Now you will always have this added even if you don't need it