Hi all, huge fan of your work/this library, we've had great success with training latent flow matching models.
Our model relies on multiple conditions to generate x1, it looks like the optimal transport class only supports having a single condition for the prior (y0) and the target distribution (y1).
If our model needs two conditions for x1 (ya_1, yb_1), I'd assume that this would just require the ot_sampler.sample_plan_with_labels() to accept perhaps a list[Tensor] for y1 and return:
def sample_plan_with_labels(self, x0, x1, y0=None, y1=None, replace=True):
pi = self.get_map(x0, x1)
i, j = self.sample_map(pi, x0.shape[0], replace=replace)
y1_pi = None
if isinstance(y1, torch.Tensor):
y1_pi = y1[j]
elif isinstance(y1, list):
y1_pi = [_y1[j] for _y1 in y1]
return (
x0[i],
x1[j],
# ...
y1_pi
)
I think something like this would allow the model to correctly accept multiple conditions along with the sampled xt?
Also it could be helpful to return the ot map pi for cases where downstream logic depends on the batch order (e.g. logging/loss aggregation which are per batch sample)
Yep this looks correct to me. Happy to consider proposed changes to the interface! May want to consider other batching strategies (dependent on your conditions), but I haven't looked into this deeply.
Hi all, huge fan of your work/this library, we've had great success with training latent flow matching models.
Our model relies on multiple conditions to generate x1, it looks like the optimal transport class only supports having a single condition for the prior (
y0
) and the target distribution (y1
).If our model needs two conditions for
x1
(ya_1, yb_1
), I'd assume that this would just require theot_sampler.sample_plan_with_labels()
to accept perhaps a list[Tensor] fory1
and return:I think something like this would allow the model to correctly accept multiple conditions along with the sampled
xt
?Also it could be helpful to return the ot map
pi
for cases where downstream logic depends on the batch order (e.g. logging/loss aggregation which are per batch sample)