ZFTurbo / volumentations

Library for 3D augmentations
MIT License
224 stars 36 forks source link

Rotation operation makes 3D matrix shifts in undesired ways because of the recenter function. #8

Open flora-sun-zhixin opened 2 years ago

flora-sun-zhixin commented 2 years ago

Problem description

After I tried to rotate my 3D matrix in shape of 512 512 10 on z plane, where on each img[:, :, i] we should just expect to observe rotation and nothing else. But the last layer img[:, :, 9] lost all the info. So it would generate 9 slices with the original(actually, it is interpolated) info and the 10th be a constant plane if the mode is "constant".

The issue behind this

Here is where I think triggered this issue:
In the functionals.py file, In your rotate3D, you called generate_coords which maps the center of the 3D matrix to be (0, 0, 0), using

coords[d] -= ((np.array(shape).astype(float) - 1) / 2)[d]

Then in your recenter_coords function which maps the top left peak of the 3D matrix to be (0, 0, 0), using

coords[d] += int(np.round(coords.shape[d+1]/2))

this -1 difference making the z coords for img[:, :, 9] change from 9 to 9.5. So it got interpolated, which is not supposed to happen.

Details:

volumentations: I cloned your code to my laptop because of the resize issue. So in the import below, I imported the local volumentations.
python: 3.7.4 numpy: 1.21.5

How to regenerate the situation:

from vp.volumentations import * # where I set the code of your 3D data aug code
import numpy as np
import random
import matplotlib.pyplot as plt

# generate some pattern
img = np.arange(512)
img = np.repeat(img, 512).reshape(512, 512)
img = np.repeat(img[..., np.newaxis], 10, axis=-1)
print("The shape of image is :", img.shape) # 512, 512, 10

# just rotate the 512*512*512 matrix on the z plane
def get_augmentation(patch_size):
    return Compose([
        Rotate((0, 0), (0, 0), (10, 10), p=1, value=1, interpolation=1),
    ], p=1.0)

# make the rotation, just on z plane. so z coord remains. and x, y coords rotate.
aug = get_augmentation((512, 512, 10))
np.random.seed(1234)
random.seed(1234)
data = {"image": img}
augData = aug(**data)
augImg = augData["image"]

# visualize the result
mats = [img, augImg]
modes = ["Original", "After Rotation"]
for i in range(2):
    fig, axes = plt.subplots(2, 5, figsize=(8, 3))
    mat = mats[i]
    mode = modes[i]
    for j in range(2):
        for k in range(5):
            ax = axes[j, k] 
            ax.imshow(mat[:, :, j * 5 + k], cmap="hot", interpolation="none") 
            ax.axis("off")
            ax.set_title(f"slice {j * 5 + k} -- {mode}", fontsize=9)
    plt.tight_layout()
plt.show()

BTW I like how u code and thanks for sharing.