jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
320 stars 51 forks source link

Some questions about IGSO3 and score function. #9

Closed longlongman closed 1 year ago

longlongman commented 1 year ago

Great work! Learn a lot from it. However, some questions about the IGSO3 and the score function confuse me.

  1. Different IGSO3 PDF functions. In proposition 3.3, the paper says that the PDF of the rotation angle $\omega$ is $$f(\omega)=\sum{l \in \mathbb{N}}{(2l+1)e^{-l(l+1)t/2}\frac{sin((l+0.5)\omega)}{sin(\omega/2)}}.$$ However, in Leach's paper, they define the PDF as $$f(\omega)=\frac{1-cos\omega}{\pi} \sum{l \in \mathbb{N}}{(2l+1)e^{-l(l+1)t/2}\frac{sin((l+0.5)\omega)}{sin(\omega/2)}}.$$ What is the difference between these two PDF functions?
    Moreover, in the source code, the actual PDF function is the second form. I wonder whether the score function in proposition 3.4 still holds in this case (it seems the term $\frac{1-cos \omega}{\pi}$ is ignored?).

https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L127-L128 https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L34-L50

  1. Implementation of the score function. In proposition 3.4, the score function is $$\nabla p{t|0}(r^{(t)}|r^{(0)}) = \frac{r^{(t)}}{\omega^{(t)}} \log \{r^{(0, t)}\} \frac{\partial{\omega}f(\omega^{(t)}, t)}{f(\omega^{(t)}, t)}.$$ However, in the source code, when deal with the ground truth rotation score, it seems $r^{(t)}$ is missing?

When deal with predicted rotation score, init_rots (rot_t) is applied to rot_score. https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/model/ipa_pytorch.py#L648-L663

Meanwhile, when deal with the ground truth, rot_t is not applied to rot_score. https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L286-L288

https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L219-L240

https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L134-L135

https://github.com/jasonkyuyim/se3_diffusion/blob/2c6405fb5a1c4a23354cc697c18ff478dff8546b/data/so3_diffuser.py#L53-L84

blt2114 commented 1 year ago

Hello -- thanks for the interest! We worked on our implementation at the same time as we were we working out our derivations, which led to some disagreements between naming of variables and object types in our code vs. in our preprint. We have refactored our code to agree with the our descriptions in the paper in another branch which I have just added (https://github.com/jasonkyuyim/se3_diffusion/tree/unsupported_refactor). Please take a look there to see if this clears up your questions. We have confirmed that this branch works for inference (and matches the main branch with using the trained weight checkpoint), and will merge this branch into main once we have finished debugging training.

I'll now reply to the specific questions:

  1. The difference is that the top expression is a density on SO(3) while the bottom expression is a density on the angle of rotation on the interval [0, pi). More formally the first expression characterizes the Radon-Nikodym derivative of the measure implied by the transition kernel of the of Brownian motion with respect to the Haar measure / volume form on SO(3). It is this first quantity (no the second) which is needed for the applying the Riemannian time reversal theorem.

  2. This should be clearer from the new branch I have pointed to. Please let us know if it is not.

meneshail commented 1 year ago

Hello -- thanks for the interest! We worked on our implementation at the same time as we were we working out our derivations, which led to some disagreements between naming of variables and object types in our code vs. in our preprint. We have refactored our code to agree with the our descriptions in the paper in another branch which I have just added (https://github.com/jasonkyuyim/se3_diffusion/tree/unsupported_refactor). Please take a look there to see if this clears up your questions. We have confirmed that this branch works for inference (and matches the main branch with using the trained weight checkpoint), and will merge this branch into main once we have finished debugging training.

I'll now reply to the specific questions:

  1. The difference is that the top expression is a density on SO(3) while the bottom expression is a density on the angle of rotation on the interval [0, pi). More formally the first expression characterizes the Radon-Nikodym derivative of the measure implied by the transition kernel of the of Brownian motion with respect to the Haar measure / volume form on SO(3). It is this first quantity (no the second) which is needed for the applying the Riemannian time reversal theorem.
  2. This should be clearer from the new branch I have pointed to. Please let us know if it is not.

Thanks for the brilliant and informative work! I have a related question on this topic. Could you present a more detailed derivation of this scaling factor connecting density on SO(3) and on [0, pi), or suggest some English reference that I could learn from? In Leach's paper, the paper he refers to is in Russian, and I can only find English papers discussing IGSO3 density on SO(3), but not this scaling factor and the connection to [0, pi).

blt2114 commented 1 year ago

Sure thing. This reference gives a nice derivation of density function on (0, pi). Hope this helps!

Rummler, Hansklaus. "On the distribution of rotation angles how great is the mean rotation angle of a random rotation?" The Mathematical Intelligencer 24.4 (2002) https://core.ac.uk/download/pdf/159153606.pdf

meneshail commented 1 year ago

Sure thing. This reference gives a nice derivation of density function on (0, pi). Hope this helps!

Rummler, Hansklaus. "On the distribution of rotation angles how great is the mean rotation angle of a random rotation?" The Mathematical Intelligencer 24.4 (2002) https://core.ac.uk/download/pdf/159153606.pdf

Cool! That's really what I need! Thanks :)

longlongman commented 1 year ago

Hello -- thanks for the interest! We worked on our implementation at the same time as we were we working out our derivations, which led to some disagreements between naming of variables and object types in our code vs. in our preprint. We have refactored our code to agree with the our descriptions in the paper in another branch which I have just added (https://github.com/jasonkyuyim/se3_diffusion/tree/unsupported_refactor). Please take a look there to see if this clears up your questions. We have confirmed that this branch works for inference (and matches the main branch with using the trained weight checkpoint), and will merge this branch into main once we have finished debugging training.

I'll now reply to the specific questions:

  1. The difference is that the top expression is a density on SO(3) while the bottom expression is a density on the angle of rotation on the interval [0, pi). More formally the first expression characterizes the Radon-Nikodym derivative of the measure implied by the transition kernel of the of Brownian motion with respect to the Haar measure / volume form on SO(3). It is this first quantity (no the second) which is needed for the applying the Riemannian time reversal theorem.
  2. This should be clearer from the new branch I have pointed to. Please let us know if it is not.

@blt2114 Hi, I have checked the updated code and found that it aligns with the preprint well. I wonder whether the checkpoint now provided is trained with the updated code. I have tried to reimplement the benchmark results by re-training but can not get what is expected. Is code inconsistency a possible reason? image

p.s. In the preprint, the paper says that it will take 1 week to train the model with 2 A100 GPUs in section 5. However, in appendix J.1, the paper also says that 2 weeks are needed for training the model under the same setting. Is this a typo?

p.p.s. In the base.yml config, the max epoch is 50, and I find that I can get 7~8 epochs each day with 2 A100 GPUs. So I guess that the total training time is 1 week.

jasonkyuyim commented 1 year ago

Hi, thanks for reporting this mismatch. It look like there is a typo. The config is incorrect. It should be 95 epochs. I just pushed a fix. https://github.com/jasonkyuyim/se3_diffusion/commit/6bf12ed7b59b8b917223de15f17ca33feda8abcf

You should be able to continue training by setting warm_start: <ckpt_dir> in base.yml. The numbers look like it's getting there so hopefully this fixes it but please report back and we'll help out.

Thank you for pointing out the inconsistency in the paper. It should be ~2 weeks on 2 A100s which was a estimate but I'm going to change it to be more specific: "95 epochs over 10 days". We'll update the paper very soon.

jasonkyuyim commented 1 year ago

Here are some plots of what the training lossess should look like.

rotation_loss translation_loss num_epochs
longlongman commented 1 year ago

@blt2114 @jasonkyuyim Thanks for your quick reply! One more question, in your master branch code, I found that there is a difference between the calculation of the ground truth score and the predicted score. Specifically, the correct formula of the score is

$$\nabla p{t|0}(r^{(t)}|r^{(0)}) = \frac{r^{(t)}}{\omega^{(t)}} \log \{r^{(0, t)}\} \frac{\partial{\omega}f(\omega^{(t)}, t)}{f(\omega^{(t)}, t)}.$$

However, when dealing with the ground truth score, the implementation like

https://github.com/jasonkyuyim/se3_diffusion/blob/6bf12ed7b59b8b917223de15f17ca33feda8abcf/data/so3_diffuser.py#L273-L288

https://github.com/jasonkyuyim/se3_diffusion/blob/6bf12ed7b59b8b917223de15f17ca33feda8abcf/data/so3_diffuser.py#L219-L240

I believe that np.interp(omega, self.discrete_omega, self._score_norms[self.t_to_idx(t)] )[:, None] is $\frac{\partial_{\omega}f(\omega^{(t)}, t)}{f(\omega^{(t)}, t)}$, vec is $\log \{r^{(0, t)}\}$, omega[:, None] + eps is $\omega^{(t)}$. It seems that $r^{(t)}$ is missing and I can not find $r^{(t)}$ in the succeeding code (se3_diffuser.py and pdb_data_loader.py) either.

When dealing with the predicted score, the implementation like

https://github.com/jasonkyuyim/se3_diffusion/blob/6bf12ed7b59b8b917223de15f17ca33feda8abcf/model/ipa_pytorch.py#L648-L663

https://github.com/jasonkyuyim/se3_diffusion/blob/6bf12ed7b59b8b917223de15f17ca33feda8abcf/data/se3_diffuser.py#L119-L125

https://github.com/jasonkyuyim/se3_diffusion/blob/6bf12ed7b59b8b917223de15f17ca33feda8abcf/data/so3_diffuser.py#L242-L267

At the beginning (torch_score function), the $r^{(t)}$ is also missing. But at the end of ipa_pytorch.py, it seems that rot_score = init_rots.apply(rot_score) gives us the $r^{(t)}$.

In summary, I found that $r^{(t)}$ is missing in the ground truth score but exists in the predicted score. Is this a bug or something? Will it affect the performance significantly?

p.s. I have checked the refactor, and $r^{(t)}$ exists in both scores.

longlongman commented 1 year ago

Sorry, I made a mistake here. I did not notice that the code is based on so(3) (Lie algebra), while the paper is based on SO(3) (Group). I guess when the score is written in so(3) that we do not need r(t). For rot_score = init_rots.apply(rot_score) in so3_diffuser.py, I guess that its purpose is making the model rotation invariant.

jasonkyuyim commented 1 year ago

@longlongman I've pushed a commit that should fix training. https://github.com/jasonkyuyim/se3_diffusion/commit/2ac1ce9be866120047cfcb506e628729fd8eced8 You were right to point out we were missing a r^(t) in the ground truth score. This got lost when we were trying to refactor the score framework to be more readable. I've now pushed the invariant predicted rotation score (it becomes equivariant when we take the reverse euler step). It was our bad in not testing the refactor well enough... We'll get around to refactoring soon. Unfortunately this means the score math won't match one-to-one with the paper.

I've also corrected training hyperparameters that I missed.

The training curves on my end match the published model after a day of training but I'll be monitoring to make sure it can recover the true training behavior. Sorry again for the confusion and thanks so much for using our code and pointing out our bugs!