microsoft / otdd

Optimal Transport Dataset Distance
MIT License
151 stars 48 forks source link

Problem when using "Exact" method to calculate large samples #23

Open chenmzh opened 2 years ago

chenmzh commented 2 years ago

When I tried to calculate the "exact" distance between two large datasets (with 2 labels and more than 10000 samples in total), the error would raise:

TypeError                                 Traceback (most recent call last)
~/otdd/otdd/pytorch/wasserstein.py in pwdist_exact(X1, Y1, X2, Y2, symmetric, loss, cost_function, p, debias, entreg, device)
    335         try:
--> 336             D[i, j] = distance(X1[Y1==c1[i]].to(device), X2[Y2==c2[j]].to(device)).item()
    337         except:

~/test/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

~/test/lib/python3.7/site-packages/geomloss-0.2.4-py3.7.egg/geomloss/samples_loss.py in forward(self, *args)
    283             labels_y=l_y,
--> 284             verbose=self.verbose,
    285         )

~/test/lib/python3.7/site-packages/geomloss-0.2.4-py3.7.egg/geomloss/sinkhorn_samples.py in sinkhorn_online(α, x, β, y, p, blur, reach, diameter, scaling, cost, debias, potentials, **kwargs)
    142     softmin = partial(
--> 143         softmin_online, log_conv=keops_lse(cost, D, dtype=str(x.dtype)[6:])
    144     )

~/test/lib/python3.7/site-packages/geomloss-0.2.4-py3.7.egg/geomloss/sinkhorn_samples.py in keops_lse(cost, D, dtype)
    107 #         "( B - (P * " + cost + " ) )",
--> 108         "( B - (P * " + cost+ " ) )",
    109         "A = Vi(1)",
TypeError: can only concatenate str (not "function") to str

I wonder would it be possible to use "exact" method to compute a large number of samples? Thanks

xlcbingo1999 commented 1 year ago

I met the same issue with you. If you solve this problem?

chenmzh commented 1 year ago

If I remember correct, there are some hard coded limitation inside the code. Not tested since last April, maybe there are difference now

xlcbingo1999 commented 1 year ago

If I remember correct, there are some hard coded limitation inside the code. Not tested since last April, maybe there are difference now

OK. Thanks for your help! I changed the lambda expression, then fixed the Error. But I got a slow speed for calculating the distance between two datasets.