connorlee77 / pytorch-mutual-information

Mutual Information in Pytorch
103 stars 10 forks source link

Update: Integrated into Kornia

pytorch-mutual-information

Batch computation of mutual information and histogram2d in Pytorch

This implementation uses kernel density estimation with a gaussian kernel to calculate histograms and joint histograms. We use a diagonal bandwidth matrix for the multivariate case, which allows us to decompose the multivariate kernel as the product of each univariate kernel. From wikipedia,

where the bandwith matrix

Example usage


Setup

device = 'cuda:0'

img1 = Image.open('grad1.jpg').convert('L')
img2 = Image.open('grad.jpg').convert('L')

img1 = transforms.ToTensor() (img1).unsqueeze(dim=0).to(device)
img2 = transforms.ToTensor() (img2).unsqueeze(dim=0).to(device)

# Pair of different images, pair of same images
input1 = torch.cat([img2, img2])
input2 = torch.cat([img1, img2])

B, C, H, W = input1.shape   # shape: (2, 1, 300, 300)

Histogram usage:

hist = histogram(input1.view(B, H*W), torch.linspace(0,255,256), sigma)

Histogram 2D usage:

hist = histogram2d(input1.view(B, H*W), input2.view(B, H*W), torch.linspace(0,255,256), sigma)

Mutual Information (of images)

MI = MutualInformation(num_bins=256, sigma=0.4, normalize=True).to(device)
score = MI(input1, input2)

Results


Histogram

Joint Histogram