jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
599 stars 60 forks source link

Can I compute the sinkhorn distance between 2 batches of p distributions using a fixed cost matrix defined elsewhere? #81

Open ldswaby opened 6 months ago

ldswaby commented 6 months ago

Hello, thanks a lot for all your work!

I have 2 tensors x and y of shape (B, N, 3) where B is the batch size, and N is fixed. Each value x[..., i] is a probability distribution of length N. I want to use a fixed cost matrix C where each value C[i, j] defines the cost of moving mass from index i in one distribution to index j in another, for all computations in the batch. Therefore C has shape (N, N).

Can I use SampleLoss to compute the sinkhorn distances between the two batches? I initially thought you could using the 'cost' init arg by setting this to a function that expands the fixed cost matrix to match the batch size, but this returns the same 0 value for every sample in the batch:

from geomloss import SamplesLoss
import torch

# Compute example cost matrix
# I want to be able to use arbitrary C here
N = 30
points = torch.rand(N, 2)
C = torch.cdist(points, points, p=2).pow(2)

def custom_cost(x, y): 
    # x and y are the input batches with shapes (B, N, 3) 
    batch_size, _, _ = x.shape

    return C.unsqueeze(0).expand(batch_size, -1, -1)

def generate_probability_distributions(batch_size, num_points, dim):
    # Generate random values
    random_tensor = torch.rand(batch_size, num_points, dim)

    # Normalize along the last dimension to make them probability distributions
    probability_tensor = random_tensor / random_tensor.sum(dim=1, keepdim=True)

    return probability_tensor

# Generate example tensors x and y
x = generate_probability_distributions(32, 30, 3)
y = generate_probability_distributions(32, 30, 3)

# Define loss
criterion = SamplesLoss('sinkhorn', blur=0.05, scaling=0.9, cost=custom_cost, backend="tensorized")
loss = criterion(x, y)
print(loss)

Output:

 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0.])

Any idea what I'm doing wrong? Any help would be much appreciated. Thanks.