jiaor17 / DiffCSP

[NeurIPS 2023] The implementation for the paper "Crystal Structure Prediction by Joint Equivariant Diffusion"
MIT License
59 stars 18 forks source link

Query about algorithms of sample function #7

Closed hspark1212 closed 5 months ago

hspark1212 commented 6 months ago

Thank you for making the great repo. I have a query regarding the implementation of the sample function in diffusion.py. https://github.com/jiaor17/DiffCSP/blob/ee131b03a1c6211828e8054d837caa8f1a980c3e/diffcsp/pl_modules/diffusion.py#L130

According to the paper, Algorithm 2 outlines the process where the predictor (as seen in line 7 of Algorithm 2) precedes the corrector (lines 9-10 in Algorithm 2). However, in the sample function implementation, the corrector seems to be employed for x_t_minus_0.5 before the predictor is applied. This appears to be in contrast with the sequence described in Algorithm 2.

Could you please clarify if this implementation reflects a deliberate modification from the algorithm described in the paper, or if I might be misinterpreting the code or the algorithm?

Best, Hyunsoo Park

jiaor17 commented 5 months ago

Thank you for your interest in our repository and for your insightful question!

The predictor-corrector sampler is implemented according to Score-Based Generative Modeling through Stochastic Differential Equations, with the original codes provided in https://github.com/yang-song/score_sde/blob/main/sampling.py. Note that in Line 9 of Algorithm 2, we utilize $\sigma_{t-1}$ to control the corrector step size at step $t$ (sorry that the square subscript is missed in Line 9), and the algorithm is implemented in the sample function by

for t in range(T, 0, -1):
    corrector(t+1)
    predictor(t)

This "corrector-predictor" loop sequence results in the following execution order: corrector(T+1) -> predictor(T) -> corrector(T) -> predictor(T-1) -> corrector(T-1) -> ... -> corrector(2) -> predictor(1). Comparing with the sequence in Algorithm 2 unrolled as predictor(T) -> corrector(T) -> predictor(T-1) -> corrector(T-1) -> ... -> predictor(1) -> corrector(1), there exist 2 slight differences:

  1. corrector(1) is omitted, as $\sigma_0$ is defined as 0. Including a corrector step with a zero step size would be a no-op, hence its exclusion.
  2. An additional corrector(T+1) is conducted at the beginning, which does not obviously affect the performance in practice.

Hope that the above explanations could help!

hspark1212 commented 5 months ago

Thank you for your kind reply! Your explanation is clear, so I totally understood.

May I ask another question about loss values? In Figure 5, the train loss seems to reach a value of 0.5 in terms of MSE at the end of training (DiffCSP with FT). Could you share your thoughts on how much the loss function should decrease for effective training?

Thank you in advance for your reply.

jiaor17 commented 5 months ago

We have not thoroughly explored the relations between the training loss and the model performance. While in practice, the model should achieve train/val loss less than 0.75/0.90 for acceptable performance on MP-20.

hspark1212 commented 5 months ago

Your answer would be very helpful to understand DiffCSP. Thank you for your reply !