MIC-DKFZ / nnUNet

Apache License 2.0
5.83k stars 1.75k forks source link

Stuck at the beginning of training #2470

Open Wind-177 opened 2 months ago

Wind-177 commented 2 months ago

It gets stuck at the beginning of training and cannot enter epoch0. It takes about 2 hours to train normally. I use my customized network for training, instead of one of the three default models. After troubleshooting, I found that the execution of self.plot_network_architecture() is stuck, but after commenting, everything is fine.

Wind-177 commented 2 months ago

def on_train_start(self): if not self.was_initialized: self.initialize()

    maybe_mkdir_p(self.output_folder)

    # make sure deep supervision is on in the network
    self.set_deep_supervision_enabled(True)

    self.print_plans()
    empty_cache(self.device)

    # maybe unpack
    if self.unpack_dataset and self.local_rank == 0:
        self.print_to_log_file('unpacking dataset...')
        unpack_dataset(self.preprocessed_dataset_folder, unpack_segmentation=True, overwrite_existing=False,
                       num_processes=max(1, round(get_allowed_n_proc_DA() // 2)))
        self.print_to_log_file('unpacking done...')

    if self.is_ddp:
        dist.barrier()

    # dataloaders must be instantiated here because they need access to the training data which may not be present
    # when doing inference
    self.dataloader_train, self.dataloader_val = self.get_dataloaders()

    # copy plans and dataset.json so that they can be used for restoring everything we need for inference
    save_json(self.plans_manager.plans, join(self.output_folder_base, 'plans.json'), sort_keys=False)
    save_json(self.dataset_json, join(self.output_folder_base, 'dataset.json'), sort_keys=False)

    # we don't really need the fingerprint but its still handy to have it with the others
    shutil.copy(join(self.preprocessed_dataset_folder_base, 'dataset_fingerprint.json'),
                join(self.output_folder_base, 'dataset_fingerprint.json'))

    # produces a pdf in output folder
    #self.plot_network_architecture()

    self._save_debug_information()

    # print(f"batch size: {self.batch_size}")
    # print(f"oversample: {self.oversample_foreground_percent}")
Wind-177 commented 2 months ago

Is it because the module cannot draw certain structures?