jiaor17 / DiffCSP

[NeurIPS 2023] The implementation for the paper "Crystal Structure Prediction by Joint Equivariant Diffusion"
MIT License
57 stars 18 forks source link

Predictor-corrector sampling reversed (?) #15

Open fedeotto opened 1 month ago

fedeotto commented 1 month ago

I have a small perplexity about the sampling described in DiffCSP (looking specifically at diffusion_w_type.py). I'm not sure that I'm interpreting it right, but Looking at the DiffCSP paper and the code simultaneously, I seem to understand that the roles of predictor and corrector are reversed (?) Below the specific block of code under examination:

# Corrector
rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

step_size = step_lr * (sigma_x / self.sigma_scheduler.sigma_begin) ** 2
std_x = torch.sqrt(2 * step_size)

pred_l, pred_x, pred_t = self.decoder(time_emb, t_t, x_t, l_t, batch.num_atoms, batch.batch)
pred_x = pred_x * torch.sqrt(sigma_norm)

x_t_minus_05 = x_t - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t
l_t_minus_05 = l_t
t_t_minus_05 = t_t

# Predictor
rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

adjacent_sigma_x = self.sigma_scheduler.sigmas[t-1] 
step_size = (sigma_x ** 2 - adjacent_sigma_x ** 2)
std_x = torch.sqrt((adjacent_sigma_x ** 2 * (sigma_x ** 2 - adjacent_sigma_x ** 2)) / (sigma_x ** 2))   

pred_l, pred_x, pred_t = self.decoder(time_emb, t_t_minus_05, x_t_minus_05, l_t_minus_05, batch.num_atoms, batch.batch)
pred_x = pred_x * torch.sqrt(sigma_norm)

x_t_minus_1 = x_t_minus_05 - step_size * pred_x + std_x * rand_x if not self.keep_coords else x_t
l_t_minus_1 = c0 * (l_t_minus_05 - c1 * pred_l) + sigmas * rand_l if not self.keep_lattice else l_t
t_t_minus_1 = c0 * (t_t_minus_05 - c1 * pred_t) + sigmas * rand_t

traj[t - 1] = {
    'num_atoms' : batch.num_atoms,
    'atom_types' : t_t_minus_1,
    'frac_coords' : x_t_minus_1 % 1.,
    'lattices' : l_t_minus_1              
}

It seems to me that x_t_minus_05 is retrieved via Langevin Dynamics (that should play the role of predictor here (?)) and the final x_t_minus_1 is computed via the iteration rule involving adjacent_sigma_x. This doesn’t seem in line with the original paper from Song et Al., so I was just wondering whether it is a deliberate choice to exchange the role of predictor and corrector?

jiaor17 commented 1 month ago

Hi, This issue is similar to #7. You can see the discussions there for more details.

fedeotto commented 1 month ago

Apologies for the oversight!. By reading the answer provided in #7 , I still have some perplexities about the step_size definition of the corrector, as it is stated that $\sigma_{t-1}$ is utilized, but in the code what I see is the following:

# Corrector

rand_l = torch.randn_like(l_T) if t > 1 else torch.zeros_like(l_T)
rand_t = torch.randn_like(t_T) if t > 1 else torch.zeros_like(t_T)
rand_x = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T)

step_size = step_lr * (sigma_x / self.sigma_scheduler.sigma_begin) ** 2
std_x = torch.sqrt(2 * step_size)

and sigma_x is defined above by sigma_x = self.sigma_scheduler.sigmas[t] so I don't see where $\sigma_{t-1}$ is used for the corrector .. (?)