jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
570 stars 57 forks source link

ValueError: Maximum allowed size exceeded when only one value #74

Open tbrugere opened 8 months ago

tbrugere commented 8 months ago

In the degenerated case when there is only one sample in a and b, and that is the same sample, Samplesloss will fail with a

ValueError: Maximum allowed size exceeded

Minimum example:

In: from geomloss import SamplesLoss
In: loss = SamplesLoss(loss="sinkhorn", p=2, blur=.05) 
In: a = b = Tensor([[1, 2]])
In: loss(a, b)
Out: RuntimeWarning: divide by zero encountered in log                              
  p * np.log(diameter), p * np.log(blur), p * np.log(scaling)                                                                                                                      
ValueError                                Traceback (most recent call last)                                                                                                        
Cell In[7], line 1                                                                                                                                                                 
----> 1 loss(a, b)                                                                                                                                                                 

File ~/.conda/envs/default/lib/python3.11/site-packages/torch/nn/modules/, in Module._call_impl(self, *args, **kwargs)                                               
   1496 # If we don't have any hooks, we want to skip the rest of the logic in                                                                                                     
   1497 # this function, and just call forward.                                                                                                                                    
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks                                                                 
   1499         or _global_backward_pre_hooks or _global_backward_hooks                                                                                                            
   1500         or _global_forward_hooks or _global_forward_pre_hooks):                                                                                                            
-> 1501     return forward_call(*args, **kwargs)                                                                                                                                   
   1502 # Do not call functions when jit is used                                                                                                                                   
   1503 full_backward_hooks, non_full_backward_hooks = [], []                                                                                                                      

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/, in SamplesLoss.forward(self, *args)                                                          
    262     α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)                                                                                            
    264 # Run --------------------------------------------------------------------------------                                                                                     
--> 265 values = routines[self.loss][backend](                                                                                                                                     
    266     α,                                                                                                                                                                     
    267     x,                                                                                                                                                                     
    268     β,                                                                                                                                                                     
    269     y,                                                                                                                                                                     
    270     p=self.p,                                                                                                                                                              
    271     blur=self.blur,
    272     reach=self.reach,
    273     diameter=self.diameter,
    274     scaling=self.scaling,
    275     truncate=self.truncate,
    276     cost=self.cost,
    277     kernel=self.kernel,
    278     cluster_scale=self.cluster_scale,
    279     debias=self.debias,
    280     potentials=self.potentials,
    281     labels_x=l_x,
    282     labels_y=l_y,
    283     verbose=self.verbose,
    284 )
    286 # Make sure that the output has the correct shape ------------------------------------
    287 if (
    288     self.potentials
    289 ):  # Return some dual potentials (= test functions) sampled on the input measures

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/, in sinkhorn_tensorized(a, x, b, y, p, blur, reach, diameter, scaling, cost, debias, poten
tials, **kwargs)
    186 C_yy = cost(y, y.detach()) if debias else None  # (B,M,M) torch Tensor
    188 # Compute the relevant values of the diameter of the configuration,
    189 # target temperature epsilon, temperature schedule across itereations
    190 # and strength of the marginal constraints:
--> 191 diameter, eps, eps_list, rho = scaling_parameters(
    192     x, y, p, blur, reach, diameter, scaling
    193 )
    195 # Use an optimal transport solver to retrieve the dual potentials:
    196 f_aa, g_bb, g_ab, f_ba = sinkhorn_loop(
    197     softmin_tensorized,
    198     log_weights(a),
    206     debias=debias,
    207 )

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/, in scaling_parameters(x, y, p, blur, reach, diameter, scaling)
    161 eps = blur ** p
    162 rho = None if reach is None else reach ** p
--> 163 eps_list = epsilon_schedule(p, diameter, blur, scaling)
    164 return diameter, eps, eps_list, rho

File ~/.conda/envs/default/lib/python3.11/site-packages/geomloss/, in epsilon_schedule(p, diameter, blur, scaling)
    116 def epsilon_schedule(p, diameter, blur, scaling):
    117     r"""Creates a list of values for the temperature "epsilon" across Sinkhorn iterations.
    119     We use an aggressive strategy with an exponential cooling
    140         list of float: list of values for the temperature epsilon.
    141     """
    142     eps_list = (
    143         [diameter ** p]
    144         + [
    145             np.exp(e)
--> 146             for e in np.arange(
    147                 p * np.log(diameter), p * np.log(blur), p * np.log(scaling)
    148             )
    149         ]
    150         + [blur ** p]
    151     )
    152     return eps_list

ValueError: Maximum allowed size exceeded

This could be solved by checking for a 0 diameter, and returning 0 in that case.