atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

Non-Regularized OTPlanSampler: Duplicate and Missing Entries #142

Closed CaioDaumann closed 4 days ago

CaioDaumann commented 1 week ago

Hi, I’ve been experimenting with the OTPlanSampler for the batchOT and SF2M methods, but I encountered some results using the non-regularized OT plan sampler that I can’t understand.

For example, my very simple code is as follows:

batch_size = 6
x0 = sample_8gaussians(batch_size)
x1 = sample_moons(batch_size)

ot_sampler = OTPlanSampler(method="exact")
x0__,x1__  = ot_sampler.sample_plan(x0, x1)
print( x0, x0__ )

And the input tensor is:

tensor([[ 2.2023, -2.9085],
        [ 3.7234,  4.4048],
        [ 2.9529, -3.4727],
        [-2.7239, -3.5034],
        [-3.2309,  4.0971],
        [ 4.9099,  0.5921]]) 

While the output is:

tensor([[ 3.7234,  4.4048],
        [ 2.9529, -3.4727],
        [-2.7239, -3.5034],
        [ 2.2023, -2.9085],
        [ 2.9529, -3.4727],
        [-3.2309,  4.0971]])

As you can see, one of the entries in x0__ is repeated. I would naively assume that when using the exact method without any regularization, every single x0 point would be matched to every single x1 point. However, in this case, the point [4.9099, 0.5921] is missing from the sampled plan.

Am I missing something about how this should work?

atong01 commented 1 week ago

Hi!

Our ot_sampler.sample_plan literally samples the plan in all cases (with repeats). This was originally done for simplicity (as the same sampler can also be used for regularized plans). However you are correct. In the exact case, you can skip this step. I think this is probably worth implementing this.

--Alex

CaioDaumann commented 1 week ago

Hmm, okay I am not sure if I understand what is going on. I will educate myself better on how the sampling plan is derived and try implementing a 1 to 1 plan with no repeated pairs.

Thanks for now! I will close the issue when I understand the problem better and come with a solution.

atong01 commented 1 week ago

Sounds good! Maybe a useful hint. I think it should be something like this:

# Assume batches x0 and x1 [Batch, Dim]
M = torch.cdist(x0, x1) ** 2
_, col_ind = scipy.optimize.linear_sum_assignment(M)
x0 = x0[col_ind]
CaioDaumann commented 1 week ago

I guess it’s working wonderfully now.

Using your script, every single point of x0 and x1 is now connected as (the right has fewer connections because some are repeated). See the code and validation plot below:

import os

import matplotlib.pyplot as plt
import torch
import torchsde
from torchdyn.core import NeuralODE
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP
from torchcfm.utils import sample_8gaussians, sample_moons, torch_wrapper

batch_size = 10
x0 = sample_8gaussians(batch_size)
x1 = sample_moons(batch_size)

# Assume batches x0 and x1 [Batch, Dim]
M = torch.cdist(x0, x1) ** 2
_, col_ind = scipy.optimize.linear_sum_assignment(M)
x0_ = x0[col_ind]
x1_ = x1[col_ind]

# Batch OT exact sampling!
ot_sampler = OTPlanSampler(method="exact")
x0__,x1__  = ot_sampler.sample_plan(x0, x1)

pairs = [
    (x0_, x1_),
    (x0__, x1__)
]

plot_multiple_points_with_connections(pairs, titles=['Alexander Script', 'Exact OT'])

And the plot looks like:

(Ps: Note that the Y and X axis are not exactly the same)

x0_x1_connections_tuple

Would you like this merged, or any other performance validation? I’d be happy to help with anything.

Thanks for the help!

kilianFatras commented 4 days ago

This can also be controlled in the sample_map function by setting replace=False. See here https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py#L99.

As it is solved, I am closing the issue.