atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
875 stars 60 forks source link

Modelling conditional pdfs #100

Closed francesco-vaselli closed 5 months ago

francesco-vaselli commented 5 months ago

Dear all, Thanks for the great package!

I am working with some data $x$ following an unknown pdf $p(x|c)$. Here, $c$ is some additional information which we would like to give as input to the model in order to learn the correct correlations and dependencies between the target $x$ and $c$. In this way, when transforming the latent noise space $u$, we can get different $x$ according to the input $c$ (simplest example: get only one half of the two moons dataset based on 0/1 input flag).

I did implement a solution for this use case. I would like to get your feedback on whether this is a reasonable solution, or something has already been done to address this type of problem.

The main steps are as follows:

  1. We define the base model to take an additional input (context), and we create a model wrapper which evolves the data with the base model and assign 0s to the derivatives of the context (in this way it remains constant through all the trajectory).
  2. We train the model passing the context input in the forward call
  3. We sample using the wrapped model and concatenating initial conditions and context for the model wrapper to work directly with ODE solvers such as torchdiffeq

The code is as follows:

We define the base model to take an additional input (context), and we create a model wrapper which evolves the data with the base model and assign 0s to the derivatives of the context (in this way it remains constant through all the trajectory).

class ModelWrapper(nn.Module):
    def __init__(self, base_model, context_dim=6):
        """
        Wraps a base model to only evolve the first part of the input specifying a certain context using the model.

        Args:
            base_model (nn.Module): The base model to wrap.
        """
        super(ModelWrapper, self).__init__()
        self.base_model = base_model.eval()
        self.context_dim = context_dim

    def forward(self, t, x, **kwargs):
        """
        Forward pass of the wrapped model.

        Args:
            t (torch.Tensor): The time tensor.
            x (torch.Tensor): The input tensor: concatenation of [actual input, context].
            **kwargs: Additional keyword arguments.

        Returns:
            torch.Tensor: The output tensor.
        """
        xt, context = x[:, :-self.context_dim], x[:, -self.context_dim:]
        t_broadcasted = t.expand(x.shape[0], 1)
        # Only evolve xt using the model (notice the additional input in the forward).
        dxt_dt = self.base_model(xt, context=context, flow_time=t_broadcasted)

        # Concatenate the derivatives of xt with zeros for context to keep their values unchanged
        zeros_for_context = torch.zeros_like(context)
        dx_dt = torch.cat([dxt_dt, zeros_for_context], dim=-1)

        return dx_dt

Then in the training loop we do something like:

for i in range(0, len(X_train), batch_size):
                X_batch = X_train[i : i + batch_size]
                Y_batch = Y_train[i : i + batch_size] # NOTE: this is the context

                optimizer.zero_grad()

                x0 = noise_dist(X_batch.shape[0], X_batch.shape[1]).to(device)

                t, xt, ut = FM.sample_location_and_conditional_flow(x0, X_batch)

                vt = model(xt, context=Y_batch, flow_time=t[:, None]) 
                loss = torch.mean((vt - ut) ** 2)
                train_loss += loss.item()
                loss.backward()
                optimizer.step()

                # Update the progress bar
                pbar.update(1)
                pbar.set_postfix({"Batch Loss": loss.item()})

While for sampling new data from noise:

print("Starting sampling")
            model.eval()
            samples_list = []
            # NOTE the call to model wrapper
            sampler = ModelWrapper(model, context_dim=context_dim)

            t_span = torch.linspace(0, 1, timesteps).to(device)
            with torch.no_grad():
                with tqdm(
                    total=len(X_test) // test_batch_size,
                    desc="Sampling",
                    dynamic_ncols=True,
                ) as pbar:
                    for i in range(0, len(X_test), test_batch_size):
                        Y_batch = Y_test[i : i + test_batch_size, :]
                        # protection against underflows in torchdiffeq solver
                        while True:
                            try:
                                x0_sample = noise_dist(len(Y_batch), X_test.shape[1]).to(
                                    device
                                )
                                # NOTE the context is concatenated to the initial conditions
                                # for the wrapper to work
                                initial_conditions = torch.cat([x0_sample, Y_batch], dim=-1)

                                # NOTE we take only the last timestep

                                samples = odeint(
                                    sampler,
                                    initial_conditions,
                                    t_span,
                                    atol=1e-5,
                                    rtol=1e-5,
                                    method="dopri5",
                                )[timesteps - 1, :, : X_test.shape[1]]

This approach works perfectly for our use case. Do you think it's reasonable and efficient enough?

I would appreciate any guidance whatsoever, and if the solution seems interesting, I would be more than happy to work on a pull request!

Best regards Francesco

atong01 commented 5 months ago

Hi Francesco!

The solution seems reasonable enough to me. This is fairly similar to the solution I used in this notebook https://github.com/atong01/conditional-flow-matching/blob/main/examples/images/conditional_mnist.ipynb

except I wrapped the conditional model to be unconditional before the integration. I think this is slightly easier (if your conditions are constant over time), but this solution should work just fine.

--Alex

kilianFatras commented 5 months ago

Closing this issue has it has been solved.