MIC-DKFZ / nnUNet

Apache License 2.0
5.9k stars 1.76k forks source link

Performance drop with region-based training ==> caching? #2136

Closed Goblaski closed 5 months ago

Goblaski commented 6 months ago

Let's start by saying that I really like the new feature with the region-based training. However, I do notice a significant drop in performance (CPU-side) when enabling this feature. This is especially the case when creating regions with a large amount of labels (e.g. > 10). This creates a CPU-bottleneck quite fast. In my case time per epoch goes from ~120 seconds (~40 regions separately) to ~400 seconds (defining 3 regions with these labels) resulting in barely using the GPU and 100% CPU.

I was wondering if there was perhaps an easy (caching) option within the nnUNet framework to merge these labels based on the dataset.json file. Alternatively I could manually merge the labels with some external code, but since the labels are already processed by the pipeline, there might be a good option to merge these prior.

TaWald commented 6 months ago

So currently in the augmentation the regions are mapped to a single channel each, which heavily impacts CPU performance. The amount of impact subsequently increases with the total amount of regions you have, so your observed behavior is not surprising.
Currently I am not aware if there is an elegant way of keeping CPU consumption down and I am also not sure how caching is supposed to resolve this.

But this being said if you find an elegant workaround for this issue it would be much appreciated :)

I will also ask around to see who else encountered this issue in order to see if someone else already found a workaround this CPU bottleneck

FabianIsensee commented 6 months ago

The conversion from labels to regions is done at the very end of the data augmentation pipeline, so the complexity of the data augmentation is linked to the numbers of labels, not the regions. Please see ConvertSegmentationToRegionsTransform in get_training_transforms (nnUNetTrainer). The conversion from labels to regions is not very expensive:

# assuming seg is (b, *shape)
seg_as_regions = np.zeros((b, num_regions, *shape), dtype=np.uint8)
for i, rl in enumerate(regions):
    for l in rl:
        seg_as_regions[:, i][seg == l] = 1

(pseudocode :arrow_up: ) I noticed that ConvertSegmentationToRegionsTransform is not perfect in its implementation, I will upload a revised version in ~30 Minutes. Still, neither the current not the new implementation should affect the CPU much. These operations are all rather cheap. Can you please double check and verify that this is actually what is causing the slowdown? Best, Fabian

FabianIsensee commented 6 months ago

Just uploaded the new ConvertSegmentationToRegionsTransform, check the master. It uses bitwise boolean operations now which should be faster. Please still look whether this is actually what's causing the problem. Region-based training uses a different loss as well

Goblaski commented 6 months ago

Just uploaded the new ConvertSegmentationToRegionsTransform, check the master. It uses bitwise boolean operations now which should be faster. Please still look whether this is actually what's causing the problem. Region-based training uses a different loss as well

Great! I will check this as soon as this training process is done or can be paused. I've tried to look into the code to see how predictions are treated with regards to regions. Is it correct that the output labels are following the regions as specified and won't undergo any merging/splitting? I will look into the loss function as well.

FabianIsensee commented 6 months ago

The network directly generates the regions. The number of network outputs is thus the number of regions

Goblaski commented 6 months ago

It doesn't seem to solve the issue. With a similar dataset and similar network but two different dataset.json files I still get wildly different epoch times. The amount of labels (from the source segmentation files) is the same, but in the second example I've region grouped the labels for my use case. Dataset used for this example was a TotalSegmentator sub-dataset.

The following dataset.json averages 160 seconds/epoch on my machine:

{
    "description": "",
    "labels": {
        "background": 0,
        "clavicula_left": 1,
        "rib_left_2": 10,
        "rib_left_3": 11,
        "rib_left_4": 12,
        "rib_left_5": 13,
        "rib_left_6": 14,
        "rib_left_7": 15,
        "rib_left_8": 16,
        "rib_left_9": 17,
        "rib_left_10": 18,
        "rib_left_11": 19,
        "clavicula_right": 2,
        "rib_left_12": 20,
        "rib_right_1": 21,
        "rib_right_2": 22,
        "rib_right_3": 23,
        "rib_right_4": 24,
        "rib_right_5": 25,
        "rib_right_6": 26,
        "rib_right_7": 27,
        "rib_right_8": 28,
        "rib_right_9": 29,
        "femur_left": 3,
        "rib_right_10": 30,
        "rib_right_11": 31,
        "rib_right_12": 32,
        "sacrum": 33,
        "scapula_left": 34,
        "scapula_right": 35,
        "vertebrae_C1": 36,
        "vertebrae_C2": 37,
        "vertebrae_C3": 38,
        "vertebrae_C4": 39,
        "femur_right": 4,
        "vertebrae_C5": 40,
        "vertebrae_C6": 41,
        "vertebrae_C7": 42,
        "vertebrae_L1": 43,
        "vertebrae_L2": 44,
        "vertebrae_L3": 45,
        "vertebrae_L4": 46,
        "vertebrae_L5": 47,
        "vertebrae_T1": 48,
        "vertebrae_T2": 49,
        "hip_left": 5,
        "vertebrae_T3": 50,
        "vertebrae_T4": 51,
        "vertebrae_T5": 52,
        "vertebrae_T6": 53,
        "vertebrae_T7": 54,
        "vertebrae_T8": 55,
        "vertebrae_T9": 56,
        "vertebrae_T10": 57,
        "vertebrae_T11": 58,
        "vertebrae_T12": 59,
        "hip_right": 6,
        "humerus_left": 7,
        "humerus_right": 8,
        "rib_left_1": 9
    },
    "licence": "",
    "name": "Task558_BoneSegDirect",
    "numTraining": 797,
    "reference": "",
    "release": "0.0",
    "channel_names": {
        "0": "CT-Scan"
    },
    "file_ending": ".nii.gz"
}

While the following gives me ~400 seconds/epoch:

{
    "description": "",
    "labels": {
        "background": 0,
    "bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
    "spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
        "vertebrae_C1": 36,
        "vertebrae_C2": 37,
        "vertebrae_C3": 38,
        "vertebrae_C4": 39,
        "vertebrae_C5": 40,
        "vertebrae_C6": 41,
        "vertebrae_C7": 42,
        "vertebrae_L1": 43,
        "vertebrae_L2": 44,
        "vertebrae_L3": 45,
        "vertebrae_L4": 46,
        "vertebrae_L5": 47,
        "vertebrae_T1": 48,
        "vertebrae_T2": 49,
        "vertebrae_T3": 50,
        "vertebrae_T4": 51,
        "vertebrae_T5": 52,
        "vertebrae_T6": 53,
        "vertebrae_T7": 54,
        "vertebrae_T8": 55,
        "vertebrae_T9": 56,
        "vertebrae_T10": 57,
        "vertebrae_T11": 58,
        "vertebrae_T12": 59
    },
    "regions_class_order": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26],
    "licence": "",
    "name": "Task558_BoneSegDirect",
    "numTraining": 797,
    "reference": "",
    "release": "0.0",
    "channel_names": {
        "0": "CT-Scan"
    },
    "file_ending": ".nii.gz"
}

After my current training is done (after the weekend I think) I will try to profile some parts of the code to see where the delay difference is coming from.

FabianIsensee commented 6 months ago

It's good to know that your problem arises from regions with many classes in them. I just pushed another optimization which should improve speed in your particular case. Let me know how this goes!

Goblaski commented 6 months ago

I've benchmarked the latest code optimization and came with the following epoch train times. All networks used the same network and same config (minus the output of course). This was after a few epochs.

All labels separately [Case 1]: train time 183.8743233680725 data loading: 39.64355134963989 train_step: 144.23077201843262 val time 12.878538608551025

{
    "description": "",
    "labels": {
        "background": 0,
        "clavicula_left": 1,
        "rib_left_2": 10,
        "rib_left_3": 11,
        "rib_left_4": 12,
        "rib_left_5": 13,
        "rib_left_6": 14,
        "rib_left_7": 15,
        "rib_left_8": 16,
        "rib_left_9": 17,
        "rib_left_10": 18,
        "rib_left_11": 19,
        "clavicula_right": 2,
        "rib_left_12": 20,
        "rib_right_1": 21,
        "rib_right_2": 22,
        "rib_right_3": 23,
        "rib_right_4": 24,
        "rib_right_5": 25,
        "rib_right_6": 26,
        "rib_right_7": 27,
        "rib_right_8": 28,
        "rib_right_9": 29,
        "femur_left": 3,
        "rib_right_10": 30,
        "rib_right_11": 31,
        "rib_right_12": 32,
        "sacrum": 33,
        "scapula_left": 34,
        "scapula_right": 35,
        "vertebrae_C1": 36,
        "vertebrae_C2": 37,
        "vertebrae_C3": 38,
        "vertebrae_C4": 39,
        "femur_right": 4,
        "vertebrae_C5": 40,
        "vertebrae_C6": 41,
        "vertebrae_C7": 42,
        "vertebrae_L1": 43,
        "vertebrae_L2": 44,
        "vertebrae_L3": 45,
        "vertebrae_L4": 46,
        "vertebrae_L5": 47,
        "vertebrae_T1": 48,
        "vertebrae_T2": 49,
        "hip_left": 5,
        "vertebrae_T3": 50,
        "vertebrae_T4": 51,
        "vertebrae_T5": 52,
        "vertebrae_T6": 53,
        "vertebrae_T7": 54,
        "vertebrae_T8": 55,
        "vertebrae_T9": 56,
        "vertebrae_T10": 57,
        "vertebrae_T11": 58,
        "vertebrae_T12": 59,
        "hip_right": 6,
        "humerus_left": 7,
        "humerus_right": 8,
        "rib_left_1": 9
    },
    "licence": "",
    "name": "Task558_BoneSegDirect",
    "numTraining": 797,
    "reference": "",
    "release": "0.0",
    "channel_names": {
        "0": "CT-Scan"
    },
    "file_ending": ".nii.gz"
}

Two composite regions [Case 2]: train time 185.98632621765137 data loading: 55.6424024105072 train_step: 130.34392380714417 val time 14.759568929672241

{
    "description": "",
    "labels": {
        "background": 0,
    "bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
    "spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59]
    },
    "regions_class_order": [1,2],
    "licence": "",
    "name": "Task558_BoneSegDirect",
    "numTraining": 797,
    "reference": "",
    "release": "0.0",
    "channel_names": {
        "0": "CT-Scan"
    },
    "file_ending": ".nii.gz"
}

Two regions + separate labels [Case 3]: train time 396.13739562034607 data loading: 179.99116373062134 train_step: 216.14623188972473 val time 81.28881335258484

{
    "description": "",
    "labels": {
        "background": 0,
    "bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
    "spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
        "vertebrae_C1": 36,
        "vertebrae_C2": 37,
        "vertebrae_C3": 38,
        "vertebrae_C4": 39,
        "vertebrae_C5": 40,
        "vertebrae_C6": 41,
        "vertebrae_C7": 42,
        "vertebrae_L1": 43,
        "vertebrae_L2": 44,
        "vertebrae_L3": 45,
        "vertebrae_L4": 46,
        "vertebrae_L5": 47,
        "vertebrae_T1": 48,
        "vertebrae_T2": 49,
        "vertebrae_T3": 50,
        "vertebrae_T4": 51,
        "vertebrae_T5": 52,
        "vertebrae_T6": 53,
        "vertebrae_T7": 54,
        "vertebrae_T8": 55,
        "vertebrae_T9": 56,
        "vertebrae_T10": 57,
        "vertebrae_T11": 58,
        "vertebrae_T12": 59
    },
    "regions_class_order": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26],
    "licence": "",
    "name": "Task558_BoneSegDirect",
    "numTraining": 797,
    "reference": "",
    "release": "0.0",
    "channel_names": {
        "0": "CT-Scan"
    },
    "file_ending": ".nii.gz"
}

It seems that the more regions used, the longer both the dataloading and training will take (loss?). Additionally, validation takes a longer time with the combination of regions and separate labels.

I haven't looked that deep into the nnUNet's dataloaders, but I will see if I can come up with some code that can create a new dataset based on the dataset.json file which will have pre-merged labels. However, I am wondering, if no regions are used, are overlapping labels allowed?

FabianIsensee commented 6 months ago

Normal training does not allow overlapping labels. If you are able to share a dataset generator that reproduces this problem we are happy to take a look! You can also just measure how long the loss computation takes to see if you can nail it down

Goblaski commented 6 months ago

A processor i've made works for non-overlapping merging of regions (which kinda defeats the purpose of the original region_class_order system), but does result in more lightweight regions. This ensures that the preprocessed files generated already have the labels used in training. Alternatively, to re-enable the regions, I could check for overlap between regions and define this as a separate label and re-introduce the more "lightweight" version.

I've used the default Data Generator as provided to create this issue in combination with the TotalSegmentator Dataset (subset). I can upload the raw or preprocessed files for download to provide a dataset which demonstrates this issue if that is also acceptable.

Btw below the preprocessor hook that works on non-overlapping label merging which might be base for creating one that allows over-lapping labels.

import multiprocessing
import shutil
from time import sleep
from typing import Union, Tuple

import nnunetv2
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
    create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets
from tqdm import tqdm

from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor

class RegionPreprocessor(DefaultPreprocessor):

    @staticmethod
    def convert_dataset_json(dataset_json_in):
        dataset_json = dataset_json_in.copy()
        mapping_lut = {}
        if 'regions_class_order' in dataset_json:
            labels = dataset_json['labels']
            mapping = {}
            label_indices = {key: idx for idx, key in enumerate(labels.keys())}

            for key, value in labels.items():
                if isinstance(value, list):
                    for v in value:
                        mapping[v] = label_indices[key]
                elif isinstance(value, int):
                    mapping[value] = label_indices[key]

            max_index = max(mapping.keys()) + 1
            mapping_lut = np.full(max_index, -1, dtype=int)
            for key, value in mapping.items():
                mapping_lut[key] = value

            dataset_json['labels'] = label_indices
            del dataset_json['regions_class_order']

        return dataset_json, mapping_lut

    def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
                      configuration_manager: ConfigurationManager) -> np.ndarray:
        if 'regions_class_order' in dataset_json:
            dataset_json, mapping_lut = self.convert_dataset_json(dataset_json)

            if np.max(seg) > 127:
                seg = seg.astype(np.int16)
            else:
                seg = seg.astype(np.int8)

            mapping_lut = mapping_lut.astype(seg.dtype)

            seg = mapping_lut[seg]

        return seg
Goblaski commented 6 months ago

Small optimization with length 1 regions. Considering there are potential regions with a single value (as in my case) the np.isin is a rather "expensive" operation.

In a test of 100 cases the time average (in debug) went from ~0.181 seconds per call to ~0.124 seconds per call. About 30% time-save. With identical result.

Unfortunately it didn't influence the training time, but still nice to save some performance. I will keep looking for other potential bottlenecks.

from typing import List, Tuple, Union

from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time

class ConvertSegmentationToRegionsTransform(AbstractTransform):
    def __init__(self, regions: Union[List, Tuple],
                 seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
        """
        regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,
        example:
        regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
        :param regions:
        :param seg_key:
        :param output_key:
        """
        self.seg_channel = seg_channel
        self.output_key = output_key
        self.seg_key = seg_key
        self.regions = regions

    def __call__(self, **data_dict):
        seg = data_dict.get(self.seg_key)
        if seg is not None:
            b, c, *shape = seg.shape
            region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
            for region_id, region_labels in enumerate(self.regions):
                if isinstance(region_labels, int) or len(region_labels) == 1:
                    if not isinstance(region_labels, int):
                        region_labels = region_labels[0]
                    region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
                else:
                    region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
            data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)

        return data_dict
Goblaski commented 6 months ago

Also identified a slower function with many channels (and since the regions were individual channels, this would make sense to also have a delay). With this new implementation of the DownsampleSegForDSTransform2 I am about 20% faster (380->300 seconds) on my benchmark. I think there might be able to get even faster, since I think my implementation might not be optimal.

Edit this version has a small error for the initial [1 1 1] scale version. Since I am still looking in the issue I will leave it here for now.

from typing import Tuple, Union, List

from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time

class DownsampleSegForDSTransform2(AbstractTransform):
    '''
    data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
    '''
    def __init__(self, ds_scales: Union[List, Tuple],
                 order: int = 0, input_key: str = "seg",
                 output_key: str = "seg", axes: Tuple[int] = None):
        """
        Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
        output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
        ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
        for each axis independently
        """
        self.axes = axes
        self.output_key = output_key
        self.input_key = input_key
        self.order = order
        self.ds_scales = ds_scales

    def encode_channels(self, data):
        encoded = np.zeros(data.shape[1:], dtype=np.uint64)
        for c in range(data.shape[0]):
            encoded |= data[c].astype(np.uint64) << c

        encoded = encoded.astype(np.float32)
        return encoded

    def decode_channels(self, data, num_channels):
        data = data.astype(np.uint64)
        decoded = np.zeros((num_channels,) + data.shape, dtype=bool)
        for c in range(num_channels):
            decoded[c] = (data >> c) & 1
        return decoded

    def __call__(self, **data_dict):
        if self.axes is None:
            axes = list(range(2, data_dict[self.input_key].ndim))
        else:
            axes = self.axes

        output = []
        for s in self.ds_scales:
            if not isinstance(s, (tuple, list)):
                s = [s] * len(axes)
            else:
                assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
                                            f'for each axis) then the number of entried in that tuple (here ' \
                                            f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'

            if all([i == 1 for i in s]):
                output.append(data_dict[self.input_key])
            else:
                new_shape = np.array(data_dict[self.input_key].shape).astype(float)
                for i, a in enumerate(axes):
                    new_shape[a] *= s[i]
                new_shape = np.round(new_shape).astype(int)
                out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
                for b in range(data_dict[self.input_key].shape[0]):
                    num_channels = data_dict[self.input_key].shape[1]
                    if num_channels > 1 and num_channels < 64: #UINT64 limit
                        # Encode channels 
                        encoded = self.encode_channels(data_dict[self.input_key][b])

                        # Resize (as usual)
                        data_resized = resize_segmentation(encoded, new_shape[2:], self.order)
                        data_resized = data_resized.astype(np.uint64)

                        # Decode back to region shape
                        decoded = self.decode_channels(data_resized,num_channels)

                        out_seg[b] = decoded.astype(out_seg.dtype)
                    else:
                        for c in range(num_channels):
                            out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)

                output.append(out_seg)
        data_dict[self.output_key] = output
        return data_dict
FabianIsensee commented 6 months ago

Hey please try to find the bottleneck first before doing optimizations - is the bottleneck really the data augmentation? Or is it in the loss computation? This is straightforward to find out and a lot less effort to dive deep into the different components of nnU-Net ;-)

Goblaski commented 6 months ago

Loss timing is roughly the same in all 3 cases I've tested. ~ 3ms. Predict/Inference is about 20-30 ms for my current Network. My bad I should have also reported that earlier.

Data Loading (from the SSD) is also the same between all 3 cases. I've tracked it down till the data augmentation which wildly varies on 2/3 steps in the data augmentation in case there are multiple regions (and therefore channels).

SpatialTransform (in some cases, seems to be triggered randomly) DownsampleSegForDSTransform2 (consistently, but now reduced with the above method) SimulateLowResolutionTransform (in some cases, seems to be triggered randomly and the non-region cases seem less affected).

I've been able to track it down to every case where samples with [b,c,x,y,z] have a c > 1. This so far worked for DownsampleSegForDSTransform2.

In the case of SpatialTransform I think it is in the function interpolate_img.

FabianIsensee commented 6 months ago

OK so it's confirmed that you lose a lot of time in the data augmentation? You can try something like this to track down where the problem is:

https://github.com/MIC-DKFZ/batchgeneratorsv2/blob/master/batchgeneratorsv2/benchmarks/bg_comparison/nnUNet_pipeline_bg.py

(simulate the segmentations and regions to match your dataset)

Goblaski commented 6 months ago

After long time of debugging I've found out that there were two major culprits and am now down back to ~190 seconds for the task that took me ~400 seconds before. Training run shows identical results as compared to the 400 second run. This doesn't seem to break anything. I can make a pull request if useful.

There are still a few optimizations that could be made, but those will be in the form of trade-off between cpu-cycles vs ram useage.

The culprits were:

  1. DownsampleSegForDSTransform2 takes a long time because resize_segmentation has to be done for every channel (https://github.com/MIC-DKFZ/nnUNet/blob/433dd9ca2ca77f3197c3c3cb6c61e009007a4dc1/nnunetv2/training/data_augmentation/custom_transforms/deep_supervision_donwsampling.py#L52). ~0.3 seconds in single channel mode, ~4 seconds in multi channel mode. Now to about 1.3 seconds after the fix.
  2. Copying the many channels to the device is also slow (https://github.com/MIC-DKFZ/nnUNet/blob/433dd9ca2ca77f3197c3c3cb6c61e009007a4dc1/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py#L963). ~0.01 seconds in single channel mode, 0.3 seconds in multi channel mode (with high CPU load). Identical speed after the fix.

What I've done:

  1. Updated DownsampleSegForDSTransform2 to first lazy encode the channels to a single channel by bitwise operators (cheap).
  2. Added a bitwise decoder for torch to do bitwise decoding on GPU in the training and validation step. For that I've added a file in _utilities/bitwiseconverters.py .

The actual fix:

Updated _deep_supervisiondonwsampling.py

from typing import Tuple, Union, List

from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time

class DownsampleSegForDSTransform2(AbstractTransform):
    '''
    data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
    '''
    def __init__(self, ds_scales: Union[List, Tuple],
                 order: int = 0, input_key: str = "seg",
                 output_key: str = "seg", axes: Tuple[int] = None):
        """
        Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
        output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
        ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
        for each axis independently
        """
        self.axes = axes
        self.output_key = output_key
        self.input_key = input_key
        self.order = order
        self.ds_scales = ds_scales

    def encode_channels(self, data):
        encoded = np.zeros(data.shape[1:], dtype=np.int64)
        for c in range(data.shape[0]):
            encoded |= data[c].astype(np.int64) << c

        encoded = encoded.astype(np.float32)
        return encoded

    # Unused, but for reference to decode channels if required
    def decode_channels(self, data, num_channels):
        data = data.astype(np.int64)
        decoded = np.zeros((num_channels,) + data.shape, dtype=bool)
        for c in range(num_channels):
            decoded[c] = (data >> c) & 1
        return decoded

    def __call__(self, **data_dict):
        if self.axes is None:
            axes = list(range(2, data_dict[self.input_key].ndim))
        else:
            axes = self.axes

        # First determine the number of channels and do the bitwise encoding. This is a lazy function and assumes that a single segmentation can occur in every channel.
        num_channels = data_dict[self.input_key].shape[1]
        # Since we use int64 due to the lack of uint64 native support on default pytorch, and I don't feel using the negative range of the int64 we wil only use at most 31 channels.
        # This can ofcourse be optimized for memory use with a little bit of impact on CPU cycles.
        if num_channels > 1 and num_channels < 32: 
            encoded_batch = []
            for b in range(data_dict[self.input_key].shape[0]):
                encoded_batch.append(self.encode_channels(data_dict[self.input_key][b]))

        output = []
        for s in self.ds_scales:
            if not isinstance(s, (tuple, list)):
                s = [s] * len(axes)
            else:
                assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
                                            f'for each axis) then the number of entried in that tuple (here ' \
                                            f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'

            # Check again if we have the right amount of channels for encoding.
            if num_channels > 1 and num_channels < 32:
                new_shape = np.array(data_dict[self.input_key].shape).astype(float)
                for i, a in enumerate(axes):
                    new_shape[a] *= s[i]

                new_shape[1] = 1
                new_shape = np.round(new_shape).astype(int)
                out_seg = np.zeros(new_shape, dtype=np.float32)

                for b in range(data_dict[self.input_key].shape[0]):
                    # Encode channels 
                    encoded = encoded_batch[b]

                    if not all([i == 1 for i in s]):
                        # Resize (as usual)
                        data_resized = resize_segmentation(encoded, new_shape[2:], self.order)
                        data_resized = data_resized.astype(np.uint64)

                        # Assign
                        out_seg[b] = data_resized.astype(out_seg.dtype)
                    else:
                        out_seg[b] = encoded.astype(out_seg.dtype)

                output.append(out_seg)
            else: # Non-encodable
                if all([i == 1 for i in s]):
                    output.append(data_dict[self.input_key])
                else:
                    num_channels = data_dict[self.input_key].shape[1]
                    new_shape = np.array(data_dict[self.input_key].shape).astype(float)
                    for i, a in enumerate(axes):
                        new_shape[a] *= s[i]

                    new_shape = np.round(new_shape).astype(int)
                    out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)

                    for b in range(data_dict[self.input_key].shape[0]):
                        for c in range(num_channels):
                            out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)

                    output.append(out_seg)
        data_dict[self.output_key] = output
        return data_dict

First part of nnUNetTrainer.py train_step:

    def train_step(self, batch: dict) -> dict:
        data = batch['data']
        target = batch['target']

        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        if self.label_manager.num_segmentation_heads > 1 and self.label_manager.num_segmentation_heads < 32:
            target = [decode_channels_torch(t,self.label_manager.num_segmentation_heads) for t in target]
First part of nnUNetTrainer.py validation_step:

    def validation_step(self, batch: dict) -> dict:
        data = batch['data']
        target = batch['target']

        data = data.to(self.device, non_blocking=True)
        if isinstance(target, list):
            target = [i.to(self.device, non_blocking=True) for i in target]
        else:
            target = target.to(self.device, non_blocking=True)

        if self.label_manager.num_segmentation_heads > 1 and self.label_manager.num_segmentation_heads < 32:
            target = [decode_channels_torch(t,self.label_manager.num_segmentation_heads) for t in target]

bitwise_converters.py:

import torch

def encode_channels_torch(data : torch.Tensor):
    # Assuming shape: [b, c, *dims] (dims can be x, y, z, w, ...)
    b, c, *dims = data.shape
    encoded = torch.zeros((b, 1, *dims), dtype=torch.int64, device=data.device)
    data_tmp = data.to(torch.int64)
    for j in range(c):
        encoded[:, 0] |= data_tmp[:, j] << j
    return encoded

def decode_channels_torch(data : torch.Tensor, num_channels : int):
    # Assuming shape: [b, 1, *dims]
    b, _, *dims = data.shape
    decoded = torch.zeros((b, num_channels, *dims), dtype=bool, device=data.device)
    data_tmp = data.to(torch.int64, non_blocking=True)
    for j in range(num_channels):
        decoded[:, j] = (data_tmp[:, 0] >> j) & 1

    decoded = decoded.to(torch.float32, non_blocking=True)
    return decoded
FabianIsensee commented 5 months ago

Hey, really cool stuff! I like the idea of bitwise encoding, but there is a much easier solution: just make the dtype of the regions bool. This is now the case in the current master (and batchgenerators 2 master) and should alleviate your issues!