jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
570 stars 57 forks source link

Sinkhorn biased results #36

Open Cy-r0 opened 3 years ago

Cy-r0 commented 3 years ago

Hello, I wrote a small test where I calculate the Sinkhorn distance between a target distribution (univariate normal with mean 0 and variance 1) and a bunch of sample distributions (univariate normals with different means and variances), and the results I get are quite surprising.

If my sampled distribution is a normal with mean 0 and variance < 1, the Sinkhorn distance from the target distribution is lower than if my sampled distribution has variance 1. This seems to indicate that the loss is biased, as I would expect Sinkhorn to be at a minimum when the two distributions have same mean and variance. I played a bit with the parameters of SamplesLoss but the result didn't change significantly.

I attached a script below that does a mean and variance sweep and saves a bunch of plots with the corresponding Sinkhorn distance values. I'd be grateful if you could shed some light on this.

import matplotlib.pyplot as plt
import numpy as np
import torch

from geomloss import SamplesLoss

np.random.seed(31)
torch.manual_seed(31)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def normal(imgshape, mu, sigma):
    """
    Sample a random tensor from a normal
    """  
    return torch.randn(imgshape) * sigma + mu 

def main():

    N = 10000

    mus = np.arange(-5, 5, 0.5)
    sigmas = np.arange(0.1, 2, 0.1)

    target = normal((1, N), mu=0, sigma=1)

    musweep = np.zeros(len(mus))
    sigmasweep = np.zeros(len(sigmas))

    sinkhorn_loss = SamplesLoss(loss="sinkhorn", p=1, blur=.05)

    # Cycle thru all mus
    for i, mu in enumerate(mus):
        sample = normal((1, N), mu=mu, sigma=1)

        sinkhorn_D = sinkhorn_loss(sample, target)
        musweep[i] = sinkhorn_D

        print("mu:", mu, "sinkhorn:", sinkhorn_D)

        # Log histograms
        fig = plt.figure(figsize=(5, 4), dpi=500)
        fig.tight_layout(pad=0)
        ax = fig.add_subplot(111)
        plt.subplots_adjust(wspace=0, hspace=0)
        ax.margins(0)
        plt.hist(sample.flatten().cpu().numpy(), bins=np.arange(-4, 4 + 0.05, 0.05), alpha=0.5)
        plt.hist(target.flatten().cpu().numpy(), bins=np.arange(-4, 4 + 0.05, 0.05), alpha=0.5)
        plt.title(f'sample vs target')
        plt.legend(["sample", "target"])
        plt.xlim(-4, 4)
        fig.text(.5, 0, 
            f"sinkhorn={float(sinkhorn_D):.4}", ha="center")
        plt.savefig(f"mu_{mu}.png")
        plt.close()

    # Cycle thru all sigmas
    for i, sigma in enumerate(sigmas):
        sample = normal((1, N), mu=0, sigma=sigma)

        sinkhorn_D = sinkhorn_loss(sample, target)
        sigmasweep[i] = sinkhorn_D

        print("sigma:", sigma, "sinkhorn:", sinkhorn_D)

        # Log histograms
        fig = plt.figure(figsize=(5, 4), dpi=500)
        fig.tight_layout(pad=0)
        ax = fig.add_subplot(111)
        plt.subplots_adjust(wspace=0, hspace=0)
        ax.margins(0)
        plt.hist(sample.flatten().cpu().numpy(), bins=np.arange(-4, 4 + 0.05, 0.05), alpha=0.5)
        plt.hist(target.flatten().cpu().numpy(), bins=np.arange(-4, 4 + 0.05, 0.05), alpha=0.5)
        plt.title(f'sample vs target')
        plt.legend(["sample", "target"])
        plt.xlim(-4, 4)
        fig.text(.5, 0, 
            f"sinkhorn={float(sinkhorn_D):.4}", ha="center")
        plt.savefig(f"sigma_{sigma:.2}.png")
        plt.close()

        im = plt.imread(f"sigma_{sigma:.2}.png")

    # Plot sinkhorn vs mu
    fig = plt.figure(figsize=(5, 4), dpi=500)
    fig.tight_layout(pad=0)
    ax = fig.add_subplot(111)
    plt.subplots_adjust(wspace=0, hspace=0)
    ax.margins(0)
    plt.plot(mus, musweep)
    plt.title(f'sinkhorn vs mu')
    plt.xlim(-4, 4)
    plt.savefig(f"sinkhorn_vs_mu.png")
    plt.close()

    # Plot sinkhorn vs sigma
    fig = plt.figure(figsize=(5, 4), dpi=500)
    fig.tight_layout(pad=0)
    ax = fig.add_subplot(111)
    plt.subplots_adjust(wspace=0, hspace=0)
    ax.margins(0)
    plt.plot(sigmas, sigmasweep)
    plt.title(f'sinkhorn vs sigma')
    plt.xlim(-4, 4)
    plt.savefig(f"sinkhorn_vs_sigma.png")
    plt.close()

if __name__ == "__main__":
    main()