G-U-N / Phased-Consistency-Model

[NeurIPS 2024] Boosting the performance of consistency models with PCM!
https://g-u-n.github.io/projects/pcm/
Apache License 2.0
361 stars 11 forks source link

[Inference Strategy] Will inconsistent sampling steps cause suboptimal performance? #14

Open Luciennnnnnn opened 3 months ago

Luciennnnnnn commented 3 months ago

Hi, I notice that you utilize builtin DDIM scheduler in inference, however, the discretization of it is different with that used in training stage.

Specifically, the endpoints that split time interval into sub-trajectories are [0, 239, 499, 739, 999], however, DDIM scheduler use [0, 249, 499, 749, 999] in inference.

G-U-N commented 3 months ago

Thanks for the reminder. It might be caused by the num_ddim_timesteps in the training configs. Using num_ddim_timesteps=50 should make it consistent.

Luciennnnnnn commented 3 months ago

@G-U-N num_ddim_timesteps=50 indeed cause problem. You can check this with following code:

import numpy as np
import torch

ddim_timesteps = 50
multiphase = 4

step_ratio = 1000 // ddim_timesteps

ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * step_ratio
        ).round().astype(np.int64) - 1
ddim_timesteps = torch.from_numpy(ddim_timesteps).long()

inference_indices = np.linspace(
    0, len(ddim_timesteps), num=multiphase, endpoint=False
)
inference_indices = np.floor(inference_indices).astype(np.int64)
inference_indices = (
    torch.from_numpy(inference_indices).long().to(ddim_timesteps.device)
)

print(ddim_timesteps) # tensor([..., 719, 739, 759, 779, ...])
print(inference_indices) # tensor([ 0, 12, 25, 37])

print(ddim_timesteps[inference_indices]) # tensor([ 19, 259, 519, 759])

step_ratio = 1000 / 4
timesteps = np.round(np.arange(1000, 0, -step_ratio)).astype(np.int64)
timesteps -= 1

print(timesteps) # [999 749 499 249]

In training, the previous timestep in DDIM before step 759 is 739, which means we learn to jump to step 739 for every timestep after 739. However, in inference, we jump from 999 to 749, which cause inconsistency.

G-U-N commented 3 months ago

Did you print the end_timesteps when training? I remember it did print [0 249 499 749].

G-U-N commented 3 months ago

You can make a PR if you find it is indeed wrong.

Luciennnnnnn commented 3 months ago

end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.

import numpy as np
import torch

from diffusers import DDPMScheduler

def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        self.step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * self.step_ratio
        ).round().astype(np.int64) - 1
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)

        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(
            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
        )
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
        return x_prev

    def ddim_style_multiphase_pred(self, timestep_index, multiphase):
        inference_indices = np.linspace(
            0, len(self.ddim_timesteps), num=multiphase, endpoint=False
        )
        inference_indices = np.floor(inference_indices).astype(np.int64)
        inference_indices = (
            torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
        )
        expanded_timestep_index = timestep_index.unsqueeze(1).expand(
            -1, inference_indices.size(0)
        )
        valid_indices_mask = expanded_timestep_index >= inference_indices
        last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
        last_valid_index = inference_indices.size(0) - 1 - last_valid_index
        timestep_index = inference_indices[last_valid_index]
        return self.ddim_timesteps_prev[timestep_index]

if __name__ == '__main__':
    ddim_timesteps = 50
    multiphase = 4

    # 1. Create the noise scheduler and the desired noise schedule.
    noise_scheduler = DDPMScheduler.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="scheduler"
    )

    # The scheduler calculates the alpha and sigma schedule for us
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    solver = DDIMSolver(
        noise_scheduler.alphas_cumprod.numpy(),
        timesteps=noise_scheduler.config.num_train_timesteps,
        ddim_timesteps=ddim_timesteps,
    )

    index = torch.arange(ddim_timesteps)

    end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)

    print(end_timesteps)
    # tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 239, 239,
    #     239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
    #     499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
    #     739, 739, 739, 739, 739, 739, 739, 739])
Luciennnnnnn commented 3 months ago

I'm happy to make a PR after we agree on a proper way for it.

I think a suitable division on time interval would be:

endpoints = np.linspace(
                    -1, noise_scheduler.config.num_train_timesteps - 1, num=num_phase + 1, endpoint=True
                )# [-1, 249. 499. 749. 999.]
# endpoints[0] = 0 # we can also let start point be 0, it is fine.
endpoints = np.floor(endpoints).astype(np.int64)
endpoints = (
    torch.from_numpy(endpoints).long().to(start_timesteps.device)
)

In training, for every timestep in (endpoints[i], endpoints[i + 1]], we enforce it jump to endpoints[i]. This division is consistent with following DDIM scheduler:

DDIMScheduler(
            timestep_spacing="trailing",
            set_alpha_to_one = True,
        )

Furthermore, a proper condition is that endpoints is subset of solver.ddim_timesteps (so that we learn from endpoints directly). To satisfy the condition num_train_timesteps / num_phase needs to be a multiple of num_train_timesteps / ddim_timesteps, so num_phase=4, ddim_timesteps=40 or num_phase=5, ddim_timesteps=50 are feasible choices.

JaySimple commented 1 month ago

end_timesteps is [0, 239, 499, 739] according to bellow minimal reproducible code.

import numpy as np
import torch

from diffusers import DDPMScheduler

def extract_into_tensor(a, t, x_shape):
    b, *_ = t.shape
    out = a.gather(-1, t)
    return out.reshape(b, *((1,) * (len(x_shape) - 1)))

class DDIMSolver:
    def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
        self.step_ratio = timesteps // ddim_timesteps
        self.ddim_timesteps = (
            np.arange(1, ddim_timesteps + 1) * self.step_ratio
        ).round().astype(np.int64) - 1
        self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
        self.ddim_timesteps_prev = np.asarray([0] + self.ddim_timesteps[:-1].tolist())
        self.ddim_alpha_cumprods_prev = np.asarray(
            [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
        )
        self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
        self.ddim_timesteps_prev = torch.from_numpy(self.ddim_timesteps_prev).long()
        self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
        self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev)

    def to(self, device):
        self.ddim_timesteps = self.ddim_timesteps.to(device)
        self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)

        self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
        self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device)
        return self

    def ddim_step(self, pred_x0, pred_noise, timestep_index):
        alpha_cumprod_prev = extract_into_tensor(
            self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape
        )
        dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
        x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
        return x_prev

    def ddim_style_multiphase_pred(self, timestep_index, multiphase):
        inference_indices = np.linspace(
            0, len(self.ddim_timesteps), num=multiphase, endpoint=False
        )
        inference_indices = np.floor(inference_indices).astype(np.int64)
        inference_indices = (
            torch.from_numpy(inference_indices).long().to(self.ddim_timesteps.device)
        )
        expanded_timestep_index = timestep_index.unsqueeze(1).expand(
            -1, inference_indices.size(0)
        )
        valid_indices_mask = expanded_timestep_index >= inference_indices
        last_valid_index = valid_indices_mask.flip(dims=[1]).long().argmax(dim=1)
        last_valid_index = inference_indices.size(0) - 1 - last_valid_index
        timestep_index = inference_indices[last_valid_index]
        return self.ddim_timesteps_prev[timestep_index]

if __name__ == '__main__':
    ddim_timesteps = 50
    multiphase = 4

    # 1. Create the noise scheduler and the desired noise schedule.
    noise_scheduler = DDPMScheduler.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        subfolder="scheduler"
    )

    # The scheduler calculates the alpha and sigma schedule for us
    alpha_schedule = torch.sqrt(noise_scheduler.alphas_cumprod)
    sigma_schedule = torch.sqrt(1 - noise_scheduler.alphas_cumprod)
    solver = DDIMSolver(
        noise_scheduler.alphas_cumprod.numpy(),
        timesteps=noise_scheduler.config.num_train_timesteps,
        ddim_timesteps=ddim_timesteps,
    )

    index = torch.arange(ddim_timesteps)

    end_timesteps = solver.ddim_style_multiphase_pred(index, multiphase)

    print(end_timesteps)
    # tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0, 239, 239,
    #     239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 239, 499, 499, 499,
    #     499, 499, 499, 499, 499, 499, 499, 499, 499, 739, 739, 739, 739, 739,
    #     739, 739, 739, 739, 739, 739, 739, 739])

I think you're right. I did the same test as yours, and the results were the same. Therefore, the inconsistency between the timesteps of the training and the test phase, may lead to the sub-optimal results, and I'm wondering whether you've tried your revised version, how is the result? Thx.

Luciennnnnnn commented 1 month ago

@JaySimple I have tested my revised version, and it functions as intended. However, I'm currently working on a different configuration, so a direct comparison with the original implementation is not available. I'm confident that my implementation at least maintains the existing performance level, if not improves it.

G-U-N commented 1 month ago

Great, I will check it in time.