Open myoresh opened 2 years ago
The following code causes the crash - a degenerate case of loss between Dirac distribution and itself:
import torch from geomloss import SamplesLoss loss_func = SamplesLoss("sinkhorn") t1_samples = torch.tensor([[1.0, ], ]) t2_samples = torch.tensor([[1.0, ], ]) loss_value = loss_func(t1_samples, t2_samples)
The following code causes the crash - a degenerate case of loss between Dirac distribution and itself: