microsoft / otdd

Optimal Transport Dataset Distance
MIT License
151 stars 48 forks source link

Parameter setup for the *MNIST+USPS distance #15

Open chenmzh opened 2 years ago

chenmzh commented 2 years ago

Thank you for your attention! I was trying to reproduce the pairwise distance among (*MINST+USPS) datasets (figure 4 of Geometric Distance via Optimal Transportation). And I used the parameter setup provided by example.py. But the distances measured are much smaller than the figure 4, even try the upper bound as inner ot method, and some of the relations are different. May I ask is there a detailed parameter setup for reproducing this figure 4? Thank you very much!

chenmzh commented 2 years ago

I ran the following following codebut get MNIST-USPS distance around 930 instead of 1260 mentioned in paper

# Load datasets
loaders_src = load_torchvision_data('MNIST',resize = 28,   maxsize=1000)[0]
loaders_tgt = load_torchvision_data('USPS', resize = 28,   maxsize=1000)[0]
dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'],
                   inner_ot_method = 'exact',# gaussian_approx, exact, jdot and naive_upperbound
                   p = 2, entreg = 1e-2,inner_ot_entreg = 1e-2, 
                   device='cpu')
d = dist.distance(maxsamples = 10000)
print(f'OTDD(MNIST,USPS)={d}')

The gaussian_approx method gets an even smaller distance. Thanks!