jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
598 stars 60 forks source link

Slightly biased results of Sinkhorn divergence #32

Open wzm2256 opened 4 years ago

wzm2256 commented 4 years ago

Hi, I'm using this code for density estimation and generative modelling. I think Sinkhorn divergence is perfect for such tasks because of it unbiased nature. However, in practice, I find it is slight biased.

Basically, I have several datapoints and a Gaussian distribution. I want to match them so the datapoints can be regarded as samples from the Gaussian distribution. This is a basic setting in generative modelling. Ideally, this will work well because Sinkhorn divergence gets rid of entropic bias of regularized W-distance by adding two self correlation terms. However, the results are still biased.

I present my simple test code here.


import os
import time

import numpy as np
import torch

from geomloss import SamplesLoss

eps = 0.1
sample = 2000
lr = 0.001
epoch = 400
p = 1

# Synthesis data points.
device = 'cuda'
tdtype = torch.float

rng = np.random.RandomState(0)
Train_all = rng.randn(1000, 2)
train_tensor = torch.tensor(Train_all, device=device, dtype=tdtype).requires_grad_(True)

d=Train_all.shape[1]

# Gaussian model
class Gaussian():
    def __init__(self, Sample_n, u, sigma, eps=0.5, p=1, tdtype=torch.float, device='cuda'):

        self.u = torch.tensor(u, requires_grad=True, device=device, dtype=tdtype)
        self.sigma = torch.tensor(sigma, requires_grad=True, device=device, dtype=tdtype)

        self.L_ab = SamplesLoss('energy', potentials=False, backend='tensorized', scaling=0.9)
        # self.L_ab = SamplesLoss('sinkhorn', p=p, blur=eps, debias=True, potentials=False, backend='tensorized', scaling=0.9)
        # self.L_ab = SamplesLoss('sinkhorn', p=p, blur=eps, debias=False, potentials=False, backend='tensorized', scaling=0.9)

        self.dim = self.u.shape[-1]

        self.n1 = Sample_n 
        self.eps = eps
        self.tdtype = tdtype
        self.device = device

    def logp_x(self, x):
        out = - self.dim / 2 * np.log(2 * np.pi) - torch.sum(torch.log(self.sigma), -1) - 0.5 * torch.sum((torch.unsqueeze(x, 0) - self.u)** 2 / (self.sigma ** 2), -1)
        return out       

    def Sample(self):
        # with torch.no_grad():
        Sample = torch.randn((1, self.n1, self.dim), device=self.device, dtype=self.tdtype)
        tmp = Sample * self.sigma + self.u
        out = torch.reshape(tmp, [-1, self.dim])
        return out

    def loss(self, D):

        x = self.Sample()
        out = self.L_ab(x, D)
        return out

# Init the Gaussian
alpha = np.ones(1) 
u_ = np.zeros((1, 1, d))
S_ = np.ones((1, 1, d))
model = Gaussian(sample, u_, S_, eps=eps, p=1)

a_opt = torch.optim.RMSprop([
                            {'params': model.u},
                            {'params': model.sigma}
                            ] , lr=lr, alpha=0.9)

p_opt = torch.optim.RMSprop([
                            {'params': train_tensor},
                            ] , lr=lr, alpha=0.9)

for epoch in range(epoch):

    lp = model.loss(train_tensor)

    # Uncomment to optimize datapoint
    # p_opt.zero_grad()
    # lp.backward()
    # p_opt.step()
    # print('Data mean:\t', train_tensor.detach().mean(0).cpu().numpy(), '\t Data std:\t', train_tensor.detach().std(0).cpu().numpy())
    # print('Gaussian mean:\t', model.u[:,0,:].detach().cpu().numpy(), '\t \t \t Gaussian std:\t', model.sigma.detach()[:,0,:].cpu().numpy())

    # Uncomment to optimize Gaussian dsitribution
    a_opt.zero_grad()
    lp.backward()
    a_opt.step()
    print('Data mean:\t', train_tensor.detach().mean(0).cpu().numpy(), '\t Data std:\t', train_tensor.detach().std(0).cpu().numpy())
    print('Gaussian mean:\t', model.u[:,0,:].detach().cpu().numpy(), '\t Gaussian std:\t', model.sigma.detach()[:,0,:].cpu().numpy())

I present several experimental facts of this code:

wzm2256 commented 4 years ago

Let me explain why it is important. Because in generative modelling, we need to adjust the data and the Gaussian simultaneously to to match each other. Being biased, although just a little, will lead to zero variance.