Sllambias / yucca

Apache License 2.0
17 stars 2 forks source link

implementation of test preprocessing into modules and pipeline #197

Closed Sllambias closed 4 weeks ago

asbjrnmunk commented 4 weeks ago

I propose adding the following or similar code to the on_before_batch_transfer to get rid of the patch size argument:

def on_before_batch_transfer(self, batch, dataloader_idx):
        if self.trainer.predicting is True:
            if self.disable_inference_preprocessing is False:
                batch["data"], batch["data_properties"] = self.preprocessor.preprocess_case_for_inference(
                    images=batch["data_paths"],
                    patch_size=self.patch_size,
                    ext=batch["extension"],
                    sliding_window_prediction=self.sliding_window_prediction,
                )
            else:
                batch["data"], batch["data_properties"] = ensure_batch_fits_patch_size(batch, patch_size=self.patch_size)
...
def ensure_batch_fits_patch_size(batch, patch_size):
    """
    Pads the spatial dimensions of the input tensor so that they are at least the size of the patch dimensions.
    If all spatial dimensions are already larger than or equal to the patch size, the input tensor is returned unchanged.

    Parameters:
    - batch: dict
        a dict with keys {"data": data, "data_properties": data_properties, "case_id": case_id},
        where data is a Tensor of shape (B, C, *spatial_dims)

    - patch_size: tuple of ints
        The minimum desired size for each spatial dimension.

    Returns:
    - padded_input: torch.Tensor
        The input tensor padded to the desired spatial dimensions.
    """
    image = batch["data"]
    image_properties = batch["data_properties"]

    spatial_dims = image.dim() - 2  # Subtract batch and channel dimensions

    if spatial_dims != len(patch_size):
        raise ValueError("Input spatial dimensions and patch size dimensions do not match.")

    current_sizes = image.shape[2:]  # Spatial dimensions

    current_sizes_tensor = torch.tensor(current_sizes)
    patch_size_tensor = torch.tensor(patch_size)

    if torch.any(current_sizes_tensor < patch_size_tensor).item():
        return image, image_properties

    pad_sizes = torch.clamp(patch_size_tensor - current_sizes_tensor, min=0)
    pad_left = pad_sizes // 2
    pad_right = pad_sizes - pad_left

    # Construct padding tuple in reverse order for F.pad
    padding_reversed = []
    for left, right in zip(reversed(pad_left.tolist()), reversed(pad_right.tolist())):
        padding_reversed.extend([left.item(), right.item()])

    padded_input = F.pad(image, padding_reversed)

    image_properties["padded_shape"] = np.array(image.shape)
    image_properties["padding"] = list(reversed(padding_reversed))

    return padded_input, image_properties