Open aeisenbarth opened 10 months ago
Hi @aeisenbarth, thanks for the feature request. The functionality that you described is already implemented with the rasterize()
function, with a few comments to mention on this.
rasterize()
function more discoverable, both in the docstring and in the notebooksrasterize()
doesn't allow to specify the fill value (currently 0), but this is a minor tweakrasterize()
also requires the user to specify the target resolution, we could add a default value for rasterize which keeps the "native" resolution of the input image.Translation
should not be a problem because it would be added to both labels (or would it?).Reading your code I confirm that, except for the minor tweaks linked to my comments above, the rasterize function should cover your cases. Nevertheless, your tests are more structured and better deal with the various transformations. Could you please check out if the rasterize()
function fits your requirements and if it makes pass your tests? In such a case you could make a small PR to add your tests it would be fantastic, thanks!
Use case
As a user, I want to do an image operation (add/subtract, intersect labels, etc.) which requires two images with matching shapes, but they stem from SpatialData elements from different coordinate systems and scale.
Example
Using labels for demonstration since images have strong interpolation effects.
labels1_cropped
is actually:That means I cannot immediately do e.g.
labels1_cropped + labels2_cropped
.Feature request
The current design focusses on "query" as "filtering". Maybe this feature request fits better into the task of transforming elements (transform to coordinate system + bounding box). Also, a user may prefer to apply the transformation and cropping/expansion to an single already selected element, not at the SpatialData level to all elements.
There are two missing pieces:
transform
prefers to ignore the translation component so that the transformed image still has Translation instead of an Identity.Requirements
Code
I have tried several approaches:
Code
```python from typing import Optional, Union import numpy as np import xarray as xr from dask_image.ndinterp import affine_transform from datatree import DataTree from multiscale_spatial_image import MultiscaleSpatialImage from skimage.transform import AffineTransform from spatial_image import SpatialImage from spatialdata._types import ArrayLike from spatialdata._utils import iterate_pyramid_levels from spatialdata.models import ( Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, get_axes_names, get_model, ) from spatialdata.models._utils import DEFAULT_COORDINATE_SYSTEM from spatialdata.transformations import ( Affine, Scale, Sequence as SequenceTransform, Translation, get_transformation, ) C = "c" YX = ("y", "x") CYX = ("c", "y", "x") def transform_image_to_bounding_box( image: Union[SpatialImage, MultiscaleSpatialImage, str], min_coordinate: ArrayLike, max_coordinate: ArrayLike, coordinate_system: Optional[str] = None, **kwargs, ) -> SpatialImage: """ Raster-transform a SpatialImage to a defined bounding box in another coordinate system. The input image's transformation and location in space is taken into account. The output shape is guaranteed. If the image is smaller or not fully overlapping the target bounding box, missing parts are filled up with zeroes. Args: image: A SpatialImage min_coordinate The minimum coordinates of the bounding box. max_coordinate The maximum coordinates of the bounding box. coordinate_system: Name of the coordinate system to which the bounding box refers to. Defaults to the global coordinate system. kwargs: Keyword arguments for Dask/scipy.ndimage affine_transform Returns: A transformed SpatialImage with the requested shape and transformation """ image = ensure_singlescale_spatial_image(image) output_transformation = Translation(list(min_coordinate), YX) output_shape = tuple(np.ceil(np.asarray(max_coordinate) - min_coordinate).astype(int)) if coordinate_system is None: coordinate_system = DEFAULT_COORDINATE_SYSTEM dims = get_axes_names(image) image = ensure_singlescale_spatial_image(image) schema = get_model(image) # Labels need to be preserved after resizing of the image if schema in (Labels2DModel, Labels3DModel): kwargs = {"prefilter": False, "order": 0, **kwargs} elif schema in (Image2DModel, Image3DModel): kwargs = {**kwargs} else: raise ValueError(f"Unsupported schema {schema}") if len(output_shape) < image.ndim: output_shape = image.shape[: -len(output_shape)] + output_shape # Compute transformation from image to target coordinate system image_transform = get_transformation(image, to_coordinate_system=coordinate_system) transformation = SequenceTransform([image_transform, output_transformation.inverse()]) transformation_matrix = transformation.inverse().to_affine_matrix( input_axes=dims, output_axes=dims ) image_array = image.transpose(*dims).data # Transform the raster data image_array_transformed = affine_transform( image_array, matrix=transformation_matrix, output_shape=output_shape, **kwargs ) # Construct a new image instance c_coords = image.indexes[C].values if C in image.indexes else None image_transforms = get_transformation(image, get_all=True) transformations = { k: SequenceTransform([transformation.inverse(), v]) for k, v in image_transforms.items() } return schema.parse( image_array_transformed, dims=dims, c_coords=c_coords, transformations=transformations ) def ensure_singlescale_spatial_image( image: Union[SpatialImage, xr.DataArray, MultiscaleSpatialImage, DataTree] ) -> "SpatialImage": """ Convert a MultiscaleSpatialImage to SpatialImage """ if isinstance(image, (MultiscaleSpatialImage, DataTree)): return SpatialImage(next(iterate_pyramid_levels(image))) return image ```Tests
```python import pytest @pytest.fixture def spatialdata_image(request: "_pytest.fixtures.SubRequest") -> SpatialImage: kwargs = request.param w = kwargs.get("width", 10) h = kwargs.get("height", 10) image_translation = list(kwargs.get("translation", [0, 0])) image_rotation = kwargs.get("rotation", 0) image_scale = list(kwargs.get("scale", [1, 1])) rotation = Affine( AffineTransform(rotation=image_rotation).params, input_axes=YX, output_axes=YX ) return Image2DModel.parse( data=np.arange(1, w * h + 1).reshape((1, h, w)), dims=CYX, transformations={ DEFAULT_COORDINATE_SYSTEM: SequenceTransform( [Scale(image_scale, axes=YX), rotation, Translation(image_translation, axes=YX)] ).to_affine(YX, YX) }, ) @pytest.mark.parametrize( ("spatialdata_image", "bounding_box", "expected"), [ # No transformation, bounding box matches image ( dict(width=3, height=3, translation=[0, 0]), ((0, 0), (3, 3)), np.array([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]), ), # Cropping to smaller bounding box ( dict(width=3, height=3, translation=[0, 0]), ((0, 1), (3, 3)), np.array([[[2, 3], [5, 6], [8, 9]]]), ), # Cropping to larger bounding box ( dict(width=3, height=3, translation=[0, 0]), ((0, -1), (3, 4)), np.array([[[0, 1, 2, 3, 0], [0, 4, 5, 6, 0], [0, 7, 8, 9, 0]]]), ), # Image with transformation, bounding box with offset ( dict(width=3, height=3, translation=[1, 0]), ((0, 1), (3, 4)), np.array([[[0, 0, 0], [2, 3, 0], [5, 6, 0]]]), ), # Scaled image ( dict(width=3, height=3, scale=[2, 2], translation=[0, 0]), ((0, 0), (3, 3)), np.array([[[1, 2, 2], [4, 5, 5], [4, 5, 5]]]), ), # Scaled image to bounding box with offset ( dict(width=3, height=3, scale=[2, 2]), ((0, 1), (3, 4)), np.array([[[2, 2, 3], [5, 5, 6], [5, 5, 6]]]), ), # Rotated image ( dict(width=4, height=4, rotation=np.pi / 2, translation=[-0.5, -0.5]), ((-3.5, 0.5), (0.5, 3.5)), np.array([[[8, 12, 16], [7, 11, 15], [6, 10, 14], [5, 9, 13]]]), ), ], indirect=["spatialdata_image"], ) def test_transform_image_to_bounding_box(spatialdata_image, bounding_box, expected): actual = transform_image_to_bounding_box( image=spatialdata_image, min_coordinate=bounding_box[0], max_coordinate=bounding_box[1], prefilter=False, order=0, ) assert actual.shape[-2:] == tuple(np.asarray(bounding_box[1]) - bounding_box[0]) np.testing.assert_allclose( get_transformation(actual).to_affine_matrix(YX, YX), Translation(list(bounding_box[0]), YX).to_affine_matrix(YX, YX), ) np.testing.assert_allclose(actual.data.compute(), expected) ```