jeanfeydy / geomloss

Geometric loss functions between point clouds, images and volumes
MIT License
595 stars 59 forks source link

Padding #33

Open korunosk opened 4 years ago

korunosk commented 4 years ago

I am aware that Geomloss supports batching, however my distributions have different number of samples, and the only way I could think of creating the batches was to pad the tensors. But, obviously the results are incorrect this way.

Any suggestions to overcome this problem?

Self contained example:

import torch
import torch.nn.functional as F
from geomloss import SamplesLoss

N = 20

# I want to compute the loss between d and every s_i.
# Instead of iterating: [loss(d, s1), loss(d, s2), loss(d, s3)],
# I want to compute it using batching: loss([d, d, d], [s1, s2, s3])
# The problem is that the number of samples are different between s_i.
# I pad the tensors, however, the results are obviously wrong.

# Original tensors
d = torch.rand(size=(80, 100))

s1 = torch.rand(size=(10, 100))
s2 = torch.rand(size=(15, 100))
s3 = torch.rand(size=(20, 100))

# Padded tensors
d_ = d.unsqueeze(0).repeat(3,1,1)

s_ = torch.stack([
    F.pad(s1, (0, 0, 0, N - s1.shape[0])),
    F.pad(s2, (0, 0, 0, N - s2.shape[0])),
    F.pad(s3, (0, 0, 0, N - s3.shape[0]))
])

# Sinkhorn
loss = SamplesLoss()

assert(loss(d_, s_)[0] == loss(d, s1))
tanglef commented 4 years ago

Hi,

If you want to pad the tensors, you need to explicitly give the weights (otherwise as you noticed the padded zeros are taken into account into the generated weights). For a uniform weight you can first use the same function used by the package when computing the loss, and then pad.

def get_weights(sample): # uniform distribution
    if sample.dim() == 2:  # 
        N = sample.shape[0]
        return torch.ones(N).type_as(sample) / N
    elif sample.dim() == 3:
        B, N, _ = sample.shape
        return torch.ones(B,N).type_as(sample) / N

beta = get_weights(s1)
gamma = get_weights(s2)
padding = max(s1.shape[0], s2.shape[0]) # no need to pad too much 

alpha_ = get_weights(d_)
weights_ = torch.stack([ # padded weights
    F.pad(beta, (0, padding - s1.shape[0])),
    F.pad(gamma, (0, padding - s2.shape[0]))    
])

padded_loss = loss(alpha_, d_, weights_, s_) # using tensorized backend
print("Is loss close for the 1st target?", torch.isclose(loss(d, s1), padded_loss[0], atol=1e-3).item())
print("Is loss close for the 2nd target?", torch.isclose(loss(d, s2) padded_loss[1], atol=1e-3).item())

But that can become quickly "heavy". I made a pull-request #35 to make this "user-friendlier" by giving a list of targets with or without a list of weights for the batches.

Best regards. Tanguy

korunosk commented 4 years ago

Hi Tanguy,

I ended up using a similar approach as yours and the results are as expected. Anyways, thank you for you answer!

Best, Mladen

GilgameshD commented 8 months ago

Hi @korunosk, I am having the same problem on my side. Would you mind sharing your solution? Thanks!