Closed Goblaski closed 5 months ago
So currently in the augmentation the regions are mapped to a single channel each, which heavily impacts CPU performance. The amount of impact subsequently increases with the total amount of regions you have, so your observed behavior is not surprising.
Currently I am not aware if there is an elegant way of keeping CPU consumption down and I am also not sure how caching is supposed to resolve this.
But this being said if you find an elegant workaround for this issue it would be much appreciated :)
I will also ask around to see who else encountered this issue in order to see if someone else already found a workaround this CPU bottleneck
The conversion from labels to regions is done at the very end of the data augmentation pipeline, so the complexity of the data augmentation is linked to the numbers of labels, not the regions. Please see ConvertSegmentationToRegionsTransform
in get_training_transforms
(nnUNetTrainer).
The conversion from labels to regions is not very expensive:
# assuming seg is (b, *shape)
seg_as_regions = np.zeros((b, num_regions, *shape), dtype=np.uint8)
for i, rl in enumerate(regions):
for l in rl:
seg_as_regions[:, i][seg == l] = 1
(pseudocode :arrow_up: ) I noticed that ConvertSegmentationToRegionsTransform is not perfect in its implementation, I will upload a revised version in ~30 Minutes. Still, neither the current not the new implementation should affect the CPU much. These operations are all rather cheap. Can you please double check and verify that this is actually what is causing the slowdown? Best, Fabian
Just uploaded the new ConvertSegmentationToRegionsTransform, check the master. It uses bitwise boolean operations now which should be faster. Please still look whether this is actually what's causing the problem. Region-based training uses a different loss as well
Just uploaded the new ConvertSegmentationToRegionsTransform, check the master. It uses bitwise boolean operations now which should be faster. Please still look whether this is actually what's causing the problem. Region-based training uses a different loss as well
Great! I will check this as soon as this training process is done or can be paused. I've tried to look into the code to see how predictions are treated with regards to regions. Is it correct that the output labels are following the regions as specified and won't undergo any merging/splitting? I will look into the loss function as well.
The network directly generates the regions. The number of network outputs is thus the number of regions
It doesn't seem to solve the issue. With a similar dataset and similar network but two different dataset.json files I still get wildly different epoch times. The amount of labels (from the source segmentation files) is the same, but in the second example I've region grouped the labels for my use case. Dataset used for this example was a TotalSegmentator sub-dataset.
The following dataset.json averages 160 seconds/epoch on my machine:
{
"description": "",
"labels": {
"background": 0,
"clavicula_left": 1,
"rib_left_2": 10,
"rib_left_3": 11,
"rib_left_4": 12,
"rib_left_5": 13,
"rib_left_6": 14,
"rib_left_7": 15,
"rib_left_8": 16,
"rib_left_9": 17,
"rib_left_10": 18,
"rib_left_11": 19,
"clavicula_right": 2,
"rib_left_12": 20,
"rib_right_1": 21,
"rib_right_2": 22,
"rib_right_3": 23,
"rib_right_4": 24,
"rib_right_5": 25,
"rib_right_6": 26,
"rib_right_7": 27,
"rib_right_8": 28,
"rib_right_9": 29,
"femur_left": 3,
"rib_right_10": 30,
"rib_right_11": 31,
"rib_right_12": 32,
"sacrum": 33,
"scapula_left": 34,
"scapula_right": 35,
"vertebrae_C1": 36,
"vertebrae_C2": 37,
"vertebrae_C3": 38,
"vertebrae_C4": 39,
"femur_right": 4,
"vertebrae_C5": 40,
"vertebrae_C6": 41,
"vertebrae_C7": 42,
"vertebrae_L1": 43,
"vertebrae_L2": 44,
"vertebrae_L3": 45,
"vertebrae_L4": 46,
"vertebrae_L5": 47,
"vertebrae_T1": 48,
"vertebrae_T2": 49,
"hip_left": 5,
"vertebrae_T3": 50,
"vertebrae_T4": 51,
"vertebrae_T5": 52,
"vertebrae_T6": 53,
"vertebrae_T7": 54,
"vertebrae_T8": 55,
"vertebrae_T9": 56,
"vertebrae_T10": 57,
"vertebrae_T11": 58,
"vertebrae_T12": 59,
"hip_right": 6,
"humerus_left": 7,
"humerus_right": 8,
"rib_left_1": 9
},
"licence": "",
"name": "Task558_BoneSegDirect",
"numTraining": 797,
"reference": "",
"release": "0.0",
"channel_names": {
"0": "CT-Scan"
},
"file_ending": ".nii.gz"
}
While the following gives me ~400 seconds/epoch:
{
"description": "",
"labels": {
"background": 0,
"bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
"spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
"vertebrae_C1": 36,
"vertebrae_C2": 37,
"vertebrae_C3": 38,
"vertebrae_C4": 39,
"vertebrae_C5": 40,
"vertebrae_C6": 41,
"vertebrae_C7": 42,
"vertebrae_L1": 43,
"vertebrae_L2": 44,
"vertebrae_L3": 45,
"vertebrae_L4": 46,
"vertebrae_L5": 47,
"vertebrae_T1": 48,
"vertebrae_T2": 49,
"vertebrae_T3": 50,
"vertebrae_T4": 51,
"vertebrae_T5": 52,
"vertebrae_T6": 53,
"vertebrae_T7": 54,
"vertebrae_T8": 55,
"vertebrae_T9": 56,
"vertebrae_T10": 57,
"vertebrae_T11": 58,
"vertebrae_T12": 59
},
"regions_class_order": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26],
"licence": "",
"name": "Task558_BoneSegDirect",
"numTraining": 797,
"reference": "",
"release": "0.0",
"channel_names": {
"0": "CT-Scan"
},
"file_ending": ".nii.gz"
}
After my current training is done (after the weekend I think) I will try to profile some parts of the code to see where the delay difference is coming from.
It's good to know that your problem arises from regions with many classes in them. I just pushed another optimization which should improve speed in your particular case. Let me know how this goes!
I've benchmarked the latest code optimization and came with the following epoch train times. All networks used the same network and same config (minus the output of course). This was after a few epochs.
All labels separately [Case 1]: train time 183.8743233680725 data loading: 39.64355134963989 train_step: 144.23077201843262 val time 12.878538608551025
{
"description": "",
"labels": {
"background": 0,
"clavicula_left": 1,
"rib_left_2": 10,
"rib_left_3": 11,
"rib_left_4": 12,
"rib_left_5": 13,
"rib_left_6": 14,
"rib_left_7": 15,
"rib_left_8": 16,
"rib_left_9": 17,
"rib_left_10": 18,
"rib_left_11": 19,
"clavicula_right": 2,
"rib_left_12": 20,
"rib_right_1": 21,
"rib_right_2": 22,
"rib_right_3": 23,
"rib_right_4": 24,
"rib_right_5": 25,
"rib_right_6": 26,
"rib_right_7": 27,
"rib_right_8": 28,
"rib_right_9": 29,
"femur_left": 3,
"rib_right_10": 30,
"rib_right_11": 31,
"rib_right_12": 32,
"sacrum": 33,
"scapula_left": 34,
"scapula_right": 35,
"vertebrae_C1": 36,
"vertebrae_C2": 37,
"vertebrae_C3": 38,
"vertebrae_C4": 39,
"femur_right": 4,
"vertebrae_C5": 40,
"vertebrae_C6": 41,
"vertebrae_C7": 42,
"vertebrae_L1": 43,
"vertebrae_L2": 44,
"vertebrae_L3": 45,
"vertebrae_L4": 46,
"vertebrae_L5": 47,
"vertebrae_T1": 48,
"vertebrae_T2": 49,
"hip_left": 5,
"vertebrae_T3": 50,
"vertebrae_T4": 51,
"vertebrae_T5": 52,
"vertebrae_T6": 53,
"vertebrae_T7": 54,
"vertebrae_T8": 55,
"vertebrae_T9": 56,
"vertebrae_T10": 57,
"vertebrae_T11": 58,
"vertebrae_T12": 59,
"hip_right": 6,
"humerus_left": 7,
"humerus_right": 8,
"rib_left_1": 9
},
"licence": "",
"name": "Task558_BoneSegDirect",
"numTraining": 797,
"reference": "",
"release": "0.0",
"channel_names": {
"0": "CT-Scan"
},
"file_ending": ".nii.gz"
}
Two composite regions [Case 2]: train time 185.98632621765137 data loading: 55.6424024105072 train_step: 130.34392380714417 val time 14.759568929672241
{
"description": "",
"labels": {
"background": 0,
"bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
"spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59]
},
"regions_class_order": [1,2],
"licence": "",
"name": "Task558_BoneSegDirect",
"numTraining": 797,
"reference": "",
"release": "0.0",
"channel_names": {
"0": "CT-Scan"
},
"file_ending": ".nii.gz"
}
Two regions + separate labels [Case 3]: train time 396.13739562034607 data loading: 179.99116373062134 train_step: 216.14623188972473 val time 81.28881335258484
{
"description": "",
"labels": {
"background": 0,
"bones": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
"spine": [36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59],
"vertebrae_C1": 36,
"vertebrae_C2": 37,
"vertebrae_C3": 38,
"vertebrae_C4": 39,
"vertebrae_C5": 40,
"vertebrae_C6": 41,
"vertebrae_C7": 42,
"vertebrae_L1": 43,
"vertebrae_L2": 44,
"vertebrae_L3": 45,
"vertebrae_L4": 46,
"vertebrae_L5": 47,
"vertebrae_T1": 48,
"vertebrae_T2": 49,
"vertebrae_T3": 50,
"vertebrae_T4": 51,
"vertebrae_T5": 52,
"vertebrae_T6": 53,
"vertebrae_T7": 54,
"vertebrae_T8": 55,
"vertebrae_T9": 56,
"vertebrae_T10": 57,
"vertebrae_T11": 58,
"vertebrae_T12": 59
},
"regions_class_order": [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26],
"licence": "",
"name": "Task558_BoneSegDirect",
"numTraining": 797,
"reference": "",
"release": "0.0",
"channel_names": {
"0": "CT-Scan"
},
"file_ending": ".nii.gz"
}
It seems that the more regions used, the longer both the dataloading and training will take (loss?). Additionally, validation takes a longer time with the combination of regions and separate labels.
I haven't looked that deep into the nnUNet's dataloaders, but I will see if I can come up with some code that can create a new dataset based on the dataset.json file which will have pre-merged labels. However, I am wondering, if no regions are used, are overlapping labels allowed?
Normal training does not allow overlapping labels. If you are able to share a dataset generator that reproduces this problem we are happy to take a look! You can also just measure how long the loss computation takes to see if you can nail it down
A processor i've made works for non-overlapping merging of regions (which kinda defeats the purpose of the original region_class_order system), but does result in more lightweight regions. This ensures that the preprocessed files generated already have the labels used in training. Alternatively, to re-enable the regions, I could check for overlap between regions and define this as a separate label and re-introduce the more "lightweight" version.
I've used the default Data Generator as provided to create this issue in combination with the TotalSegmentator Dataset (subset). I can upload the raw or preprocessed files for download to provide a dataset which demonstrates this issue if that is also acceptable.
Btw below the preprocessor hook that works on non-overlapping label merging which might be base for creating one that allows over-lapping labels.
import multiprocessing
import shutil
from time import sleep
from typing import Union, Tuple
import nnunetv2
import numpy as np
from batchgenerators.utilities.file_and_folder_operations import *
from nnunetv2.paths import nnUNet_preprocessed, nnUNet_raw
from nnunetv2.preprocessing.cropping.cropping import crop_to_nonzero
from nnunetv2.preprocessing.resampling.default_resampling import compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.find_class_by_name import recursive_find_python_class
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.utilities.utils import get_identifiers_from_splitted_dataset_folder, \
create_lists_from_splitted_dataset_folder, get_filenames_of_train_images_and_targets
from tqdm import tqdm
from nnunetv2.preprocessing.preprocessors.default_preprocessor import DefaultPreprocessor
class RegionPreprocessor(DefaultPreprocessor):
@staticmethod
def convert_dataset_json(dataset_json_in):
dataset_json = dataset_json_in.copy()
mapping_lut = {}
if 'regions_class_order' in dataset_json:
labels = dataset_json['labels']
mapping = {}
label_indices = {key: idx for idx, key in enumerate(labels.keys())}
for key, value in labels.items():
if isinstance(value, list):
for v in value:
mapping[v] = label_indices[key]
elif isinstance(value, int):
mapping[value] = label_indices[key]
max_index = max(mapping.keys()) + 1
mapping_lut = np.full(max_index, -1, dtype=int)
for key, value in mapping.items():
mapping_lut[key] = value
dataset_json['labels'] = label_indices
del dataset_json['regions_class_order']
return dataset_json, mapping_lut
def modify_seg_fn(self, seg: np.ndarray, plans_manager: PlansManager, dataset_json: dict,
configuration_manager: ConfigurationManager) -> np.ndarray:
if 'regions_class_order' in dataset_json:
dataset_json, mapping_lut = self.convert_dataset_json(dataset_json)
if np.max(seg) > 127:
seg = seg.astype(np.int16)
else:
seg = seg.astype(np.int8)
mapping_lut = mapping_lut.astype(seg.dtype)
seg = mapping_lut[seg]
return seg
Small optimization with length 1 regions. Considering there are potential regions with a single value (as in my case) the np.isin is a rather "expensive" operation.
In a test of 100 cases the time average (in debug) went from ~0.181 seconds per call to ~0.124 seconds per call. About 30% time-save. With identical result.
Unfortunately it didn't influence the training time, but still nice to save some performance. I will keep looking for other potential bottlenecks.
from typing import List, Tuple, Union
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time
class ConvertSegmentationToRegionsTransform(AbstractTransform):
def __init__(self, regions: Union[List, Tuple],
seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
"""
regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region,
example:
regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
:param regions:
:param seg_key:
:param output_key:
"""
self.seg_channel = seg_channel
self.output_key = output_key
self.seg_key = seg_key
self.regions = regions
def __call__(self, **data_dict):
seg = data_dict.get(self.seg_key)
if seg is not None:
b, c, *shape = seg.shape
region_output = np.zeros((b, len(self.regions), *shape), dtype=bool)
for region_id, region_labels in enumerate(self.regions):
if isinstance(region_labels, int) or len(region_labels) == 1:
if not isinstance(region_labels, int):
region_labels = region_labels[0]
region_output[:, region_id] = seg[:, self.seg_channel] == region_labels
else:
region_output[:, region_id] |= np.isin(seg[:, self.seg_channel], region_labels)
data_dict[self.output_key] = region_output.astype(np.uint8, copy=False)
return data_dict
Also identified a slower function with many channels (and since the regions were individual channels, this would make sense to also have a delay). With this new implementation of the DownsampleSegForDSTransform2 I am about 20% faster (380->300 seconds) on my benchmark. I think there might be able to get even faster, since I think my implementation might not be optimal.
Edit this version has a small error for the initial [1 1 1] scale version. Since I am still looking in the issue I will leave it here for now.
from typing import Tuple, Union, List
from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time
class DownsampleSegForDSTransform2(AbstractTransform):
'''
data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
'''
def __init__(self, ds_scales: Union[List, Tuple],
order: int = 0, input_key: str = "seg",
output_key: str = "seg", axes: Tuple[int] = None):
"""
Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
for each axis independently
"""
self.axes = axes
self.output_key = output_key
self.input_key = input_key
self.order = order
self.ds_scales = ds_scales
def encode_channels(self, data):
encoded = np.zeros(data.shape[1:], dtype=np.uint64)
for c in range(data.shape[0]):
encoded |= data[c].astype(np.uint64) << c
encoded = encoded.astype(np.float32)
return encoded
def decode_channels(self, data, num_channels):
data = data.astype(np.uint64)
decoded = np.zeros((num_channels,) + data.shape, dtype=bool)
for c in range(num_channels):
decoded[c] = (data >> c) & 1
return decoded
def __call__(self, **data_dict):
if self.axes is None:
axes = list(range(2, data_dict[self.input_key].ndim))
else:
axes = self.axes
output = []
for s in self.ds_scales:
if not isinstance(s, (tuple, list)):
s = [s] * len(axes)
else:
assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
f'for each axis) then the number of entried in that tuple (here ' \
f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
if all([i == 1 for i in s]):
output.append(data_dict[self.input_key])
else:
new_shape = np.array(data_dict[self.input_key].shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
for b in range(data_dict[self.input_key].shape[0]):
num_channels = data_dict[self.input_key].shape[1]
if num_channels > 1 and num_channels < 64: #UINT64 limit
# Encode channels
encoded = self.encode_channels(data_dict[self.input_key][b])
# Resize (as usual)
data_resized = resize_segmentation(encoded, new_shape[2:], self.order)
data_resized = data_resized.astype(np.uint64)
# Decode back to region shape
decoded = self.decode_channels(data_resized,num_channels)
out_seg[b] = decoded.astype(out_seg.dtype)
else:
for c in range(num_channels):
out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
output.append(out_seg)
data_dict[self.output_key] = output
return data_dict
Hey please try to find the bottleneck first before doing optimizations - is the bottleneck really the data augmentation? Or is it in the loss computation? This is straightforward to find out and a lot less effort to dive deep into the different components of nnU-Net ;-)
Loss timing is roughly the same in all 3 cases I've tested. ~ 3ms. Predict/Inference is about 20-30 ms for my current Network. My bad I should have also reported that earlier.
Data Loading (from the SSD) is also the same between all 3 cases. I've tracked it down till the data augmentation which wildly varies on 2/3 steps in the data augmentation in case there are multiple regions (and therefore channels).
SpatialTransform (in some cases, seems to be triggered randomly) DownsampleSegForDSTransform2 (consistently, but now reduced with the above method) SimulateLowResolutionTransform (in some cases, seems to be triggered randomly and the non-region cases seem less affected).
I've been able to track it down to every case where samples with [b,c,x,y,z] have a c > 1. This so far worked for DownsampleSegForDSTransform2.
In the case of SpatialTransform I think it is in the function interpolate_img.
OK so it's confirmed that you lose a lot of time in the data augmentation? You can try something like this to track down where the problem is:
(simulate the segmentations and regions to match your dataset)
After long time of debugging I've found out that there were two major culprits and am now down back to ~190 seconds for the task that took me ~400 seconds before. Training run shows identical results as compared to the 400 second run. This doesn't seem to break anything. I can make a pull request if useful.
There are still a few optimizations that could be made, but those will be in the form of trade-off between cpu-cycles vs ram useage.
The culprits were:
What I've done:
The actual fix:
Updated _deep_supervisiondonwsampling.py
from typing import Tuple, Union, List
from batchgenerators.augmentations.utils import resize_segmentation
from batchgenerators.transforms.abstract_transforms import AbstractTransform
import numpy as np
from time import time
class DownsampleSegForDSTransform2(AbstractTransform):
'''
data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
'''
def __init__(self, ds_scales: Union[List, Tuple],
order: int = 0, input_key: str = "seg",
output_key: str = "seg", axes: Tuple[int] = None):
"""
Downscales data_dict[input_key] according to ds_scales. Each entry in ds_scales specified one deep supervision
output and its resolution relative to the original data, for example 0.25 specifies 1/4 of the original shape.
ds_scales can also be a tuple of tuples, for example ((1, 1, 1), (0.5, 0.5, 0.5)) to specify the downsampling
for each axis independently
"""
self.axes = axes
self.output_key = output_key
self.input_key = input_key
self.order = order
self.ds_scales = ds_scales
def encode_channels(self, data):
encoded = np.zeros(data.shape[1:], dtype=np.int64)
for c in range(data.shape[0]):
encoded |= data[c].astype(np.int64) << c
encoded = encoded.astype(np.float32)
return encoded
# Unused, but for reference to decode channels if required
def decode_channels(self, data, num_channels):
data = data.astype(np.int64)
decoded = np.zeros((num_channels,) + data.shape, dtype=bool)
for c in range(num_channels):
decoded[c] = (data >> c) & 1
return decoded
def __call__(self, **data_dict):
if self.axes is None:
axes = list(range(2, data_dict[self.input_key].ndim))
else:
axes = self.axes
# First determine the number of channels and do the bitwise encoding. This is a lazy function and assumes that a single segmentation can occur in every channel.
num_channels = data_dict[self.input_key].shape[1]
# Since we use int64 due to the lack of uint64 native support on default pytorch, and I don't feel using the negative range of the int64 we wil only use at most 31 channels.
# This can ofcourse be optimized for memory use with a little bit of impact on CPU cycles.
if num_channels > 1 and num_channels < 32:
encoded_batch = []
for b in range(data_dict[self.input_key].shape[0]):
encoded_batch.append(self.encode_channels(data_dict[self.input_key][b]))
output = []
for s in self.ds_scales:
if not isinstance(s, (tuple, list)):
s = [s] * len(axes)
else:
assert len(s) == len(axes), f'If ds_scales is a tuple for each resolution (one downsampling factor ' \
f'for each axis) then the number of entried in that tuple (here ' \
f'{len(s)}) must be the same as the number of axes (here {len(axes)}).'
# Check again if we have the right amount of channels for encoding.
if num_channels > 1 and num_channels < 32:
new_shape = np.array(data_dict[self.input_key].shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape[1] = 1
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=np.float32)
for b in range(data_dict[self.input_key].shape[0]):
# Encode channels
encoded = encoded_batch[b]
if not all([i == 1 for i in s]):
# Resize (as usual)
data_resized = resize_segmentation(encoded, new_shape[2:], self.order)
data_resized = data_resized.astype(np.uint64)
# Assign
out_seg[b] = data_resized.astype(out_seg.dtype)
else:
out_seg[b] = encoded.astype(out_seg.dtype)
output.append(out_seg)
else: # Non-encodable
if all([i == 1 for i in s]):
output.append(data_dict[self.input_key])
else:
num_channels = data_dict[self.input_key].shape[1]
new_shape = np.array(data_dict[self.input_key].shape).astype(float)
for i, a in enumerate(axes):
new_shape[a] *= s[i]
new_shape = np.round(new_shape).astype(int)
out_seg = np.zeros(new_shape, dtype=data_dict[self.input_key].dtype)
for b in range(data_dict[self.input_key].shape[0]):
for c in range(num_channels):
out_seg[b, c] = resize_segmentation(data_dict[self.input_key][b, c], new_shape[2:], self.order)
output.append(out_seg)
data_dict[self.output_key] = output
return data_dict
First part of nnUNetTrainer.py train_step:
def train_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
if self.label_manager.num_segmentation_heads > 1 and self.label_manager.num_segmentation_heads < 32:
target = [decode_channels_torch(t,self.label_manager.num_segmentation_heads) for t in target]
First part of nnUNetTrainer.py validation_step:
def validation_step(self, batch: dict) -> dict:
data = batch['data']
target = batch['target']
data = data.to(self.device, non_blocking=True)
if isinstance(target, list):
target = [i.to(self.device, non_blocking=True) for i in target]
else:
target = target.to(self.device, non_blocking=True)
if self.label_manager.num_segmentation_heads > 1 and self.label_manager.num_segmentation_heads < 32:
target = [decode_channels_torch(t,self.label_manager.num_segmentation_heads) for t in target]
bitwise_converters.py:
import torch
def encode_channels_torch(data : torch.Tensor):
# Assuming shape: [b, c, *dims] (dims can be x, y, z, w, ...)
b, c, *dims = data.shape
encoded = torch.zeros((b, 1, *dims), dtype=torch.int64, device=data.device)
data_tmp = data.to(torch.int64)
for j in range(c):
encoded[:, 0] |= data_tmp[:, j] << j
return encoded
def decode_channels_torch(data : torch.Tensor, num_channels : int):
# Assuming shape: [b, 1, *dims]
b, _, *dims = data.shape
decoded = torch.zeros((b, num_channels, *dims), dtype=bool, device=data.device)
data_tmp = data.to(torch.int64, non_blocking=True)
for j in range(num_channels):
decoded[:, j] = (data_tmp[:, 0] >> j) & 1
decoded = decoded.to(torch.float32, non_blocking=True)
return decoded
Hey, really cool stuff! I like the idea of bitwise encoding, but there is a much easier solution: just make the dtype of the regions bool. This is now the case in the current master (and batchgenerators 2 master) and should alleviate your issues!
Let's start by saying that I really like the new feature with the region-based training. However, I do notice a significant drop in performance (CPU-side) when enabling this feature. This is especially the case when creating regions with a large amount of labels (e.g. > 10). This creates a CPU-bottleneck quite fast. In my case time per epoch goes from ~120 seconds (~40 regions separately) to ~400 seconds (defining 3 regions with these labels) resulting in barely using the GPU and 100% CPU.
I was wondering if there was perhaps an easy (caching) option within the nnUNet framework to merge these labels based on the dataset.json file. Alternatively I could manually merge the labels with some external code, but since the labels are already processed by the pipeline, there might be a good option to merge these prior.