kornia / kornia

Geometric Computer Vision Library for Spatial AI
https://kornia.readthedocs.io
Apache License 2.0
9.97k stars 971 forks source link

Elastic transform offset generation maybe inefficient #994

Open ChristophReich1996 opened 3 years ago

ChristophReich1996 commented 3 years ago

❓ Questions and Help

I'm very happy to see the elastic transform augmentation added to Kornia 🚀, however, the current implementation might be somewhat inefficient for large images. In the current implementation, a noise map both for the x and y deformations is smoothed by a gaussian filter and then applied to the image to be augmented (correct me if I'm wrong). For large images and large values of sigma, the two convolutions are getting expensive. In the elasticdeform package, the deformation offsets are generated differently, namely by generating a random coarse displacement grid which is upsampled to the size of the image to be augmented. From my understanding, this approach may constrain the variation of the offsets (limited to the coarse grid) but is highly more computationally efficient than the current approach.🤔

shijianjian commented 3 years ago

@ChristophReich1996 Thank you for reporting. I am wondering if it possible for you to make a benchmark between these two implementations.

If the performance boost is significant, we are happy to improve it.

ChristophReich1996 commented 3 years ago

Hey @shijianjian, I implemented a quick and dirty script to benchmark both implementations. The current Kornia implementation took ~590ms and the PyTorch reference implementation based on the elasticdeform package took ~1ms, for augmentation. The augmentation was applied to a grayscale image with a resolution of 2048 X 2048, additional parameters can be seen in the code. The benchmark was performed on a Nvidia 2080 Ti with PyTorch 1.7.1, CUDA 11.1, and Kornia 0.5.1. The image file was taken form the Cell-DETR repo.

Cheers Christoph

Image Image augmented (Kornia) Image augmented (ref.)
1 1 1
"""
This script tests the performance of the Kornia elastic transform implementation against the approach of the elasticdeform
package reimplemented in PyTorch/Kornia.
Source of the kornia implementation: https://github.com/kornia/kornia/blob/3606cf9c3d1eb3aabd65ca36a0e7cb98944c01ba/kornia/geometry/transform/elastic_transform.py
"""

from typing import Tuple

import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import kornia
from kornia.filters.kernels import get_gaussian_kernel2d

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def elastic_transform2d_kornia(image: torch.Tensor,
                               noise: torch.Tensor,
                               kernel_size: Tuple[int, int] = (3, 3),
                               sigma: Tuple[float, float] = (4., 4.),
                               alpha: Tuple[float, float] = (32., 32.),
                               align_corners: bool = False,
                               mode: str = 'bilinear') -> torch.Tensor:
    """
    Source: https://github.com/kornia/kornia/blob/3606cf9c3d1eb3aabd65ca36a0e7cb98944c01ba/kornia/geometry/transform/elastic_transform.py
    """
    # Get Gaussian kernel for 'y' and 'x' displacement
    kernel_x: torch.Tensor = get_gaussian_kernel2d(kernel_size, (sigma[0], sigma[0]))[None]
    kernel_y: torch.Tensor = get_gaussian_kernel2d(kernel_size, (sigma[1], sigma[1]))[None]

    # Convolve over a random displacement matrix and scale them with 'alpha'
    disp_x: torch.Tensor = noise[:, :1]
    disp_y: torch.Tensor = noise[:, 1:]

    disp_x = kornia.filters.filter2D(disp_x, kernel=kernel_y, border_type='constant') * alpha[0]
    disp_y = kornia.filters.filter2D(disp_y, kernel=kernel_x, border_type='constant') * alpha[1]

    # stack and normalize displacement
    disp = torch.cat([disp_x, disp_y], dim=1).permute(0, 2, 3, 1)

    # Warp image based on displacement matrix
    b, c, h, w = image.shape
    grid = kornia.utils.create_meshgrid(h, w, device=image.device).to(image.dtype)
    warped = F.grid_sample(
        image, (grid + disp).clamp(-1, 1), align_corners=align_corners, mode=mode)

    return warped

def elastic_transform2d_elasticdeform(image: torch.Tensor,
                                      course_grid: torch.Tensor,
                                      align_corners: bool = False,
                                      mode: str = 'bilinear') -> torch.Tensor:
    """
    PyTorch implementation of the elastic transform approach of the elasticdeform package.
    """
    # Get image shape
    b, c, h, w = image.shape
    # Upsample course grid to size of image
    disp: torch.Tensor = F.interpolate(course_grid, size=(h, w), mode="bilinear", align_corners=True)
    # Reshape displacements
    disp = disp.permute(0, 2, 3, 1).view(b, h, w, 2)
    # Warp image based on displacement matrix
    grid = kornia.utils.create_meshgrid(h, w, device=image.device).to(image.dtype)
    warped = F.grid_sample(
        image, (grid + disp).clamp(-1, 1), align_corners=align_corners, mode=mode)
    return warped

if __name__ == '__main__':
    # Set device
    device = "cuda:0"
    # Load image and upscale image
    image = torch.load("1.pt")[None, None].float().to(device)
    image = F.interpolate(image, size=(2048, 2048), mode="bilinear", align_corners=False)

    # Apply kornia elastic transformation
    noise = 2. * torch.rand(1, 2, 2048, 2048).to(device) - 1.
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    image_kornia = elastic_transform2d_kornia(image=image,
                                              noise=noise,
                                              alpha=(1, 1),
                                              kernel_size=(63, 63),
                                              sigma=(32, 32))
    end.record()
    torch.cuda.synchronize()
    print("Kornia runtime:", start.elapsed_time(end), "ms")

    # Apply reference implementation
    course_grid = (2. * torch.rand(1, 2, 32, 32).to(device) - 1.) * 0.025
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    image_elasticdeform = elastic_transform2d_elasticdeform(image=image, course_grid=course_grid)
    end.record()
    torch.cuda.synchronize()
    print("Reference implementation runtime: ", start.elapsed_time(end), "ms")

    plt.imshow(image[0, 0].cpu().detach(), cmap="gray")
    plt.show()
    plt.imshow(image_kornia[0, 0].cpu().detach(), cmap="gray")
    plt.show()
    plt.imshow(image_elasticdeform[0, 0].cpu().detach(), cmap="gray")
    plt.show()
Kornia runtime: 590.91162109375 ms
Reference implementation runtime:  1.0226240158081055 ms
edgarriba commented 3 years ago

I would say that given those numbers - let's go for this update. The only concern here is that we'll break backward compatibility in terms of api which I believe it was originally designed to optimize the sigma of the gaussian /cc @IssamLaradji

shijianjian commented 3 years ago

Is it possible to keep both implementations? Just to have another parameter to state the method name. Also, include a user warning that we gonna take the faster version as the default in the future.

ChristophReich1996 commented 3 years ago

I could open a pull request, if desired, with a clean implementation.

edgarriba commented 3 years ago

that would be great

edgarriba commented 3 years ago

I was planning to add it too as an augmentation layer but I'll wait until we refactor

shijianjian commented 3 years ago

Hi @ChristophReich1996, how's the refactor going?

ChristophReich1996 commented 3 years ago

Hey @shijianjian, I still need to adapt the tests. Should be finished tomorrow I guess :)

edgarriba commented 3 years ago

@ChristophReich1996 any progress on that ? do you think you could make by the end of next week for our bi-weekly release ?