I propose adding the following or similar code to the on_before_batch_transfer to get rid of the patch size argument:
def on_before_batch_transfer(self, batch, dataloader_idx):
if self.trainer.predicting is True:
if self.disable_inference_preprocessing is False:
batch["data"], batch["data_properties"] = self.preprocessor.preprocess_case_for_inference(
images=batch["data_paths"],
patch_size=self.patch_size,
ext=batch["extension"],
sliding_window_prediction=self.sliding_window_prediction,
)
else:
batch["data"], batch["data_properties"] = ensure_batch_fits_patch_size(batch, patch_size=self.patch_size)
...
def ensure_batch_fits_patch_size(batch, patch_size):
"""
Pads the spatial dimensions of the input tensor so that they are at least the size of the patch dimensions.
If all spatial dimensions are already larger than or equal to the patch size, the input tensor is returned unchanged.
Parameters:
- batch: dict
a dict with keys {"data": data, "data_properties": data_properties, "case_id": case_id},
where data is a Tensor of shape (B, C, *spatial_dims)
- patch_size: tuple of ints
The minimum desired size for each spatial dimension.
Returns:
- padded_input: torch.Tensor
The input tensor padded to the desired spatial dimensions.
"""
image = batch["data"]
image_properties = batch["data_properties"]
spatial_dims = image.dim() - 2 # Subtract batch and channel dimensions
if spatial_dims != len(patch_size):
raise ValueError("Input spatial dimensions and patch size dimensions do not match.")
current_sizes = image.shape[2:] # Spatial dimensions
current_sizes_tensor = torch.tensor(current_sizes)
patch_size_tensor = torch.tensor(patch_size)
if torch.any(current_sizes_tensor < patch_size_tensor).item():
return image, image_properties
pad_sizes = torch.clamp(patch_size_tensor - current_sizes_tensor, min=0)
pad_left = pad_sizes // 2
pad_right = pad_sizes - pad_left
# Construct padding tuple in reverse order for F.pad
padding_reversed = []
for left, right in zip(reversed(pad_left.tolist()), reversed(pad_right.tolist())):
padding_reversed.extend([left.item(), right.item()])
padded_input = F.pad(image, padding_reversed)
image_properties["padded_shape"] = np.array(image.shape)
image_properties["padding"] = list(reversed(padding_reversed))
return padded_input, image_properties
I propose adding the following or similar code to the on_before_batch_transfer to get rid of the patch size argument: