locuslab / ect

Consistency Models Made Easy
166 stars 5 forks source link

A question about the code #4

Closed 2019211753 closed 1 month ago

2019211753 commented 1 month ago

I noticed that the code between lines 86-95 in the file at https://github.com/locuslab/ect/blob/main/training/loss.py may produce NaN values:

        if r.max() > 0:
            torch.cuda.set_rng_state(rng_state)
            with torch.no_grad():
                D_yr = net(y + eps_r, r, labels, augment_labels=augment_labels)

            mask = r > 0
            D_yr = torch.nan_to_num(D_yr)
            D_yr = mask * D_yr + (~mask) * y
        else:
            D_yr = y

Could you help me understand why D_yr might contain NaN values? Thank you for your assistance!

2019211753 commented 1 month ago

Additionally, about the code between lines 80-97 in the file at https://github.com/locuslab/ect/blob/main/training/ct_training_loop.py, why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?

def generator_fn(
    net, latents, class_labels=None, 
    t_max=80, mid_t=None
):
    # Time step discretization.
    mid_t = [] if mid_t is None else mid_t
    t_steps = torch.tensor([t_max]+list(mid_t), dtype=torch.float64, device=latents.device)

    # t_0 = T, t_N = 0
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])  # [80, 0]

    # Sampling steps 
    x = latents.to(torch.float64) * t_steps[0]
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
        x = net(x, t_cur, class_labels).to(torch.float64)
        if t_next > 0:
            x = x + t_next * torch.randn_like(x) 
    return x
2019211753 commented 1 month ago

Additionally, is ECT a combination of Diffusion Pretraining and Consistency Tuning? However, I haven't come across any code for loading weights, such as model.load_state_dict(torch.load('pretrained_weights.pth')). Could you please assist me in locating it? Thank you in advance for your response, and I apologize for any oversight on my part.

Gsunshine commented 1 month ago

Hi @2019211753 ,

I'm sorry for getting back to you late. Regarding your questions,

  1. Could you help me understand why D_yr might contain NaN values?

    • The timestep into the networks is r.log() / 4 following EDM. If there are samples of $r=0$ in a batch, the network output will have NaN. (These NaN cannot be zeroed out by c_out before the network output in c_skip * x_r + c_out * net(x_r, r).) So, we set up a mask and force model outputs on $r=0$ to be $\mathbf{x}_0$.
    • In your own settings, you need to take care of the boundary condition. As long as your model doesn't map the time embedding through the log transformation or other transformations that lead to NAN, you would not need this. Make sure your function has a well-defined boundary.
  2. Why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?

    • Diffusion models' sampling starts from the noise distribution ($t \to \infty$ for VE SDE used in EDM). In practice, you can approximate this value by a large $t$ value, where we use $T=80$ here.
    • For your own setup, follow the sampler used by the pretrained diffusion models and reduce the NFEs to construct a sampler for your CMs. If it is not a VE SDE, you don't need this $T=80$.
  3. Code for loading weights

    • See here.
    • It loads the network snapshot (following EDM's checkpoint format).

Best, Zhengyang

2019211753 commented 1 month ago

Hi @2019211753 ,

I'm sorry for getting back to you late. Regarding your questions,

  1. Could you help me understand why D_yr might contain NaN values?
  • The timestep into the networks is r.log() / 4 following EDM. If there are samples of r=0 in a batch, the network output will have NaN. (These NaN cannot be zeroed out by c_out before the network output in c_skip * x_r + c_out * net(x_r, r).) So, we set up a mask and force model outputs on r=0 to be x0.
  • In your own settings, you need to take care of the boundary condition. As long as your model doesn't map the time embedding through the log transformation or other transformations that lead to NAN, you would not need this. Make sure your function has a well-defined boundary.
  1. Why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?
  • Diffusion models' sampling starts from the noise distribution (t→∞ for VE SDE used in EDM). In practice, you can approximate this value by a large t value, where we use T=80 here.
  • For your own setup, follow the sampler used by the pretrained diffusion models and reduce the NFEs to construct a sampler for your CMs. If it is not a VE SDE, you don't need this T=80.
  1. Code for loading weights
  • See here.
  • It loads the network snapshot (following EDM's checkpoint format).

Best, Zhengyang

Thank you very much for your assistance. It has resolved many questions I had. This is excellent work, and I believe it will become a renowned project.