Closed CaioDaumann closed 4 days ago
Hi!
Our ot_sampler.sample_plan literally samples the plan in all cases (with repeats). This was originally done for simplicity (as the same sampler can also be used for regularized plans). However you are correct. In the exact case, you can skip this step. I think this is probably worth implementing this.
--Alex
Hmm, okay I am not sure if I understand what is going on. I will educate myself better on how the sampling plan is derived and try implementing a 1 to 1 plan with no repeated pairs.
Thanks for now! I will close the issue when I understand the problem better and come with a solution.
Sounds good! Maybe a useful hint. I think it should be something like this:
# Assume batches x0 and x1 [Batch, Dim]
M = torch.cdist(x0, x1) ** 2
_, col_ind = scipy.optimize.linear_sum_assignment(M)
x0 = x0[col_ind]
I guess it’s working wonderfully now.
Using your script, every single point of x0 and x1 is now connected as (the right has fewer connections because some are repeated). See the code and validation plot below:
import os
import matplotlib.pyplot as plt
import torch
import torchsde
from torchdyn.core import NeuralODE
from tqdm import tqdm
from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP
from torchcfm.utils import sample_8gaussians, sample_moons, torch_wrapper
batch_size = 10
x0 = sample_8gaussians(batch_size)
x1 = sample_moons(batch_size)
# Assume batches x0 and x1 [Batch, Dim]
M = torch.cdist(x0, x1) ** 2
_, col_ind = scipy.optimize.linear_sum_assignment(M)
x0_ = x0[col_ind]
x1_ = x1[col_ind]
# Batch OT exact sampling!
ot_sampler = OTPlanSampler(method="exact")
x0__,x1__ = ot_sampler.sample_plan(x0, x1)
pairs = [
(x0_, x1_),
(x0__, x1__)
]
plot_multiple_points_with_connections(pairs, titles=['Alexander Script', 'Exact OT'])
And the plot looks like:
(Ps: Note that the Y and X axis are not exactly the same)
Would you like this merged, or any other performance validation? I’d be happy to help with anything.
Thanks for the help!
This can also be controlled in the sample_map
function by setting replace=False
. See here https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/optimal_transport.py#L99.
As it is solved, I am closing the issue.
Hi, I’ve been experimenting with the OTPlanSampler for the batchOT and SF2M methods, but I encountered some results using the non-regularized OT plan sampler that I can’t understand.
For example, my very simple code is as follows:
And the input tensor is:
While the output is:
As you can see, one of the entries in x0__ is repeated. I would naively assume that when using the exact method without any regularization, every single x0 point would be matched to every single x1 point. However, in this case, the point [4.9099, 0.5921] is missing from the sampled plan.
Am I missing something about how this should work?