microsoft / otdd

Optimal Transport Dataset Distance
MIT License
156 stars 48 forks source link

Questions regarding using the "exact" method and default "gaussian_approx" method #24

Open chenmzh opened 2 years ago

chenmzh commented 2 years ago

Dear auther,

I am trying to calculate the distance between the same dataset(USPS training dataset). According to the paper, it says $d{OT-N} \le d{OT}$. So I use the exact and default gaussian_approx method to compute the distance. However, the result gets from the gaussian approx is a little bit larger than the exact method which is almost 0 in this case, which seems different from the previous statement.

Here is the setting for two experiment:

# For d_OT
from otdd.pytorch.datasets import load_torchvision_data
from otdd.pytorch.distance import DatasetDistance

# Load datasets
loaders_src = load_torchvision_data('USPS', valid_size=0, resize = 28, maxsize=20000)[0]
loaders_tgt = load_torchvision_data('USPS',  valid_size=0, resize = 28, maxsize=20000)[0]

# Instantiate distance
dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'],
                       inner_ot_method='exact',
                       inner_ot_debiased=True,
                       device='cpu')

d = dist.distance(maxsamples = 20000)
print(f'OTDD(src,tgt)={d}')

# For d_OT_N

# Instantiate distance
dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'],
                       inner_ot_debiased=True,
                       device='cpu')

d = dist.distance(maxsamples = 20000)
print(f'OTDD(src,tgt)={d}')