In the degenerated case when there is only one sample in a and b, and that is the same sample, Samplesloss will fail with a
ValueError: Maximum allowed size exceeded
Minimum example:
In: from geomloss import SamplesLoss
In: loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05)
In: a = b = Tensor([[1, 2]])
In: loss(a, b)
Out:
sinkhorn_divergence.py:147: RuntimeWarning: divide by zero encountered in log
p * np.log(diameter), p * np.log(blur), p * np.log(scaling)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[7], line 1
----> 1 loss(a, b)
File ~/.conda/envs/default/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/samples_loss.py:265, in SamplesLoss.forward(self, *args)
262 α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)
264 # Run --------------------------------------------------------------------------------
--> 265 values = routines[self.loss][backend](
266 α,
267 x,
268 β,
269 y,
270 p=self.p,
271 blur=self.blur,
272 reach=self.reach,
273 diameter=self.diameter,
274 scaling=self.scaling,
275 truncate=self.truncate,
276 cost=self.cost,
277 kernel=self.kernel,
278 cluster_scale=self.cluster_scale,
279 debias=self.debias,
280 potentials=self.potentials,
281 labels_x=l_x,
282 labels_y=l_y,
283 verbose=self.verbose,
284 )
286 # Make sure that the output has the correct shape ------------------------------------
287 if (
288 self.potentials
289 ): # Return some dual potentials (= test functions) sampled on the input measures
File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_samples.py:191, in sinkhorn_tensorized(a, x, b, y, p, blur, reach, diameter, scaling, cost, debias, poten
tials, **kwargs)
186 C_yy = cost(y, y.detach()) if debias else None # (B,M,M) torch Tensor
188 # Compute the relevant values of the diameter of the configuration,
189 # target temperature epsilon, temperature schedule across itereations
190 # and strength of the marginal constraints:
--> 191 diameter, eps, eps_list, rho = scaling_parameters(
192 x, y, p, blur, reach, diameter, scaling
193 )
195 # Use an optimal transport solver to retrieve the dual potentials:
196 f_aa, g_bb, g_ab, f_ba = sinkhorn_loop(
197 softmin_tensorized,
198 log_weights(a),
(...)
206 debias=debias,
207 )
File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_divergence.py:163, in scaling_parameters(x, y, p, blur, reach, diameter, scaling)
161 eps = blur ** p
162 rho = None if reach is None else reach ** p
--> 163 eps_list = epsilon_schedule(p, diameter, blur, scaling)
164 return diameter, eps, eps_list, rho
File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/sinkhorn_divergence.py:146, in epsilon_schedule(p, diameter, blur, scaling)
116 def epsilon_schedule(p, diameter, blur, scaling):
117 r"""Creates a list of values for the temperature "epsilon" across Sinkhorn iterations.
118
119 We use an aggressive strategy with an exponential cooling
(...)
140 list of float: list of values for the temperature epsilon.
141 """
142 eps_list = (
143 [diameter ** p]
144 + [
145 np.exp(e)
--> 146 for e in np.arange(
147 p * np.log(diameter), p * np.log(blur), p * np.log(scaling)
148 )
149 ]
150 + [blur ** p]
151 )
152 return eps_list
ValueError: Maximum allowed size exceeded
This could be solved by checking for a 0 diameter, and returning 0 in that case.
In the degenerated case when there is only one sample in
a
andb
, and that is the same sample,Samplesloss
will fail with aMinimum example:
This could be solved by checking for a 0 diameter, and returning 0 in that case.