atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.23k stars 100 forks source link

Any pointers on custom data training? #22

Closed jbmaxwell closed 1 year ago

jbmaxwell commented 1 year ago

Hello and thanks for this super interesting work!

I'm curious to try this with custom, image-like data, and just wondering what steps are required to get something going. Any guidance would be awesome.

atong01 commented 1 year ago

Hi @jbmaxwell,

We are planning on updating the code late next week with image experiments. If you need something sooner feel free to email me and we can work something out.

Best, Alex

jbmaxwell commented 1 year ago

Ah, good to hear! No huge rush as there's always plenty to try and plenty to do...

If you could post an update here when it's ready that would be awesome (so I get pinged when it's available).

Thanks!

maulberto3 commented 1 year ago

Hi @jbmaxwell The following can get you started on CNFs on MNIST. Note that this code is very rough and a lot of enhancements can be applied, it may even have errors. As yu can see, it's an adjustment over the provided notebook examples. Maybe @atong01 can shed some light on it. Let me know if it works for you. EDIT: I did some tweaks to code and some numbers start to generate.

# MNIST example

# Load MNIST dataset
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.,), (.2,))])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize the model
sigma = 0.1
dim = 28 * 28  # MNIST image size
model = MLP(dim=dim, w=64, time_varying=True)
print(f'{model.count_weights_biases():,}')
lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
start = time.time()
a, b = pot.unif(batch_size), pot.unif(batch_size)
for k in range(20000):
    optimizer.zero_grad()

    t = torch.rand(batch_size, 1)

    x0 = torch.randn(batch_size, dim)
    x1, _ = next(iter(train_loader))

    x0 = x0.view(x0.size(0), -1)  # Flatten the input images
    x1 = x1.view(x1.size(0), -1)

    # Resample x0, x1 according to transport matrix
    M = torch.cdist(x0, x1) ** 2
    M = M / M.max()
    pi = pot.emd(a, b, M.detach().cpu().numpy())

    # Sample random interpolations on pi
    p = pi.flatten()
    p = p / p.sum()
    choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size)
    i, j = np.divmod(choices, pi.shape[1])
    x0 = x0[i]
    x1 = x1[j]

    ut = (x1 - x0)**2

    mu_t = t * x1 + (1 - t) * x0
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn(batch_size, dim)
    vt = model(torch.cat([x, t], dim=-1))

    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    optimizer.step()

    if k % 500 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(torch_wrapper(model), 
                         solver="dopri5", 
                         sensitivity="adjoint", 
                         atol=1e-4, 
                         rtol=1e-4)
        with torch.no_grad():
            num_samples = 5
            traj = node.trajectory(
                torch.randn(num_samples, dim),
                t_span=torch.linspace(0, 1, 50),
            )
            image = traj[-1, :, :]  # [t, num_samples, dim]
            image = image.reshape(num_samples, 28, 28)
            fig, axs = plt.subplots(1, 5, constrained_layout=True, sharex=True, sharey=True, num=16)
            for i, ax in zip(range(num_samples), axs):
                ax.imshow(image[i, :, :])
            plt.show()
torch.save(model, f"{savedir}/cfm_v1.pt")
atong01 commented 1 year ago

Thanks @maulberto3!

I'm a bit surprised by this line ut = (x1 - x0)**2 otherwise looks similar to what we have. I would have expected no square.

I think for high quality generation you need a bigger network (we use Unet) but glad its working!

jbmaxwell commented 1 year ago

Thanks for the example, @maulberto3! You mention "what we have", @atong01, suggesting perhaps you have a solution in the works. Will you be adding something to the repo soon?

atong01 commented 1 year ago

Yes, ideally today or tomorrow at least as a branch.

maulberto3 commented 1 year ago

Thanks @maulberto3!

I'm a bit surprised by this line ut = (x1 - x0)**2 otherwise looks similar to what we have. I would have expected no square.

I think for high quality generation you need a bigger network (we use Unet) but glad its working!

Hello again, oh yea, sorry for that, yesterday I was doing a lot of trial and error, lots of tweaks, plotting results, and I just had a hunch, I remember that sometimes absolute distances are hard for nets to follow, so I thought about squaring them, and to my surprise, it worked, for early steps, the network started drawing numbers.

Also, I noticed the transform.Normalize is very sensitive to it. I guess since the network is trying to learn dynamics, the location of data distribution versus prior distribution is key. For example, with (1.0), (.5) it'll always draw static (no numbers), but with (0.,), (.2,), it does.

maulberto3 commented 1 year ago

Just to add some visualization to what I said

With ut = (x1 - x0)**2 image

With ut = x1 - x0 image

jbmaxwell commented 1 year ago

Hi @maulberto3,

I've tried to reproduce your example, just pulling in the imports and also saving PIL images (just simpler, since my machine is remote), but although my losses look correct, my images are just noise. Probably some silly error...

# MNIST example
import numpy as np
import torch
import torchvision.transforms as transforms
from torchvision import datasets
import torch.optim as optim
import time
import ot as pot
from torchdyn.core import NeuralODE
from PIL import Image

class torch_wrapper(torch.nn.Module):
    """Wraps model to torchdyn compatible format."""

    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, t, x, args=None):
        return self.model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))

class MLP(torch.nn.Module):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__()
        self.time_varying = time_varying
        if out_dim is None:
            out_dim = dim
        self.net = torch.nn.Sequential(
            torch.nn.Linear(dim + (1 if time_varying else 0), w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, w),
            torch.nn.SELU(),
            torch.nn.Linear(w, out_dim),
        )

    def forward(self, x):
        return self.net(x)

to_image = transforms.ToPILImage()

savedir = 'mnist'

# Load MNIST dataset
batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.,), (.2,))])
train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize the model
sigma = 0.1
dim = 28 * 28  # MNIST image size
model = MLP(dim=dim, w=64, time_varying=True)
lr = 0.001
optimizer = optim.Adam(model.parameters(), lr=lr)

# Training loop
start = time.time()
a, b = pot.unif(batch_size), pot.unif(batch_size)
for k in range(20000):
    optimizer.zero_grad()

    t = torch.rand(batch_size, 1)

    x0 = torch.randn(batch_size, dim)
    x1, _ = next(iter(train_loader))

    x0 = x0.view(x0.size(0), -1)  # Flatten the input images
    x1 = x1.view(x1.size(0), -1)

    # Resample x0, x1 according to transport matrix
    M = torch.cdist(x0, x1) ** 2
    M = M / M.max()
    pi = pot.emd(a, b, M.detach().cpu().numpy())

    # Sample random interpolations on pi
    p = pi.flatten()
    p = p / p.sum()
    choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size)
    i, j = np.divmod(choices, pi.shape[1])
    x0 = x0[i]
    x1 = x1[j]

    ut = (x1 - x0)**2

    mu_t = t * x1 + (1 - t) * x0
    sigma_t = sigma
    x = mu_t + sigma_t * torch.randn(batch_size, dim)
    vt = model(torch.cat([x, t], dim=-1))

    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    optimizer.step()

    if k % 500 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(torch_wrapper(model), 
                         solver="dopri5", 
                         sensitivity="adjoint", 
                         atol=1e-4, 
                         rtol=1e-4)
        with torch.no_grad():
            num_samples = 5
            traj = node.trajectory(
                torch.randn(num_samples, dim),
                t_span=torch.linspace(0, 1, 50),
            )
            image = traj[-1, :, :]  # [t, num_samples, dim]
            image = image.reshape(num_samples, 28, 28)
            new_im = Image.new('L', (num_samples * 28, 28))
            for n in range(num_samples):
                img_t = image[n]
                img = to_image(img_t)
                img.save(f'{savedir}/{n}.png')
                x_pos = n * 28
                new_im.paste(img, (x_pos, 0))
                new_im.save(f'{savedir}/ex_{k}.png')
torch.save(model, f"{savedir}/cfm_v1.pt")

Output:

1: loss 66.825 time 3.79
501: loss 36.866 time 11.01
1001: loss 34.911 time 11.72
1501: loss 33.202 time 10.88
[...]
12501: loss 25.530 time 10.71
13001: loss 25.652 time 10.91
13501: loss 24.351 time 10.71
14001: loss 25.944 time 10.79
14501: loss 26.947 time 10.77
15001: loss 25.749 time 10.42
[...]

But the image: ex_19500

Maybe a wrong import? I took the torch_wrapper and the MLP from the (moons) notebook. Seems like the sampling to generate the image is messing up.

maulberto3 commented 1 year ago

Hi @jbmaxwell I can't say, nets are always tricky to train, you can try removing ut = (x1 - x0)**2 for ut = x1 - x0 to closely follow the original implementation, also I didn't use PIL's Image, just plain plt.imshow. You also try checking this one img_t = image[n] jut in case it's slicing correctly.

jbmaxwell commented 1 year ago

Okay, yeah, I just find it easier to use PIL on a remote/headless server... The model is clearly converging, so it must be an issue with sampling. All the shapes look fine, and I get the same when just saving image[0] to a PIL.Image. Weird. The torch_wrapper is correct, yes? You were referring to the one in the notebook?

UPDATE: Okay, weird... I guess I don't understand pyplot! :) Just using this works (to save an image rather than messing around with remote plotting):

fig, axs = plt.subplots(1, 5, constrained_layout=True, sharex=True, sharey=True, num=16)
for i, ax in zip(range(num_samples), axs):
    ax.imshow(image[i, :, :])
plt.savefig(f'{savedir}/epoch_{k}_img_{i}.png')

UPDATE 2: Just missed normalizing the tensor before converting to PIL... gulp.

maulberto3 commented 1 year ago

UPDATE 2: Just missed normalizing the tensor before converting to PIL... gulp.

Update 2 remembers me, at least at this point of (C)FM models, how I feel how these models work in general. If I'm not mistaken, @atong01 can freely correct me or complement me, the point is to make the net learn many different prior-to-target distribution instances in time (interpolations) and map it to the general one step timestep prior-to-target distribution instance. Extra sauce can be things like OT helpers (how to transport between distributions with minimum cost). In other words, learn the dynamics going from prior to target without specifying the distribution transformations (like Normalizing Flows actually do with the Planar, Radial and so on type of layers). So, all this to say that, I've seen differences in training depending on where does the prior distribution starts to where it needs to go (target distribution). And you can see it with the moons example. Anyways, these type of models interest me as much as the NFs did. I remember the loss (MSE) but I wonder if we could apply a similar score-based loss judging from their alleged good properties for training.

jbmaxwell commented 1 year ago

I tried to revise this for a bigger input using a UNet with residual blocks (instead of the basic MLP), but I'm hitting problems with the dimension size difference caused by the torch.cat([x, t] used for the time-varying option (basically, the unet's up/down tensors differ by that single dimension). Any thoughts on the best fix?

maulberto3 commented 1 year ago

Any thoughts on the best fix?

You can always make your net work for you. If there is some point in the net that dims do not match, make a trick: Do a simple dense layer on the "problematic" input and then do the required layer. Trivial example to illustrate: if you have a layer where it expects two inputs of exactly 15 and 15 dimensions, but the 2nd input is of 16 (because of some prior layer work), apply the trick: pass the 16-dim input into a dense layer to make it output a 15-dim, then do the required layer, since now it's 15-dim, it'll work as expected.

On the other hand, I've found that normalizing to transforms.Normalize((0.015,), (0.015,)), numbers start to generate. That's also using the expected ut = (x1 - x0) (and not the -trivially- altered one above). image

jbmaxwell commented 1 year ago

Thanks for the tip, @maulberto3. That's actually exactly what I wound up doing—adding a linear layer on the input to adapt the shape. That's interesting about the normalization. I was also curious about adding conditioning info. For training this obviously is pretty simple, but I wasn't sure how it would be integrated in the generation/sampling routine (i.e., through the NeuralODE and torch_wrapper).

atong01 commented 1 year ago

Hi all,

Thanks for your patience, I've now updated the code to version 1 (Old code is frozen in branch v0). See examples/mnist_example.ipynb for our MNIST training notebook. There is also code in runner to train and test on CIFAR10. Let me know how it works for you!

@maulberto3 this is an interesting finding about the (x1 - x0)^2 and the setting of the initial distribution noise relative to the target. I'm very curious about this. I've heard that a similar trick can boost performance in diffusion models (i.e. increasing the source noise as compared to the data improves FID). Cool to see that it also seems to apply to flow models. I have my suspicions as to why this works, but I'm not certain.

maulberto3 commented 1 year ago

@atong01 My other thought is that that ut = (x1 - x0) ** 2 would strip out the directions for the prior distribution match the target distribution, i.e. if pixel 1 was off (between real data and net's prediction) by .1 and pixel 2 by -.2 (this is just how an ut's example would behave for a pair of pixels), then by squaring them before feeding it into the loss, we get .01 and .04. Result is: no directions (no signs). That would make the net learn, say, other dynamics than the desired one. Possible solution: add the directions back..., i.e. .01 and -.04...? Just some brainstorming.

atong01 commented 1 year ago

Ya I don't understand this either. Maybe it's a learning thing? If there's a lot of noise then this says to focus on the "large" directions more. Maybe this helps early on? It seems very weird that this works haha. I think it's possible that all the x1s are larger than the x0s and if so, maybe sign is not a problem? Although your finding that a smaller standard deviation of data working better seems to contradict this.

maulberto3 commented 1 year ago

Also @atong01, I think the target distribution should not be re-normalized to high planes like I did even when showing numbers i.e. this one transforms.Normalize((0.015,), (0.015,) makes pixels values go from [0,1] range to [-1, 60]!, because then if the net's parameters increase to reach that output, the SELU operation would be useless, and we end up with a linear combination inside the net...

atong01 commented 1 year ago

This makes sense. I was misremembering the arguments for Normalize. Maybe this also explains why the squaring works so well, if the data is scaled so large, all the vectors are positive and the square effectively exaggerates the initial learning rate.

kilianFatras commented 1 year ago

Hello, the library has seen a huge improvement and a major refactoring. I invite you to check the example folders where you have several notebooks examples. I am now closing this issue as it has been solved with 1.0.+ releases.