Open chenmzh opened 2 years ago
I ran the following following codebut get MNIST-USPS distance around 930 instead of 1260 mentioned in paper
# Load datasets
loaders_src = load_torchvision_data('MNIST',resize = 28, maxsize=1000)[0]
loaders_tgt = load_torchvision_data('USPS', resize = 28, maxsize=1000)[0]
dist = DatasetDistance(loaders_src['train'], loaders_tgt['train'],
inner_ot_method = 'exact',# gaussian_approx, exact, jdot and naive_upperbound
p = 2, entreg = 1e-2,inner_ot_entreg = 1e-2,
device='cpu')
d = dist.distance(maxsamples = 10000)
print(f'OTDD(MNIST,USPS)={d}')
The gaussian_approx method gets an even smaller distance. Thanks!
Thank you for your attention! I was trying to reproduce the pairwise distance among (*MINST+USPS) datasets (figure 4 of Geometric Distance via Optimal Transportation). And I used the parameter setup provided by example.py. But the distances measured are much smaller than the figure 4, even try the upper bound as inner ot method, and some of the relations are different. May I ask is there a detailed parameter setup for reproducing this figure 4? Thank you very much!