Closed francesco-vaselli closed 8 months ago
Hi,
For example In the NB conditional_mnist.ipynb
, this code sample generates samples $x_1 \sim q(x_1)$, where $q(x_1)$ is the target distribution, starting from $x_0 \sim p(x_0)$, where $p(x_0)$ is the base distribution:
traj = torchdiffeq.odeint(
lambda t, x: model.forward(t, x, generated_class_list),
torch.randn(100, 1, 28, 28).to(device),
torch.linspace(0, 1, 2).to(device),
atol=1e-4,
rtol=1e-4,
method="dopri5",
)
In this case, traj
is a Tensor of shape (2, 100, 1, 28, 28)
$-$ since we specified the noise to be in this shape $-$ so by reversing the order of traj
along the first dimension, you get the trajectory from the target to base distribution:
traj__target_to_base = traj.flip(dims=[0])
Hope this helps :slightly_smiling_face:
Dear Imahn, Thank you so much for the help!
From what I understand, your approach:
However, what I would like to do is to:
Do you know if something like this is possible? Thanks again for taking the time to help, Best Francesco
Dear Francesco,
You are right. :slightly_smiling_face: I was at first unsure how to do what you asked for, but having looked into the function torchdiffeq.odeint()
again, I would do it like this:
First train your CNF (continuous NF) with the CFM (conditional FM) objective, e.g. as in cells 3 and 4 of this NB.
Then sample an $x_1 \sim q(x_1)$ $-$ or a batch of them $-$ i.e. samples from the real target distribution, let's say MNIST images in the shape (100, 1, 28, 28)
, where the first dimension is just the batch dimension, and then do trajectory inference like this:
traj = torchdiffeq.odeint(
lambda t, x: model.forward(t, x, generated_class_list),
mnist_images, # Tensor in shape `(100, 1, 28, 28)` e.g.
torch.linspace(1, 0, 2).to(device),
atol=1e-4,
rtol=1e-4,
method="dopri5",
)
Hope this helps.
Yep, this is the solution I would use too.
Great answer @ImahnShekhzadeh.
Closing this issue has it has been solved.
Dear all, Thanks for the great package!
I am writing to seek guidance with a doubt I have. With discrete flows, during training we learn a transformation from data space to, say, Gaussian space ( $u = f(x)$ ) and then invert to transform from Gaussian to data ( $x = f^{-1}(u)$ ). We can always take new data to the Gaussian space by using $u = f(x)$ .
In the case of this package, it's clear how we can use the models to start from a (Gaussian) noise space, get the initial conditions of the ODE and solve with torchdiffeq or torchdyn. The question is how can I use the model to compute the reverse trajectory, i.e. go from new data back to the (Gaussian) noise space? How do we compute $u = f(x)$ ?
Do I have to reverse the time steps order (from 1 to 0) and give the data as the initial conditions of the ODE?
I would be really grateful for any guidance or code example you could provide.
Best regards, Francesco