albumentations-team / albumentations

Fast and flexible image augmentation library. Paper about the library: https://www.mdpi.com/2078-2489/11/2/125
https://albumentations.ai
MIT License
14.06k stars 1.64k forks source link

Grid-based Elastic Transformation #1865

Closed 4pygmalion closed 1 month ago

4pygmalion commented 1 month ago

Feature description

I would like to propose the addition of a grid-based Elastic Transformation to the Albumentations library. This feature applies elastic transformations to an image on a grid-by-grid basis, rather than applying a single elastic transformation to the entire image. The grid-based approach allows for more localized distortions, which can enhance data augmentation processes by simulating more realistic variations in the images.

Motivation and context

his feature is particularly important for medical imaging applications, where localized distortions can simulate different patient anatomies more effectively than global transformations. In medical imaging, the precision and realism of data augmentations are crucial for training robust models. The grid-based elastic transformation aligns well with Albumentations' objective of providing diverse and powerful augmentation techniques.

By introducing grid-based elastic transformations, we can achieve more granular control over the augmentation process, thereby creating more varied and realistic training data. This improvement is expected to enhance the performance of models, especially in tasks such as tumor detection, organ segmentation, and other medical image analysis applications.

I have implemented this algorithm based on Augmentor (https://augmentor.readthedocs.io/en/stable/userguide/mainfeatures.html#elastic-distortions)

Possible implementation

class GranularElasticDeform(A.DualTransform):
    """Elastic deformation Albumentation implementation

    This class applies elastic transformations on a grid-based approach,
    where the granularity of the distortions can be controlled using the
    width and height of the overlaying distortion grid. Larger grid sizes
    result in finer, less severe distortions.

    Original source:
        https://github.com/mdbloice/Augmentor/blob/master/Augmentor/Operations.py#L1355
    """

    def __init__(
        self,
        n_grid_width: int,
        n_grid_height: int,
        magnitude: int,
        p: float = 1.0,
        always_apply: bool | None = None,
    ):
        """
        Params:
            n_grid_width (int): Number of grid cells along the width
            n_grid_height (int): Number of grid cells along the height
            magnitude (int): Magnitude of the distortions
            p (float): Probability of applying the transform
            always_apply (bool): Whether to always apply the transform
        """
        super().__init__(p=p, always_apply=always_apply)
        self.n_grid_width = n_grid_width
        self.n_grid_height = n_grid_height
        self.magnitude = abs(magnitude)

    def calculate_dimensions(
        self,
        width_of_square,
        height_of_square,
        width_of_last_square,
        height_of_last_square,
    ):
        dimensions = []
        for vertical_tile in range(self.n_grid_width):
            for horizontal_tile in range(self.n_grid_height):
                x1 = horizontal_tile * width_of_square
                y1 = vertical_tile * height_of_square
                x2 = x1 + (
                    width_of_last_square
                    if horizontal_tile == self.n_grid_height - 1
                    else width_of_square
                )
                y2 = y1 + (
                    height_of_last_square
                    if vertical_tile == self.n_grid_width - 1
                    else height_of_square
                )
                dimensions.append([x1, y1, x2, y2])

        return dimensions

    def calculate_polygons(self, dimensions, horizontal_tiles, vertical_tiles):
        polygons = []
        for x1, y1, x2, y2 in dimensions:
            polygons.append([x1, y1, x1, y2, x2, y2, x2, y1])

        last_column = [
            (horizontal_tiles - 1) + horizontal_tiles * i for i in range(vertical_tiles)
        ]
        last_row = range(
            (horizontal_tiles * vertical_tiles) - horizontal_tiles,
            horizontal_tiles * vertical_tiles,
        )

        polygon_indices = []
        for i in range((vertical_tiles * horizontal_tiles) - 1):
            if i not in last_row and i not in last_column:
                polygon_indices.append(
                    [i, i + 1, i + horizontal_tiles, i + 1 + horizontal_tiles]
                )

        for a, b, c, d in polygon_indices:
            dx = random.randint(-self.magnitude, self.magnitude)
            dy = random.randint(-self.magnitude, self.magnitude)

            polygons[a][4] += dx
            polygons[a][5] += dy
            polygons[b][2] += dx
            polygons[b][3] += dy
            polygons[c][6] += dx
            polygons[c][7] += dy
            polygons[d][0] += dx
            polygons[d][1] += dy

        return polygons

    def generate_mesh(self, polygons, dimensions):
        return [[dimensions[i], polygons[i]] for i in range(len(dimensions))]

    def distort_image(self, image: np.ndarray, generated_mesh: List[List]):
        image = Image.fromarray(image)
        return np.array(
            image.transform(
                image.size, Image.MESH, generated_mesh, resample=Image.BICUBIC
            )
        )

    def get_params_dependent_on_data(
        self, params: Dict[str, Any], data: dict[str, Any]
    ) -> Dict[str, Any]:

        img = data["image"]
        h, w = img.shape[:2]

        horizontal_tiles = self.n_grid_width
        vertical_tiles = self.n_grid_height

        width_of_square = int(w / horizontal_tiles)
        height_of_square = int(h / vertical_tiles)

        width_of_last_square = w - (width_of_square * (horizontal_tiles - 1))
        height_of_last_square = h - (height_of_square * (vertical_tiles - 1))

        dimensions = self.calculate_dimensions(
            width_of_square,
            height_of_square,
            width_of_last_square,
            height_of_last_square,
        )
        polygons = self.calculate_polygons(dimensions, horizontal_tiles, vertical_tiles)
        generated_mesh = self.generate_mesh(polygons, dimensions)

        return {"generated_mesh": generated_mesh}

    def apply(self, img, generated_mesh, **params):
        return self.distort_image(img, generated_mesh)

    def apply_to_mask(self, mask, generated_mesh, **params):
        return self.distort_image(mask, generated_mesh)

    def get_transform_init_args_names(self):
        return ("n_grid_width", "n_grid_height", "magnitude")

Additional context

ternaus commented 1 month ago

Added