uni-medical / STU-Net

The largest pre-trained medical image segmentation model (1.4B parameters) based on the largest public dataset (>100k annotations), up until April 2023.
Apache License 2.0
273 stars 23 forks source link

TypeError: STUNetTrainer_huge.build_network_architecture() takes from 4 to 5 positional arguments but 6 were given #34

Open SergioRodLla opened 2 months ago

SergioRodLla commented 2 months ago

Hi Ziyan,

I tried to finetune your pretrained huge model on my multi-modal dataset using nnUNet v2 framework like this: CUDA_VISIBLE_DEVICES=0 python nnUNet/nnunetv2/run/run_finetuning_stunet.py 500 3d_fullres 0 -tr STUNetTrainer_huge_ft -pretrained_weights STU-net/models/huge_ep4k.model

However, I got this error:

Traceback (most recent call last): File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_finetuning_stunet.py", line 63, in File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_finetuning_stunet.py", line 63, in File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_finetuning_stunet.py", line 63, in run_training_entry() File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_training.py", line 275, in run_training_entry run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_training.py", line 204, in run_training maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_finetuning_stunet.py", line 63, in run_training_entry() File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_training.py", line 275, in run_training_entry run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_training.py", line 204, in run_training maybe_load_checkpoint(nnunet_trainer, continue_training, only_run_validation, pretrained_weights) File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/run/run_training.py", line 94, in maybe_load_checkpoint nnunet_trainer.initialize() File "/media/HDD_4TB_2/sergio/TFM/nnUNet/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py", line 210, in initialize self.network = self.build_network_architecture( TypeError: STUNetTrainer_huge.build_network_architecture() takes from 4 to 5 positional arguments but 6 were given

Looking at your implementation looks like you changed the way the basic nnUNetTrainer class is initializing the network when calling self.build_network_architecture. In the original one it takes 6 arguments. This is were the error ocurrs. I wanted to tell you so you are aware of this. I think defining a new initialize method as you have done in your nnUNetTrainer.py code taking 4 arguements in your custom STUNetTrainer would do the trick.

I'm a begginer on this so I might be wrong. Any ideas from you would be very appreciated.

Best regards, Sergio

plbenveniste commented 2 months ago

Hi ! Can you explain what installation you did to get this ?

SergioRodLla commented 2 months ago

Hi @plbenveniste ! I had previously cloned the nnUNet v2 repository and just added the STUNetTrainer.py and run_finetuning_stunet.py scripts to the corresponding folders.

plbenveniste commented 2 months ago

I am having the same problem so I looked at the code and it seems that in the latest version of nnunetv2, the trainer is not built the same way. I think the solution might be to use this release (since it seems to match) https://github.com/MIC-DKFZ/nnUNet/releases/tag/v2.2 Looking into it now. Will let you know.

SergioRodLla commented 2 months ago

If you don't want to use that specific release you can just redefine the initialize method inside STUNetTrainer like this. It worked for me.

class STUNetTrainer(nnUNetTrainer):
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
                 device: torch.device = torch.device('cuda')):
        super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
        self.num_epochs = 1000
        self.initial_lr = 1e-2

    @staticmethod
    def build_network_architecture(plans_manager,
                                   dataset_json,
                                   configuration_manager,
                                   num_input_channels,
                                   enable_deep_supervision: bool = True) -> nn.Module: 
        label_manager = plans_manager.get_label_manager(dataset_json)
        num_classes=label_manager.num_segmentation_heads
        kernel_sizes = [[3,3,3]] * 6
        strides=configuration_manager.pool_op_kernel_sizes[1:]
        if len(strides)>5:
            strides = strides[:5]
        while len(strides)<5:
            strides.append([1,1,1])
        return STUNet(num_input_channels, num_classes, depth=[1]*6, dims= [32 * x for x in [1, 2, 4, 8, 16, 16]], 
                      pool_op_kernel_sizes=strides, conv_kernel_sizes=kernel_sizes, enable_deep_supervision=enable_deep_supervision)

    def initialize(self):
        if not self.was_initialized:
            self.num_input_channels = determine_num_input_channels(self.plans_manager, self.configuration_manager,
                                                                   self.dataset_json)

            self.network = self.build_network_architecture(self.plans_manager, self.dataset_json,
                                                           self.configuration_manager,
                                                           self.num_input_channels,
                                                           enable_deep_supervision=True).to(self.device)
            # compile network for free speedup
            if self._do_i_compile():
                self.print_to_log_file('Compiling network...')
                self.network = torch.compile(self.network)

            self.optimizer, self.lr_scheduler = self.configure_optimizers()
            # if ddp, wrap in DDP wrapper
            if self.is_ddp:
                self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network)
                self.network = DDP(self.network, device_ids=[self.local_rank])

            self.loss = self._build_loss()
            self.was_initialized = True
        else:
            raise RuntimeError("You have called self.initialize even though the trainer was already initialized. "
                               "That should not happen.")
plbenveniste commented 2 months ago

Thanks for the solution @SergioRodLla ! It worked for me as well.

SergioRodLla commented 2 months ago

Happy to hear that! :)