jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
305 stars 50 forks source link

About rotation loss dynamics #18

Closed amorehead closed 1 year ago

amorehead commented 1 year ago

Hello.

I was curious if you have observed any particular behaviors that you think are worth noting when it comes to our rotation and translation diffusion losses. Specifically, when you train your own models, do you notice the translation loss being easier for the model to optimize, or are both loss terms equally challenging for the model?

jasonkyuyim commented 1 year ago

Hi, I believe the translation is harder to optimize. Once the model gets the translations right, the rotations becomes easy. I've done mini experiments of fixing rotations or translations then trying to learn the other component. Rotations were learned immediately. I recommend using the wandb evaluation set-up in the code base. The losses don't tell you much. The best way to see how well you're doing is to run sampling between training epochs and evaluate metrics like steric clashes and secondary structure composition.

amorehead commented 1 year ago

Hi, @jasonkyuyim. I hope you are doing well. I was curious if you have also observed (what appear to be) instabilities in the training rot_loss for this model over the course of the two weeks it takes to train a model to convergence. The picture below shows my loss curves for this model after two days of training on 2 A100 GPUs using DDP. Interestingly, below this picture are all my other training losses, which all appear to be steadily decreasing or converging. image image image image

In your experience, is this expected behavior? Are the rotation loss dynamics simply much harder to optimize compared to translations or backbone atom positions?

amorehead commented 1 year ago

As a follow-up, it looks like the rotation loss lowers slightly, but still maintains a large degree of variance throughout training. I'm not sure if this indicates there is somehow a bug present as I am training this model: image

jasonkyuyim commented 1 year ago

Hi, we just posted a update that affects the rotation score learning. Please take a look at our README. I'm not sure if it's related to your problem but it might help. Though I must say I've never seen the rotation loss get that high. Are you sure you didn't change anything from the repo?

amorehead commented 1 year ago

Thanks, @jasonkyuyim. I will take a closer look at the updates shortly. I appreciate your quick response here. I'll close this issue for now, as I suspect it will be addressed with the rotation loss updates. Great catch, BTW!

P.S. For reference, the above screenshots come from reproducing your experiments in a rewrite of this repository using PyTorch Lightning + Hydra. Hopefully I didn't miss any important details in my implementation along the way (e.g., proper DDP batching with the custom Sampler, handled automatically by Lightning) 😉 . I hope to make this Lightning version of the repository publicly available in the near future.

amorehead commented 1 year ago

For those curious, I can confirm that the latest rotation loss updates have fixed the issue that spawned this thread. Thanks @jasonkyuyim for pushing these fixes out so quickly! image

jamaliki commented 10 months ago

Hi @amorehead ,

Did you manage to get everything working with Lightning and your sampler? Is it appreciably faster, because if it is do you mind sharing the code?

Best, Kiarash.

jasonkyuyim commented 10 months ago

Hi, the frameflow code is probably of interest to you. I started from FrameDiff and refactored it to work with pytorch lightning and support for DDP (training and inference). A future update will add SE(3) diffusion to frameflow. https://github.com/microsoft/flow-matching

jamaliki commented 10 months ago

Thanks @jasonkyuyim , I saw this and am taking a look. Great work!

amorehead commented 10 months ago

Agreed. Great work!