connorlee77 / pytorch-mutual-information

Mutual Information in Pytorch
103 stars 10 forks source link

RuntimeError: The size of tensor a (3) must match the size of tensor b (256) at non-singleton dimension 2 #3

Closed bach05 closed 1 year ago

bach05 commented 1 year ago

I am trying to compute the Mutual Information between to batched images of shape [12,3,256,256]. I set up as follows: self.mi = MutualInformation(num_bins=256, sigma=0.1, normalize=True, device=self.args.device).to(self.args.device) Note I add the parameter device since it was not defined in __init__. But I get this error:

residuals = values - self.bins.unsqueeze(0).unsqueeze(0)
RuntimeError: The size of tensor a (3) must match the size of tensor b (256) at non-singleton dimension 2

It makes sense since value is a [12,65536,3] tensor while bins is a [1,1,256] tensor. How to fix it? Does it make sense to change the shaping of bins to match value?

connorlee77 commented 1 year ago

Ahh this is because the code expects grayscale image inputs which are single channel and you are passing a 3 channel input. Check the example in README.md.

bach05 commented 1 year ago

Ok, I see. So either I have to ocnvert to black and white or modify the script such that it works for RGB images. Do you have any suggestion in the latter case?

daniela997 commented 1 year ago

@connorlee77 I have the same question- how would you suggest ammending the code so that it works for images with more than 1 channel?

bach05 commented 1 year ago

@connorlee77 I have the same question- how would you suggest ammending the code so that it works for images with more than 1 channel?

Hi @daniela997 , from my side I have just computed the joint probability needed to compute the mutual information as the product of the 3 channels joint probabilities. This is an approximation since the 3 RGB channels are not independent, but computing the true joint distribution was unfeasible.

connorlee77 commented 1 year ago

For an RGB image, you would need to compute the probability $P(r, g, b)$ when doing the histogram calculations. This can be generalized to $N$ dimensions as $P(x_1, x_2, ..., x_N)$. The number of histogram bins you'd have would be $N^\textrm{bits}$. I'm not sure what the easiest way to do this is.

I recommend just converting to a grayscale image, or simply reshaping the image to a 1D array. It shouldn't matter how you reshape it, as long as you do the same for both images.