atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

ExactOptimalTransportConditionalFlowMatcher with multiple conditions #132

Open lukasschmit opened 3 months ago

lukasschmit commented 3 months ago

Hi all, huge fan of your work/this library, we've had great success with training latent flow matching models.

Our model relies on multiple conditions to generate x1, it looks like the optimal transport class only supports having a single condition for the prior (y0) and the target distribution (y1).

If our model needs two conditions for x1 (ya_1, yb_1), I'd assume that this would just require the ot_sampler.sample_plan_with_labels() to accept perhaps a list[Tensor] for y1 and return:

def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True):
    pi = self.get_map(x0, x1)
    i, j = self.sample_map(pi, x0.shape[0], replace=replace)

    y1_pi = None
    if isinstance(y1, torch.Tensor):
        y1_pi = y1[j]
    elif isinstance(y1, list):
        y1_pi = [_y1[j] for _y1 in y1]

    return (
        x0[i],
        x1[j],
        # ...
        y1_pi
    )

I think something like this would allow the model to correctly accept multiple conditions along with the sampled xt?

Also it could be helpful to return the ot map pi for cases where downstream logic depends on the batch order (e.g. logging/loss aggregation which are per batch sample)

atong01 commented 3 months ago

Yep this looks correct to me. Happy to consider proposed changes to the interface! May want to consider other batching strategies (dependent on your conditions), but I haven't looked into this deeply.

--Alex