MIC-DKFZ / nnUNet

Apache License 2.0
5.95k stars 1.77k forks source link

Problem occurs when inputting images above 10 images. #2590

Closed U-ma-s closed 1 week ago

U-ma-s commented 2 weeks ago

I would like to extract features from the final layer of encoder and visualize them. Therefore, I created a process that can extract features by executing main() of predict_from_raw_data.py. I checked the operation with 12 images. Then, after the 10th image, the features were added twice. Changing both 'num_processes_preprocessing=2, num_processes_segmentation_export=2' to 1 did not change the result.

Do you know what could be causing this?

thanks in advance

Output for the eighth and subsequent images↓

Predicting 00014:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
  0%|                                                                                                                                                                  | 0/1 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([8, 512, 7, 8])
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 79.07it/s]
sending off prediction to background worker for resampling and export
done with 00014

Predicting 00015:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
  0%|                                                                                                                                                                  | 0/1 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([9, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 147.96it/s]
sending off prediction to background worker for resampling and export
done with 00015

Predicting 00025:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
  0%|                                                                                                                                                                  | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([10, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([11, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 142.38it/s]
sending off prediction to background worker for resampling and export
done with 00025

Predicting 00026:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
  0%|                                                                                                                                                                  | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([12, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([13, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 125.74it/s]
sending off prediction to background worker for resampling and export
done with 00026

Predicting 00027:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
  0%|                                                                                                                                                                  | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([14, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([15, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 110.17it/s]
sending off prediction to background worker for resampling and export
done with 00027

my code


class nnUNetPredictor(object):

     ~~~~~~default codes~~~~~

    ########################################### My func start(Almost same as original code) ###########################################
    def _My_internal_maybe_mirror_and_predict(self, x: torch.Tensor, all_features: List = None) -> torch.Tensor:
        mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
        print("in _My_internal_maybe_mirror_and_predict")
        # save = SaveFeatures(self.network,self.network.encoder.stages[6][0].convs[1])

        def hook_fn(model, input, output):
            all_features.append(output.cpu())

        # print("network",self.network)
        handle = self.network.encoder.stages[6][0].convs[1].register_forward_hook(hook_fn) # Set the hook on the last layer.

        prediction = self.network(x)
        handle.remove()

        all_features = torch.cat(all_features, dim=0) 
        print("all_features",all_features.shape)

        if mirror_axes is not None:
            print("mirror_axes is not None")
            # check for invalid numbers in mirror_axes
            # x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
            assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'

            mirror_axes = [m + 2 for m in mirror_axes]
            axes_combinations = [
                c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
            ]
            for axes in axes_combinations:
                prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
            prediction /= (len(axes_combinations) + 1)
        return prediction

    def _My_internal_predict_sliding_window_return_logits(self,
                                                       data: torch.Tensor,
                                                       slicers,
                                                       do_on_device: bool = True,
                                                       all_features: List = None
                                                       ):
        predicted_logits = n_predictions = prediction = gaussian = workon = None
        results_device = self.device if do_on_device else torch.device('cpu')
        print("in _My_internal_predict_sliding_window_return_logits")

        try:
            empty_cache(self.device)

            # move data to device
            if self.verbose:
                print(f'move image to device {results_device}')
            data = data.to(results_device)

            # preallocate arrays
            if self.verbose:
                print(f'preallocating results arrays on device {results_device}')
            predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
                                           dtype=torch.half,
                                           device=results_device)
            n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)

            if self.use_gaussian:
                gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
                                            value_scaling_factor=10,
                                            device=results_device)
            else:
                gaussian = 1

            if not self.allow_tqdm and self.verbose:
                print(f'running prediction: {len(slicers)} steps')
            for sl in tqdm(slicers, disable=not self.allow_tqdm):
                workon = data[sl][None]
                workon = workon.to(self.device)

                prediction = self._My_internal_maybe_mirror_and_predict(workon,all_features)[0].to(results_device)

                if self.use_gaussian:
                    prediction *= gaussian
                predicted_logits[sl] += prediction
                n_predictions[sl[1:]] += gaussian

            predicted_logits /= n_predictions
            # check for infs
            if torch.any(torch.isinf(predicted_logits)):
                raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
                                   'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
                                   'predicted_logits to fp32')
        except Exception as e:
            del predicted_logits, n_predictions, prediction, gaussian, workon
            empty_cache(self.device)
            empty_cache(results_device)
            raise e
        return predicted_logits

    def My_predict_sliding_window_return_logits(self, input_image: torch.Tensor, all_features: List = None) \
            -> Union[np.ndarray, torch.Tensor]:
        print("in My_predict_sliding_window_return_logits")
        with torch.no_grad():
            assert isinstance(input_image, torch.Tensor)
            self.network = self.network.to(self.device)
            self.network.eval()

            empty_cache(self.device)

            # Autocast can be annoying
            # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
            # and needs to be disabled.
            # If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
            # is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
            # So autocast will only be active if we have a cuda device.
            with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
                assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'

                if self.verbose:
                    print(f'Input shape: {input_image.shape}')
                    print("step_size:", self.tile_step_size)
                    print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)

                # if input_image is smaller than tile_size we need to pad it to tile_size.
                data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
                                                           'constant', {'value': 0}, True,
                                                           None)

                slicers = self._internal_get_sliding_window_slicers(data.shape[1:])

                if self.perform_everything_on_device and self.device != 'cpu':
                    # we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
                    try:
                        predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers,
                                                                                               self.perform_everything_on_device,
                                                                                               all_features)
                    except RuntimeError:
                        print(
                            'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
                        empty_cache(self.device)
                        predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers, False, all_features)
                else:
                    predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers,
                                                                                           self.perform_everything_on_device,
                                                                                           all_features)

                empty_cache(self.device)
                # revert padding
                predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
        return predicted_logits

    def My_predict_logits_from_preprocessed_data(self, data: torch.Tensor, all_fratures: List = None) -> torch.Tensor:
        """
        IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
        TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!

        RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
        SEE convert_predicted_logits_to_segmentation_with_correct_shape
        """
        print("in My_predict_logits_from_preprocessed_data")
        n_threads = torch.get_num_threads()
        torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
        prediction = None

        for params in self.list_of_parameters:

            # messing with state dict names...
            if not isinstance(self.network, OptimizedModule):
                self.network.load_state_dict(params)
            else:
                self.network._orig_mod.load_state_dict(params)

            # why not leave prediction on device if perform_everything_on_device? Because this may cause the
            # second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
            # this actually saves computation time
            if prediction is None:
                prediction = self.My_predict_sliding_window_return_logits(data,all_features).to('cpu')
            else:
                prediction += self.My_predict_sliding_window_return_logits(data,all_features).to('cpu')

        if len(self.list_of_parameters) > 1:
            prediction /= len(self.list_of_parameters)

        if self.verbose: print('Prediction done')
        torch.set_num_threads(n_threads)
        return prediction

    def My_predict_from_data_iterator(self,
                                   data_iterator,
                                   save_probabilities: bool = False,
                                   num_processes_segmentation_export: int = default_num_processes,
                                   all_features: List = None):
        """
        each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!
        If 'ofile' is None, the result will be returned instead of written to a file
        """
        print("in My_predict_from_data_iterator")
        with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
            worker_list = [i for i in export_pool._pool]
            r = []
            for preprocessed in data_iterator:
                data = preprocessed['data']
                if isinstance(data, str):
                    delfile = data
                    data = torch.from_numpy(np.load(data))
                    os.remove(delfile)

                ofile = preprocessed['ofile']
                if ofile is not None:
                    print(f'\nPredicting {os.path.basename(ofile)}:')
                else:
                    print(f'\nPredicting image of shape {data.shape}:')

                print(f'perform_everything_on_device: {self.perform_everything_on_device}')

                properties = preprocessed['data_properties']

                # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with
                # npy files
                proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
                while not proceed:
                    sleep(0.1)
                    proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)

                prediction = self.My_predict_logits_from_preprocessed_data(data, all_features).cpu()

                if ofile is not None:
                    # this needs to go into background processes
                    # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager,
                    #                               self.dataset_json, ofile, save_probabilities)
                    print('sending off prediction to background worker for resampling and export')
                    r.append(
                        export_pool.starmap_async(
                            export_prediction_from_logits,
                            ((prediction, properties, self.configuration_manager, self.plans_manager,
                              self.dataset_json, ofile, save_probabilities),)
                        )
                    )
                else:
                    # convert_predicted_logits_to_segmentation_with_correct_shape(
                    #             prediction, self.plans_manager,
                    #              self.configuration_manager, self.label_manager,
                    #              properties,
                    #              save_probabilities)

                    print('sending off prediction to background worker for resampling')
                    r.append(
                        export_pool.starmap_async(
                            convert_predicted_logits_to_segmentation_with_correct_shape, (
                                (prediction, self.plans_manager,
                                 self.configuration_manager, self.label_manager,
                                 properties,
                                 save_probabilities),)
                        )
                    )
                if ofile is not None:
                    print(f'done with {os.path.basename(ofile)}')
                else:
                    print(f'\nDone with image of shape {data.shape}:')
            ret = [i.get()[0] for i in r]

        if isinstance(data_iterator, MultiThreadedAugmenter):
            data_iterator._finish()

        # clear lru cache
        compute_gaussian.cache_clear()
        # clear device cache
        empty_cache(self.device)
        return ret

    def My_predict_from_files(self,
                            list_of_lists_or_source_folder: Union[str, List[List[str]]],
                            output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
                            save_probabilities: bool = False,
                            overwrite: bool = True,
                            num_processes_preprocessing: int = default_num_processes,
                            num_processes_segmentation_export: int = default_num_processes,
                            folder_with_segs_from_prev_stage: str = None,
                            num_parts: int = 1,
                            part_id: int = 0,
                            all_features: List = None):
            """
            This is nnU-Net's default function for making predictions. It works best for batch predictions
            (predicting many images at once).
            """
            print("in My_predict_from_files")
            if isinstance(output_folder_or_list_of_truncated_output_files, str):
                output_folder = output_folder_or_list_of_truncated_output_files
            elif isinstance(output_folder_or_list_of_truncated_output_files, list):
                output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
            else:
                output_folder = None

            ########################
            # let's store the input arguments so that its clear what was used to generate the prediction
            if output_folder is not None:
                my_init_kwargs = {}
                for k in inspect.signature(self.predict_from_files).parameters.keys():
                    my_init_kwargs[k] = locals()[k]
                my_init_kwargs = deepcopy(
                    my_init_kwargs)  # let's not unintentionally change anything in-place. Take this as a
                recursive_fix_for_json_export(my_init_kwargs)
                maybe_mkdir_p(output_folder)
                save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))

                # we need these two if we want to do things with the predictions like for example apply postprocessing
                save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
                save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
            #######################

            # check if we need a prediction from the previous stage
            if self.configuration_manager.previous_stage_name is not None:
                assert folder_with_segs_from_prev_stage is not None, \
                    f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
                    f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
                    f' they are located via folder_with_segs_from_prev_stage'

            # sort out input and output filenames
            list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
                self._manage_input_and_output_lists(list_of_lists_or_source_folder,
                                                    output_folder_or_list_of_truncated_output_files,
                                                    folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
                                                    save_probabilities)
            if len(list_of_lists_or_source_folder) == 0:
                return

            data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
                                                                                    seg_from_prev_stage_files,
                                                                                    output_filename_truncated,
                                                                                    num_processes_preprocessing)

            return self.My_predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export, all_features)
        ########################################### My func end ###########################################

if __name__ == '__main__':
    # predict a bunch of files
    from nnunetv2.paths import nnUNet_results, nnUNet_raw

    all_features = []

       predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=False,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0),
        verbose=False, 
        verbose_preprocessing=False, 
        allow_tqdm=True
    )

    predictor.initialize_from_trained_model_folder(
        join(nnUNet_results, 'Dataset301_pretrain_removed/nnUNetTrainer__nnUNetPlans_forPretrain_removed_RFA4__2d'),
        use_folds=(0,),
        checkpoint_name='checkpoint_final.pth',
    )
    predictor.My_predict_from_files(join(nnUNet_raw, 'Dataset301_pretrain_removed/imagesTr_mini'),
                                 join(nnUNet_raw, 'Dataset301_pretrain_removed/imagesTr_prediction'), 
                                 save_probabilities=False, overwrite=True,
                                 num_processes_preprocessing=2, num_processes_segmentation_export=2,
                                 folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0,
                                 all_features=all_features)
U-ma-s commented 1 week ago

It does not occur when there are less than 9 images. It occurs when there are more than 10 images.