qazwsxal / diffusion-extensions

Reference implmentation of Diffusion models on SO(3)
26 stars 7 forks source link

About the conversion from axis-angle representation to skew-symmetric matrix #2

Open jiaxiang-wu opened 1 year ago

jiaxiang-wu commented 1 year ago

Hi, I have a question about the conversion from the axis-angle representation vector to a skew-symmetric matrix, as implemented here: https://github.com/qazwsxal/diffusion-extensions/blob/f100885dc7f33beda7a5182324eb19f29b29fd47/util.py#L87

In the above implementation, the skew-symmetric matrix $S(v)$ for $v = \left[ x, y, z \right]$ is defined as:

$$S(v) = \begin{bmatrix} 0 & -z & y \\\ z & 0 & -x \\\ -y & x & 0 \end{bmatrix}$$

which is different from the definition in your ICLR paper (Denoising Diffusion Probabilistic Models on SO(3) for Rotational Alignment, https://openreview.net/forum?id=BY88eBbkpe5, page 3, six lines above Eq. 6).

image

Did I misunderstand something? Which definition should be used here for such conversion?

dexin-wang commented 10 months ago

Hello, I also found this problem. In addition, in the log_rmat() function, the calculation of skew_mat is also opposite to the sign of the formula in the paper. Have you solved this problem?

def log_rmat(r_mat: torch.Tensor) -> torch.Tensor:
    skew_mat = (r_mat - r_mat.transpose(-1, -2))    # the calculation is opposite to the sign of the formula in the paper. 
    sk_vec = skew2vec(skew_mat)
    s_angle = (sk_vec).norm(p=2, dim=-1) / 2    # sin(theta)
    c_angle = (torch.einsum('...ii', r_mat) - 1) / 2    # cos(theta)
    angle = torch.atan2(s_angle, c_angle) 
    scale = (angle / (2 * s_angle))
    # if s_angle = 0, i.e. rotation by 0 or pi (180), we get NaNs
    # by definition, scale values are 0 if rotating by 0.
    # This also breaks down if rotating by pi, fix further down
    scale[angle == 0.0] = 0.0
    log_r_mat = scale[..., None, None] * skew_mat

    # Check for NaNs caused by 180deg rotations.
    nanlocs = log_r_mat[...,0,0].isnan() 
    nanmats = r_mat[nanlocs]
    # We need to use an alternative way of finding the logarithm for nanmats,
    # Use eigendecomposition to discover axis of rotation.
    # By definition, these are symmetric, so use eigh.
    # NOTE: linalg.eig() isn't in torch 1.8,
    #       and torch.eig() doesn't do batched matrices
    eigval, eigvec = torch.linalg.eigh(nanmats)
    # Final eigenvalue == 1, might be slightly off because floats, but other two are -ve.
    # this *should* just be the last column if the docs for eigh are true.
    nan_axes = eigvec[...,-1,:]
    nan_angle = angle[nanlocs]
    nan_skew = vec2skew(nan_angle[...,None] * nan_axes)
    log_r_mat[nanlocs] = nan_skew
    return log_r_mat