I have 2 tensors x and y of shape (B, N, 3) where B is the batch size, and N is fixed. Each value x[..., i] is a probability distribution of length N. I want to use a fixed cost matrix C where each value C[i, j] defines the cost of moving mass from index i in one distribution to index j in another, for all computations in the batch. Therefore C has shape (N, N).
Can I use SampleLoss to compute the sinkhorn distances between the two batches? I initially thought you could using the 'cost' init arg by setting this to a function that expands the fixed cost matrix to match the batch size, but this returns the same 0 value for every sample in the batch:
from geomloss import SamplesLoss
import torch
# Compute example cost matrix
# I want to be able to use arbitrary C here
N = 30
points = torch.rand(N, 2)
C = torch.cdist(points, points, p=2).pow(2)
def custom_cost(x, y):
# x and y are the input batches with shapes (B, N, 3)
batch_size, _, _ = x.shape
return C.unsqueeze(0).expand(batch_size, -1, -1)
def generate_probability_distributions(batch_size, num_points, dim):
# Generate random values
random_tensor = torch.rand(batch_size, num_points, dim)
# Normalize along the last dimension to make them probability distributions
probability_tensor = random_tensor / random_tensor.sum(dim=1, keepdim=True)
return probability_tensor
# Generate example tensors x and y
x = generate_probability_distributions(32, 30, 3)
y = generate_probability_distributions(32, 30, 3)
# Define loss
criterion = SamplesLoss('sinkhorn', blur=0.05, scaling=0.9, cost=custom_cost, backend="tensorized")
loss = criterion(x, y)
print(loss)
Hello, thanks a lot for all your work!
I have 2 tensors x and y of shape (B, N, 3) where B is the batch size, and N is fixed. Each value x[..., i] is a probability distribution of length N. I want to use a fixed cost matrix C where each value C[i, j] defines the cost of moving mass from index i in one distribution to index j in another, for all computations in the batch. Therefore C has shape (N, N).
Can I use SampleLoss to compute the sinkhorn distances between the two batches? I initially thought you could using the 'cost' init arg by setting this to a function that expands the fixed cost matrix to match the batch size, but this returns the same 0 value for every sample in the batch:
Output:
Any idea what I'm doing wrong? Any help would be much appreciated. Thanks.