Closed 2019211753 closed 1 month ago
Additionally, about the code between lines 80-97 in the file at https://github.com/locuslab/ect/blob/main/training/ct_training_loop.py, why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?
def generator_fn(
net, latents, class_labels=None,
t_max=80, mid_t=None
):
# Time step discretization.
mid_t = [] if mid_t is None else mid_t
t_steps = torch.tensor([t_max]+list(mid_t), dtype=torch.float64, device=latents.device)
# t_0 = T, t_N = 0
t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # [80, 0]
# Sampling steps
x = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
x = net(x, t_cur, class_labels).to(torch.float64)
if t_next > 0:
x = x + t_next * torch.randn_like(x)
return x
Additionally, is ECT a combination of Diffusion Pretraining and Consistency Tuning? However, I haven't come across any code for loading weights, such as model.load_state_dict(torch.load('pretrained_weights.pth')). Could you please assist me in locating it? Thank you in advance for your response, and I apologize for any oversight on my part.
Hi @2019211753 ,
I'm sorry for getting back to you late. Regarding your questions,
Could you help me understand why D_yr might contain NaN values?
r.log() / 4
following EDM. If there are samples of $r=0$ in a batch, the network output will have NaN. (These NaN cannot be zeroed out by c_out
before the network output in c_skip * x_r + c_out * net(x_r, r)
.) So, we set up a mask and force model outputs on $r=0$ to be $\mathbf{x}_0$.Why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?
Code for loading weights
Best, Zhengyang
Hi @2019211753 ,
I'm sorry for getting back to you late. Regarding your questions,
- Could you help me understand why D_yr might contain NaN values?
- The timestep into the networks is
r.log() / 4
following EDM. If there are samples of r=0 in a batch, the network output will have NaN. (These NaN cannot be zeroed out byc_out
before the network output inc_skip * x_r + c_out * net(x_r, r)
.) So, we set up a mask and force model outputs on r=0 to be x0.- In your own settings, you need to take care of the boundary condition. As long as your model doesn't map the time embedding through the log transformation or other transformations that lead to NAN, you would not need this. Make sure your function has a well-defined boundary.
- Why does the latents mutiply the t_steps[0](always 80?) instead of adding the t_steps[0] * randn_like(x)?
- Diffusion models' sampling starts from the noise distribution (t→∞ for VE SDE used in EDM). In practice, you can approximate this value by a large t value, where we use T=80 here.
- For your own setup, follow the sampler used by the pretrained diffusion models and reduce the NFEs to construct a sampler for your CMs. If it is not a VE SDE, you don't need this T=80.
- Code for loading weights
- See here.
- It loads the network snapshot (following EDM's checkpoint format).
Best, Zhengyang
Thank you very much for your assistance. It has resolved many questions I had. This is excellent work, and I believe it will become a renowned project.
I noticed that the code between lines 86-95 in the file at https://github.com/locuslab/ect/blob/main/training/loss.py may produce NaN values:
Could you help me understand why D_yr might contain NaN values? Thank you for your assistance!