atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.05k stars 79 forks source link

Provide demos for conditional generation #30

Closed bugsz closed 1 year ago

bugsz commented 1 year ago

Hi, I noticed that you have provided a demo on unconditional generation on the MNIST dataset. Do you plan to provide some demos on conditional generation? That would be of great help to me. Thanks again for your valuable work and codebase!

gevmin94 commented 1 year ago

I have extended the provided MNIST example to include specific number conditional generation. During inference, I encountered a challenge in passing additional conditional inputs, specifically digit labels, using NeuralODE from torchdyn. Instead, I used the odeint function from torchdiffeq in the following manner:

 l = torch.tensor([digit_label]).to(device)
 out = torchdiffeq.odeint(
          lambda t, x: model.forward(t, x, l),
          torch.randn(1,1,28,28).to(device), 
          torch.linspace(0, 1, 2).to(device),
          atol=1e-4, 
          rtol=1e-4,
          method='dopri5'
  )   

In a recent article from Meta's Voicebox, they also mentioned the usage of torchdiffeq. The results obtained using the odeint function were quite similar to the unconditional case. However, I observed that the solver does not end exactly at 1.0 time in the given t_span when using torchdiffeq. Instead, it slightly exceeds the specified time, for example, reaching 1.1. When I set t_span to [0, 2.0], the solver exceeds the time significantly, reaching 11.1. This suggests that further exploration of advanced parameters may be necessary to address this issue.

Upon experimenting with different label conditioning settings, I found that the model did not generate numbers of the given class label. Instead, it produced random numbers similar to the unconditional case. In the original Unet implementation, there is an option to condition label embeddings by summing timestamp embeddings. I tried both concatenation and summation options for conditioning, and I also attempted conditioning by summing label embeddings to the noise input while changing the embedding dimension to 28*28. I experimented with various hyperparameters such as learning rate, increased the number of residual blocks, adjusted model dimensions, and trained for additional epochs. However, in all the attempted settings, the model behaved as if it were the unconditional case, generating random numbers and ignoring the class label. Furthermore, when I tried to simplify the distribution by choosing just two class numbers [7, 8], the model still failed to generate the desired results.

I wonder if I missed something crucial or if there's a need to modify the loss computation or any other aspect of the model.

bugsz commented 1 year ago

Thanks for your detailed explanation! I will look into it and see if I can find something new!

atong01 commented 1 year ago

@gevmin94 This is very interesting! Are you willing to share this code? Would love to take a look. I'm busy with NeurIPS rebuttals through this week but will definitely be looking into this next week.

The model training needs to be modified a bit for the conditional generation to work, but this should not be too challenging.

I'm not super surprised the solver slightly exceeds the time specified because of the adaptive solver used, but the 11.1 seems quite a bit larger than I would expect.

gevmin94 commented 1 year ago

I have opened PR for my current changes. Additionally, I can provide a summary of the changes below.

Training Loop:

for epoch in range(n_epochs):
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        x1 = data[0].to(device)
        y = data[1].to(device)
        x0 = torch.randn_like(x1)
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
        vt = model(t, xt, y)
        loss = torch.mean((vt - ut) ** 2)
        loss.backward()
        optimizer.step()
        print(f'epoch: {epoch}, steps: {i}, loss: {loss.item():.4}', end='\r')

Inference:

USE_TORCH_DIFFEQ = True
digit_label = 8
l = torch.tensor([digit_label]*100).to(device)
with torch.no_grad():
    if USE_TORCH_DIFFEQ:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x, l),
            torch.randn(100,1,28,28).to(device), 
            torch.linspace(0, 1, 2).to(device),
            atol=1e-4, 
            rtol=1e-4,
            method='dopri5'
        )   
    else:
        traj = node.trajectory(
            torch.randn(100, 1, 28, 28).to(device),
            t_span=torch.linspace(0, 1, 2).to(device),
        )
grid = make_grid(
    traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
)
img = ToPILImage()(grid)
plt.imshow(img)
plt.show()

Is there anything else that needs to be modified in the training process besides passing labels to the model?

bugsz commented 1 year ago

I have opened PR for my current changes. Additionally, I can provide a summary of the changes below.

Training Loop:

for epoch in range(n_epochs):
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        x1 = data[0].to(device)
        y = data[1].to(device)
        x0 = torch.randn_like(x1)
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
        vt = model(t, xt, y)
        loss = torch.mean((vt - ut) ** 2)
        loss.backward()
        optimizer.step()
        print(f'epoch: {epoch}, steps: {i}, loss: {loss.item():.4}', end='\r')

Inference:

USE_TORCH_DIFFEQ = True
digit_label = 8
l = torch.tensor([digit_label]*100).to(device)
with torch.no_grad():
    if USE_TORCH_DIFFEQ:
        traj = torchdiffeq.odeint(
            lambda t, x: model.forward(t, x, l),
            torch.randn(100,1,28,28).to(device), 
            torch.linspace(0, 1, 2).to(device),
            atol=1e-4, 
            rtol=1e-4,
            method='dopri5'
        )   
    else:
        traj = node.trajectory(
            torch.randn(100, 1, 28, 28).to(device),
            t_span=torch.linspace(0, 1, 2).to(device),
        )
grid = make_grid(
    traj[-1, :100].view([-1, 1, 28, 28]).clip(-1, 1), value_range=(-1, 1), padding=0, nrow=10
)
img = ToPILImage()(grid)
plt.imshow(img)
plt.show()

Is there anything else that needs to be modified in the training process besides passing labels to the model?

I notice that the result is still not that satisfying, but I'm not sure what is missing here.

Besides, I tried another task, which is to complete the number given some part of the original image is masked. I tried to concatenate condition (masked image) and random noise on the channel dimension and make a linear projection before going into the U-Net, which resembles the Voicebox (but still very different, as Voicebox concatenate all of the hidden features and make a projection, while I deal with the raw image). Below is the modification.

On the model side

C, H, W = dim
self.ctx_proj = nn.Parameter(torch.randn(H, W, C * 2, C))

if ctx is not None:
    assert x.shape == ctx.shape
    x = torch.cat((x, ctx), dim=1) # [B, 2, 28, 28]
    x = torch.matmul(x.permute(2, 3, 0, 1), self.ctx_proj).permute(2, 3, 0, 1) # [B, 1, 28, 28]

On the training and inference side, we only need to provide a condition image as ctx

for epoch in range(n_epochs):
        total_losses = []
        print(f"Training epoch {epoch}")
        for i, data in tqdm(enumerate(train_loader), total=len(train_loader)):
            optimizer.zero_grad()
            x1 = data[0].to(device)
            x0 = torch.randn_like(x1)

            cond = gen_cond(x1)
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            vt = model(t, xt, ctx=cond)
            # print(xt.shape)
            # print(vt.shape)
            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            # return 0

            total_losses.append(loss.item())
            optimizer.step()
        import numpy as np
        print(f"Loss: {np.mean(total_losses)}")

For preprocessing

def gen_cond(data):
    B, _, H, W = data.shape

    cond = data.clone()
    cond[:, :, H // 4: H // 2, W // 4: W // 2] = 0
    return cond 

However, the output image is still a random one. Here is an example. The first two images are raw image and masked image, respectively, and the other images show the generation process. image

atong01 commented 1 year ago

This is great! @gevmin94, there is one subtle bug in the training loop. I'll make a change to fix this and post here when I have it working. Basically, when sampling the OT sampler scrambles the rows of the xts relative to the y variables. We need to scramble the condition in the same way. Note that both of your PRs should work as is for the ConditionalFlowMatcher instead of the ExactOptimalTransportConditionalFlowMatcher. I'm adding a way to scramble the conditions in the same way and will post here when its done.

I think what @bugsz is trying to do with image generation will also be easy with this modification.

atong01 commented 1 year ago

See updates to PR #33

gevmin94 commented 1 year ago

I now understand what was previously unclear. Thank you for providing clarification and resolution of the issue. Before delving into coding, I should thoroughly read your impressive work. I was initially not aware of the distinction between Cond. OT and Marginal OT. Without having read your work, I mistakenly believed that OT-CFM was the same as Lipman's FM algorithm, that's why even not doubting OT-CFM changes the order of minibatch items. The terminology surrounding OT was perplexing to me.

Thanks again for your assistance, this saves a lot of time for me.