I would like to extract features from the final layer of encoder and visualize them. Therefore, I created a process that can extract features by executing main() of predict_from_raw_data.py.
I checked the operation with 12 images. Then, after the 10th image, the features were added twice.
Changing both 'num_processes_preprocessing=2, num_processes_segmentation_export=2' to 1 did not change the result.
Do you know what could be causing this?
thanks in advance
Output for the eighth and subsequent images↓
Predicting 00014:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
0%| | 0/1 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([8, 512, 7, 8])
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 79.07it/s]
sending off prediction to background worker for resampling and export
done with 00014
Predicting 00015:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
0%| | 0/1 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([9, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 147.96it/s]
sending off prediction to background worker for resampling and export
done with 00015
Predicting 00025:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
0%| | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([10, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([11, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 142.38it/s]
sending off prediction to background worker for resampling and export
done with 00025
Predicting 00026:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
0%| | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([12, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([13, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 125.74it/s]
sending off prediction to background worker for resampling and export
done with 00026
Predicting 00027:
perform_everything_on_device: True
in My_predict_logits_from_preprocessed_data
in My_predict_sliding_window_return_logits
in _My_internal_predict_sliding_window_return_logits
0%| | 0/2 [00:00<?, ?it/s]in _My_internal_maybe_mirror_and_predict
all_features torch.Size([14, 512, 7, 8])
in _My_internal_maybe_mirror_and_predict
all_features torch.Size([15, 512, 7, 8])
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 110.17it/s]
sending off prediction to background worker for resampling and export
done with 00027
my code
class nnUNetPredictor(object):
~~~~~~default codes~~~~~
########################################### My func start(Almost same as original code) ###########################################
def _My_internal_maybe_mirror_and_predict(self, x: torch.Tensor, all_features: List = None) -> torch.Tensor:
mirror_axes = self.allowed_mirroring_axes if self.use_mirroring else None
print("in _My_internal_maybe_mirror_and_predict")
# save = SaveFeatures(self.network,self.network.encoder.stages[6][0].convs[1])
def hook_fn(model, input, output):
all_features.append(output.cpu())
# print("network",self.network)
handle = self.network.encoder.stages[6][0].convs[1].register_forward_hook(hook_fn) # Set the hook on the last layer.
prediction = self.network(x)
handle.remove()
all_features = torch.cat(all_features, dim=0)
print("all_features",all_features.shape)
if mirror_axes is not None:
print("mirror_axes is not None")
# check for invalid numbers in mirror_axes
# x should be 5d for 3d images and 4d for 2d. so the max value of mirror_axes cannot exceed len(x.shape) - 3
assert max(mirror_axes) <= x.ndim - 3, 'mirror_axes does not match the dimension of the input!'
mirror_axes = [m + 2 for m in mirror_axes]
axes_combinations = [
c for i in range(len(mirror_axes)) for c in itertools.combinations(mirror_axes, i + 1)
]
for axes in axes_combinations:
prediction += torch.flip(self.network(torch.flip(x, axes)), axes)
prediction /= (len(axes_combinations) + 1)
return prediction
def _My_internal_predict_sliding_window_return_logits(self,
data: torch.Tensor,
slicers,
do_on_device: bool = True,
all_features: List = None
):
predicted_logits = n_predictions = prediction = gaussian = workon = None
results_device = self.device if do_on_device else torch.device('cpu')
print("in _My_internal_predict_sliding_window_return_logits")
try:
empty_cache(self.device)
# move data to device
if self.verbose:
print(f'move image to device {results_device}')
data = data.to(results_device)
# preallocate arrays
if self.verbose:
print(f'preallocating results arrays on device {results_device}')
predicted_logits = torch.zeros((self.label_manager.num_segmentation_heads, *data.shape[1:]),
dtype=torch.half,
device=results_device)
n_predictions = torch.zeros(data.shape[1:], dtype=torch.half, device=results_device)
if self.use_gaussian:
gaussian = compute_gaussian(tuple(self.configuration_manager.patch_size), sigma_scale=1. / 8,
value_scaling_factor=10,
device=results_device)
else:
gaussian = 1
if not self.allow_tqdm and self.verbose:
print(f'running prediction: {len(slicers)} steps')
for sl in tqdm(slicers, disable=not self.allow_tqdm):
workon = data[sl][None]
workon = workon.to(self.device)
prediction = self._My_internal_maybe_mirror_and_predict(workon,all_features)[0].to(results_device)
if self.use_gaussian:
prediction *= gaussian
predicted_logits[sl] += prediction
n_predictions[sl[1:]] += gaussian
predicted_logits /= n_predictions
# check for infs
if torch.any(torch.isinf(predicted_logits)):
raise RuntimeError('Encountered inf in predicted array. Aborting... If this problem persists, '
'reduce value_scaling_factor in compute_gaussian or increase the dtype of '
'predicted_logits to fp32')
except Exception as e:
del predicted_logits, n_predictions, prediction, gaussian, workon
empty_cache(self.device)
empty_cache(results_device)
raise e
return predicted_logits
def My_predict_sliding_window_return_logits(self, input_image: torch.Tensor, all_features: List = None) \
-> Union[np.ndarray, torch.Tensor]:
print("in My_predict_sliding_window_return_logits")
with torch.no_grad():
assert isinstance(input_image, torch.Tensor)
self.network = self.network.to(self.device)
self.network.eval()
empty_cache(self.device)
# Autocast can be annoying
# If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection)
# and needs to be disabled.
# If the device_type is 'mps' then it will complain that mps is not implemented, even if enabled=False
# is set. Whyyyyyyy. (this is why we don't make use of enabled=False)
# So autocast will only be active if we have a cuda device.
with torch.autocast(self.device.type, enabled=True) if self.device.type == 'cuda' else dummy_context():
assert input_image.ndim == 4, 'input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)'
if self.verbose:
print(f'Input shape: {input_image.shape}')
print("step_size:", self.tile_step_size)
print("mirror_axes:", self.allowed_mirroring_axes if self.use_mirroring else None)
# if input_image is smaller than tile_size we need to pad it to tile_size.
data, slicer_revert_padding = pad_nd_image(input_image, self.configuration_manager.patch_size,
'constant', {'value': 0}, True,
None)
slicers = self._internal_get_sliding_window_slicers(data.shape[1:])
if self.perform_everything_on_device and self.device != 'cpu':
# we need to try except here because we can run OOM in which case we need to fall back to CPU as a results device
try:
predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device,
all_features)
except RuntimeError:
print(
'Prediction on device was unsuccessful, probably due to a lack of memory. Moving results arrays to CPU')
empty_cache(self.device)
predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers, False, all_features)
else:
predicted_logits = self._My_internal_predict_sliding_window_return_logits(data, slicers,
self.perform_everything_on_device,
all_features)
empty_cache(self.device)
# revert padding
predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])]
return predicted_logits
def My_predict_logits_from_preprocessed_data(self, data: torch.Tensor, all_fratures: List = None) -> torch.Tensor:
"""
IMPORTANT! IF YOU ARE RUNNING THE CASCADE, THE SEGMENTATION FROM THE PREVIOUS STAGE MUST ALREADY BE STACKED ON
TOP OF THE IMAGE AS ONE-HOT REPRESENTATION! SEE PreprocessAdapter ON HOW THIS SHOULD BE DONE!
RETURNED LOGITS HAVE THE SHAPE OF THE INPUT. THEY MUST BE CONVERTED BACK TO THE ORIGINAL IMAGE SIZE.
SEE convert_predicted_logits_to_segmentation_with_correct_shape
"""
print("in My_predict_logits_from_preprocessed_data")
n_threads = torch.get_num_threads()
torch.set_num_threads(default_num_processes if default_num_processes < n_threads else n_threads)
prediction = None
for params in self.list_of_parameters:
# messing with state dict names...
if not isinstance(self.network, OptimizedModule):
self.network.load_state_dict(params)
else:
self.network._orig_mod.load_state_dict(params)
# why not leave prediction on device if perform_everything_on_device? Because this may cause the
# second iteration to crash due to OOM. Grabbing that with try except cause way more bloated code than
# this actually saves computation time
if prediction is None:
prediction = self.My_predict_sliding_window_return_logits(data,all_features).to('cpu')
else:
prediction += self.My_predict_sliding_window_return_logits(data,all_features).to('cpu')
if len(self.list_of_parameters) > 1:
prediction /= len(self.list_of_parameters)
if self.verbose: print('Prediction done')
torch.set_num_threads(n_threads)
return prediction
def My_predict_from_data_iterator(self,
data_iterator,
save_probabilities: bool = False,
num_processes_segmentation_export: int = default_num_processes,
all_features: List = None):
"""
each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys!
If 'ofile' is None, the result will be returned instead of written to a file
"""
print("in My_predict_from_data_iterator")
with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool:
worker_list = [i for i in export_pool._pool]
r = []
for preprocessed in data_iterator:
data = preprocessed['data']
if isinstance(data, str):
delfile = data
data = torch.from_numpy(np.load(data))
os.remove(delfile)
ofile = preprocessed['ofile']
if ofile is not None:
print(f'\nPredicting {os.path.basename(ofile)}:')
else:
print(f'\nPredicting image of shape {data.shape}:')
print(f'perform_everything_on_device: {self.perform_everything_on_device}')
properties = preprocessed['data_properties']
# let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with
# npy files
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
while not proceed:
sleep(0.1)
proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2)
prediction = self.My_predict_logits_from_preprocessed_data(data, all_features).cpu()
if ofile is not None:
# this needs to go into background processes
# export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager,
# self.dataset_json, ofile, save_probabilities)
print('sending off prediction to background worker for resampling and export')
r.append(
export_pool.starmap_async(
export_prediction_from_logits,
((prediction, properties, self.configuration_manager, self.plans_manager,
self.dataset_json, ofile, save_probabilities),)
)
)
else:
# convert_predicted_logits_to_segmentation_with_correct_shape(
# prediction, self.plans_manager,
# self.configuration_manager, self.label_manager,
# properties,
# save_probabilities)
print('sending off prediction to background worker for resampling')
r.append(
export_pool.starmap_async(
convert_predicted_logits_to_segmentation_with_correct_shape, (
(prediction, self.plans_manager,
self.configuration_manager, self.label_manager,
properties,
save_probabilities),)
)
)
if ofile is not None:
print(f'done with {os.path.basename(ofile)}')
else:
print(f'\nDone with image of shape {data.shape}:')
ret = [i.get()[0] for i in r]
if isinstance(data_iterator, MultiThreadedAugmenter):
data_iterator._finish()
# clear lru cache
compute_gaussian.cache_clear()
# clear device cache
empty_cache(self.device)
return ret
def My_predict_from_files(self,
list_of_lists_or_source_folder: Union[str, List[List[str]]],
output_folder_or_list_of_truncated_output_files: Union[str, None, List[str]],
save_probabilities: bool = False,
overwrite: bool = True,
num_processes_preprocessing: int = default_num_processes,
num_processes_segmentation_export: int = default_num_processes,
folder_with_segs_from_prev_stage: str = None,
num_parts: int = 1,
part_id: int = 0,
all_features: List = None):
"""
This is nnU-Net's default function for making predictions. It works best for batch predictions
(predicting many images at once).
"""
print("in My_predict_from_files")
if isinstance(output_folder_or_list_of_truncated_output_files, str):
output_folder = output_folder_or_list_of_truncated_output_files
elif isinstance(output_folder_or_list_of_truncated_output_files, list):
output_folder = os.path.dirname(output_folder_or_list_of_truncated_output_files[0])
else:
output_folder = None
########################
# let's store the input arguments so that its clear what was used to generate the prediction
if output_folder is not None:
my_init_kwargs = {}
for k in inspect.signature(self.predict_from_files).parameters.keys():
my_init_kwargs[k] = locals()[k]
my_init_kwargs = deepcopy(
my_init_kwargs) # let's not unintentionally change anything in-place. Take this as a
recursive_fix_for_json_export(my_init_kwargs)
maybe_mkdir_p(output_folder)
save_json(my_init_kwargs, join(output_folder, 'predict_from_raw_data_args.json'))
# we need these two if we want to do things with the predictions like for example apply postprocessing
save_json(self.dataset_json, join(output_folder, 'dataset.json'), sort_keys=False)
save_json(self.plans_manager.plans, join(output_folder, 'plans.json'), sort_keys=False)
#######################
# check if we need a prediction from the previous stage
if self.configuration_manager.previous_stage_name is not None:
assert folder_with_segs_from_prev_stage is not None, \
f'The requested configuration is a cascaded network. It requires the segmentations of the previous ' \
f'stage ({self.configuration_manager.previous_stage_name}) as input. Please provide the folder where' \
f' they are located via folder_with_segs_from_prev_stage'
# sort out input and output filenames
list_of_lists_or_source_folder, output_filename_truncated, seg_from_prev_stage_files = \
self._manage_input_and_output_lists(list_of_lists_or_source_folder,
output_folder_or_list_of_truncated_output_files,
folder_with_segs_from_prev_stage, overwrite, part_id, num_parts,
save_probabilities)
if len(list_of_lists_or_source_folder) == 0:
return
data_iterator = self._internal_get_data_iterator_from_lists_of_filenames(list_of_lists_or_source_folder,
seg_from_prev_stage_files,
output_filename_truncated,
num_processes_preprocessing)
return self.My_predict_from_data_iterator(data_iterator, save_probabilities, num_processes_segmentation_export, all_features)
########################################### My func end ###########################################
if __name__ == '__main__':
# predict a bunch of files
from nnunetv2.paths import nnUNet_results, nnUNet_raw
all_features = []
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=False,
perform_everything_on_device=True,
device=torch.device('cuda', 0),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)
predictor.initialize_from_trained_model_folder(
join(nnUNet_results, 'Dataset301_pretrain_removed/nnUNetTrainer__nnUNetPlans_forPretrain_removed_RFA4__2d'),
use_folds=(0,),
checkpoint_name='checkpoint_final.pth',
)
predictor.My_predict_from_files(join(nnUNet_raw, 'Dataset301_pretrain_removed/imagesTr_mini'),
join(nnUNet_raw, 'Dataset301_pretrain_removed/imagesTr_prediction'),
save_probabilities=False, overwrite=True,
num_processes_preprocessing=2, num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None, num_parts=1, part_id=0,
all_features=all_features)
I would like to extract features from the final layer of encoder and visualize them. Therefore, I created a process that can extract features by executing main() of predict_from_raw_data.py. I checked the operation with 12 images. Then, after the 10th image, the features were added twice. Changing both 'num_processes_preprocessing=2, num_processes_segmentation_export=2' to 1 did not change the result.
Do you know what could be causing this?
thanks in advance
Output for the eighth and subsequent images↓
my code