ClementPinard / Pytorch-Correlation-extension

Custom implementation of Corrleation Module
MIT License
413 stars 77 forks source link

Issue with Computing Complete Pixel-to-Pixel Correlation #103

Closed jasam-sheja closed 11 months ago

jasam-sheja commented 11 months ago

I'm attempting to calculate the correlation between all pixels of one input and all pixels of the second input. Given two HxW single-channel inputs, the expected behavior is that for a patch size of 1 (a one-pixel neighbourhood), the correlation should be the product of each pixel from the first input with each pixel from the second. Essentially, this should result in the cross-product of the flattened inputs, which can be visualized as an HWxHW matrix.

However, the result is not as expected when adapting the example provided in the readme file. Specifically, the outcome appears to be the diagonal of the expected cross-product, replicated across rows.

Code

import torch
from spatial_correlation_sampler import spatial_correlation_sample 

device = "cpu"
batch_size = 1
channel = 1
H = 2
W = 2
dtype = torch.float32

# seed
torch.manual_seed(42)

input1 = torch.randint(1, 5, (batch_size, channel, H, W), dtype=dtype, device=device)
input2 = torch.randint_like(input1, 5, 9)

# Complete correlation
out = spatial_correlation_sample(input1,
                                input2,
                                kernel_size=1,
                                patch_size=(H,W),
                                stride=1,
                                padding=0,
                                dilation=0,
                                dilation_patch=0)
# cross product
cross = torch.einsum("bchw,bcij->bijhw", input1, input2)
print(f'\ninput1:\n{input1[0,0]}')
print(f'\ninput2:\n{input2[0,0]}')
print(f'\nComplete correlation:\n{out[0].view(H*W,H*W)}')
print(f'\nCross product:\n{cross[0].view(H*W,H*W)}')

Output:

input1:
tensor([[3., 4.],
        [1., 3.]])

input2:
tensor([[7., 8.],
        [5., 5.]])

Complete correlation:
tensor([[21., 32.,  5., 15.],
        [21., 32.,  5., 15.],
        [21., 32.,  5., 15.],
        [21., 32.,  5., 15.]])

Cross product:
tensor([[21., 28.,  7., 21.],
        [24., 32.,  8., 24.],
        [15., 20.,  5., 15.],
        [15., 20.,  5., 15.]])

Would appreciate any guidance or corrections to achieve the intended result.

ClementPinard commented 11 months ago

Hi, you are making two mistakes here:

  1. the dilation parameters is similar to the stride, the default value is 1 and adding extra dilation is putting a parameter above 1. With a dilation of 0, the second input is never shifted, hence the values that are all the same wrt to columns.
  2. you assume that the shift is only happening down and right for upper left values and up and left for bottom right values With that assumption, a patch size of 2,2 would indeed test the pixel (0,0) with (0,0), (1,0), (0,1) and (1,1) but the patch size is actually the diameter of the search area, centered in the pixel of first input which would then test (-1, -1), (-1, 0), (0, -1) and (0,0) (which will be almost exclusively 0)

See how it works here : https://github.com/ClementPinard/Pytorch-Correlation-extension/blob/master/Correlation_Module/correlation_cuda_kernel.cu#L51 we compute the radius of patch to get patchRad and then remove this from the sample coordinates of input2

If you actually want the cross product, you will have to a patch size twice the size of you inputs, and then crop the output with the right value for every output pixel, which is arguably wasteful.

jasam-sheja commented 11 months ago

Thank you for the detailed explanation! It cleared up my confusion about the patch_size parameter. I used the cross-product initially to help me understand the parameter behavior. With your explanation, I now have a much better understanding of how everything works. Thank you!