jasonkyuyim / se3_diffusion

Implementation for SE(3) diffusion model with application to protein backbone generation
https://arxiv.org/abs/2302.02277
MIT License
328 stars 53 forks source link

math in SE(3) diffusion model #37

Open knightzzz9w opened 7 months ago

knightzzz9w commented 7 months ago

Dear author! I have read your paper and your code. It's very amazing. Here's some questions. Firstly, I want to know how you design your forward process? Why don't your consider a formula like: Ht = Exp( sqrt(1-alpha_bar)zt + sqrt(alpha_bar)Log(H0) ) , and noise on se(3) and change it to SE(3); And the function of scale factor lambda. Second, considering that your forward process includes calculation Log and Exp, does the backward process q(x_t-1 | xt , x0) follows the DDPM formula for probability? If I use the DDIM formula, can it still work? Thirdly , if the loss in train process is not the one in your paper, but the CDloss or HDloss, can it still work?

Recently, I have tried to use the method in your paper for smpl pose recovery mission. Input is 2d key points, t and xt1, using the DDIM process, output is xt2.However, it doesn't work well, the loss in training is small, but the loss in inference is high. I add the mse loss of final predicted joints, I guess it may not work for SE(3) diffusion?

Really, I can't figure out these problems, hope for your reply!

jasonkyuyim commented 7 months ago

Hi, thanks for the interest.

Answer to question 1. We follow the formulation in Riemannian Score-Based Generative Modelling. Assuming se(3) is the tangent space, we are indeed add noise in the tangent space then transfer it back onto SE(3) , see algorithm 1 of Bortoli et al. So I'm not sure what you mean with your first question.

Answer to question 2. We follow the VE-SDE formulation whereas DDPM is known as the VP-SDE formulation. There is a paper exploring DDPM on SO(3) that you can find here Denoising Diffusion Probabilistic Models on SO(3) for Rotational Alignment. The VE-SDE formulation in Bortoli et al makes the most sense to me since it converges to a uniform prior over SO(3). I haven't thought or looked into using DDIM in this setting.

Answer to question 3. I'm not sure what you mean by CDloss or HDloss.

Answer to pose recovery. "Input is 2d key points, t and xt1, using the DDIM process, output is xt2." Isn't this deviating from our work? Our model's outputs are the scores which gets used during the generative (reverse) diffusion process. I can't offer more help without knowing your set-up. Unfortunately I'm low on bandwidth. I would recommend debugging with overfitting onto a few examples then gradually increasing it to make sure your implementation is correct.

Hope that helps.