MIC-DKFZ / nnUNet

Apache License 2.0
5.95k stars 1.77k forks source link

Two-step initialization of the network may lead to unexpected model behaviour #2520

Open machur opened 1 month ago

machur commented 1 month ago

Hi, I'm working on exporting raw nnUNet model to TorchScript format and I've faced some unexpected issues due to two-step initialization of the model. Initially, I created a predictor and called _initialize_from_trained_modelfolder method. I assumed that this would be enough to have a fully initialized model for the export. Unfortunately, the actual initialization of weights happens further in code, during the prediction of logits and I had to run the prediction once before the export (even though I didn't need it):

    def predict_logits_from_preprocessed_data(self, data: torch.Tensor) -> torch.Tensor:
        n_threads = torch.get_num_threads()
        torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
        prediction = None

        for params in self.list_of_parameters:

            # messing with state dict names...
            if not isinstance(self.network, OptimizedModule):
                self.network.load_state_dict(params)
            else:
                self.network._orig_mod.load_state_dict(params)

I've opened this issue as a humble suggestion to move the loop with loading state dictionaries to the _initialize_from_trained_modelfolder method, so it's all performed in a single step. I hope it's feasible, although I tried exporting a single model only, not an ensemble, so I'm not sure it would make sense for multiple models.