sjrothfuss / torch-rotate-image

Rotate images with torch
BSD 3-Clause "New" or "Revised" License
0 stars 0 forks source link

Possible 2d rotation error #4

Closed sjrothfuss closed 2 months ago

sjrothfuss commented 3 months ago

Description

test_rotate_image_2d_rotation() is failing because it appears that that is an error with rotation. E.g. I would expect that applying a 180° rotation would produce a flipped image but it does not. This may be an error in my interpretation of the code, I need to get to the bottom of it still. My current guess is that it has to do with the torch_image_lerp.sample_image_2d call in rotate_image_2d.

What I Did

This is from the contents of tests.test_rotate_image_2d_rotation:

image = torch.zeros(28, 28)
image[:14, :] = 1  # image is half black, half white
angles = torch.tensor([0.0, 90.0, 180.0])
rotated_image = rotate_image_2d(image, angles) # (3, 28, 28)

image image

rotated_image[0] image

rotated_image[1] image

rotated_image[2] image

sjrothfuss commented 3 months ago

Reopening to track progress, I'll return to this when I can. PR #5 got us closer, still not quite done. Artifacts appear on both sides when using a odd-number-sized image.

Updated code from test_rotate_image_2d_rotation:

image = torch.linspace(0.5, 1, steps=28).repeat(28, 1)
rotated_image = rotate_image_2d(image=image, angles=180.0)

image:

image

rotated_image (note the different scale):

image
alisterburt commented 3 months ago

remember there will always be artefacts

def dft_center(
    image_shape: tuple[int, ...],
    rfft: bool,
    fftshifted: bool,
    device: torch.device | None = None,
) -> torch.LongTensor:
    """Return the position of the DFT center in an fftshifted DFT for a given input shape."""
    fft_center = torch.zeros(size=(len(image_shape),), device=device)
    image_shape = torch.as_tensor(image_shape).float()
    if rfft is True:
        image_shape = torch.tensor(rfft_shape(image_shape))
    if fftshifted is True:
        fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
    if rfft is True:
        fft_center[-1] = 0
    return fft_center.long()
alisterburt commented 3 months ago

this looks like it works as expected to me - everything sampled out of bounds is zero, everything in bounds the values go from 0.5 -> 1

sjrothfuss commented 2 months ago

Thanks, @alisterburt, couple questions.

When you say:

  • we might want to update the rotation center code to explicitly be the position of the DC component in an fftshifted DFT

Do you mean to use dft_center() to define the center of the image? What's the advantage to that? I'm missing how it yields anything different than (h//2, w//2) like we have now. Granted, I don't understand the rfft command, see below.


rfft_shape is not defined. Should image_shape defined here:

if rfft is True:
   image_shape = torch.tensor(rfft_shape(image_shape))

instead be something like torch.tensor(torch.fft.rfft(image_shape).shape)? That doesn't yield proper results either but I'm not sure what else it would be. 😃


this looks like it works as expected to me - everything sampled out of bounds is zero, everything in bounds the values go from 0.5 -> 1

I was concerned that we're "losing" one row and column off the edge to the bottom right. I guess I assumed we'd get a perfect flip with 180° rotation but you're right, there'll always be artifacts and I can see that this could be "close enough" for most uses. I trust your assessment of what is acceptable for subsequent applications and what is reasonable to achieve with our tools.

alisterburt commented 2 months ago

Do you mean to use dft_center() to define the center of the image? What's the advantage to that?

Yes, the advantag is that it's being very explicit about what our convention for the rotation center is 🙂

rfft_shape is not defined

sorry, rfft_shape is here - worth checking it against a few values

def rfft_shape(input_shape: Sequence[int]) -> Tuple[int]:
    """Get the output shape of an rfft on an input with input_shape."""
    rfft_shape = list(input_shape)
    rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1)
    return tuple(rfft_shape)

I was concerned that we're "losing" one row and column off the edge to the bottom right

Right,

sjrothfuss commented 2 months ago

rfft_shape appears to return the correct values but the way _get_dft_center handles those values is incorrect – at least for real space images. Is the intent that rfft=True is for use when rotating Fourier-space images or is this a bug?

>>> _rfft_shape((10,10)) == torch.fft.rfft(torch.rand(10,10)).shape == (10,6)
True
>>> _get_dft_center((10,10), rfft=True)
tensor([5, 0])
>>> _get_dft_center((10,10), rfft=False)
tensor([5,5])

Thanks.

alisterburt commented 2 months ago

@sjrothfuss heh - good question!

as a result we want to use rfft=False here to get the position of the DC component in the full size FFT

alisterburt commented 2 months ago

this function was not written with rotations in mind, just does what it says on the tin 🙂

alisterburt commented 2 months ago

closing