qazwsxal / diffusion-extensions

Reference implmentation of Diffusion models on SO(3)
28 stars 7 forks source link

About the use of so3_lerp and so3_scale #1

Open Foruck opened 2 years ago

Foruck commented 2 years ago

Thanks for sharing your marvelous project! I have some questions about the use of so3_lerp and so3_scale. According to the paper, it seems that in the following code, so3_scale should be used instead of so3_lerp towards identity rotation. https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/diffusion.py#L286

Is there anything I missed or misunderstood? Or does that mean so3_lerp could approximate so3_scale? Your response would help a lot.

qazwsxal commented 2 years ago

Hi Foruck,

Apologies for the quality of the code, it is in need of some cleanup! You're right in this case that so3_scale should be being used here. Fortunately, as we're interpolating between the identity and x_start, the transformation to axis-angle format and back means that we are still choosing a mean on the geodesic between them, and so get the same result.

so3_lerp is a more general function that interpolates along a geodesic between two arbitrary points in SO(3). You can think of so3_scale as a special function that interpolates along the geodesic between the identity and another point in SO(3) using the Lie group rather than the axis-angle formula.

The Lie group interpolation formula would have difficulty working in the more general case, as the matrix logarithm is non-unique, and the line in R^(3x3) between the matrix logarithms that we've calculated many not be the geodesic in SO(3). (Consider two rotations along the same axis, one 170 degrees and the other -170 degrees, the calculated matrix logarithms would be far apart in R^(3x3) and the line between them would correspond to going through the identity in SO(3), definitely not the shortest path between them!)

Foruck commented 2 years ago

Thanks a lot for your response! According to your response, I'm trying to replace so3_scale with so3_lerp. The main reason is that the matrix_exp function used in so3_scale is quite time-consuming. However, I ran into a problem that the loss becomes NaN. Do you have any clues? Does that mean so3_scale is more computationally stable?