atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.08k stars 85 forks source link

How to compute inverse trajectories #99

Closed francesco-vaselli closed 8 months ago

francesco-vaselli commented 8 months ago

Dear all, Thanks for the great package!

I am writing to seek guidance with a doubt I have. With discrete flows, during training we learn a transformation from data space to, say, Gaussian space ( $u = f(x)$ ) and then invert to transform from Gaussian to data ( $x = f^{-1}(u)$ ). We can always take new data to the Gaussian space by using $u = f(x)$ .

In the case of this package, it's clear how we can use the models to start from a (Gaussian) noise space, get the initial conditions of the ODE and solve with torchdiffeq or torchdyn. The question is how can I use the model to compute the reverse trajectory, i.e. go from new data back to the (Gaussian) noise space? How do we compute $u = f(x)$ ?

Do I have to reverse the time steps order (from 1 to 0) and give the data as the initial conditions of the ODE?

I would be really grateful for any guidance or code example you could provide.

Best regards, Francesco

ImahnShekhzadeh commented 8 months ago

Hi,

For example In the NB conditional_mnist.ipynb, this code sample generates samples $x_1 \sim q(x_1)$, where $q(x_1)$ is the target distribution, starting from $x_0 \sim p(x_0)$, where $p(x_0)$ is the base distribution:

traj = torchdiffeq.odeint(
    lambda t, x: model.forward(t, x, generated_class_list),
    torch.randn(100, 1, 28, 28).to(device),
    torch.linspace(0, 1, 2).to(device),
    atol=1e-4,
    rtol=1e-4,
    method="dopri5",
)

In this case, traj is a Tensor of shape (2, 100, 1, 28, 28) $-$ since we specified the noise to be in this shape $-$ so by reversing the order of traj along the first dimension, you get the trajectory from the target to base distribution:

traj__target_to_base = traj.flip(dims=[0]) 

Hope this helps :slightly_smiling_face:

francesco-vaselli commented 8 months ago

Dear Imahn, Thank you so much for the help!

From what I understand, your approach:

  1. Starts from some noise $u$;
  2. Generates some synthetic data $x$ evolving $u$ along the trajectories of the ODE;
  3. Gets back to the starting noise $u$ by inverting the trajectories.

However, what I would like to do is to:

  1. Start from some novel, real data $x ^\ast$ coming from the same process as the training one;
  2. Find the trajectory for taking this data into the noise $u^\ast$.

Do you know if something like this is possible? Thanks again for taking the time to help, Best Francesco

ImahnShekhzadeh commented 8 months ago

Dear Francesco, You are right. :slightly_smiling_face: I was at first unsure how to do what you asked for, but having looked into the function torchdiffeq.odeint() again, I would do it like this:

  1. First train your CNF (continuous NF) with the CFM (conditional FM) objective, e.g. as in cells 3 and 4 of this NB.

  2. Then sample an $x_1 \sim q(x_1)$ $-$ or a batch of them $-$ i.e. samples from the real target distribution, let's say MNIST images in the shape (100, 1, 28, 28), where the first dimension is just the batch dimension, and then do trajectory inference like this:

    traj = torchdiffeq.odeint(
    lambda t, x: model.forward(t, x, generated_class_list),
    mnist_images,  # Tensor in shape `(100, 1, 28, 28)` e.g.
    torch.linspace(1, 0, 2).to(device),
    atol=1e-4,
    rtol=1e-4,
    method="dopri5",
    )

Hope this helps.

atong01 commented 8 months ago

Yep, this is the solution I would use too.

Great answer @ImahnShekhzadeh.

kilianFatras commented 8 months ago

Closing this issue has it has been solved.