pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
15.98k stars 6.92k forks source link

procrustes alignment for pytorch #8577

Open heth27 opened 1 month ago

heth27 commented 1 month ago

🚀 The feature

Orthogonal procrustes alignment

Motivation, pitch

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

There is a variant in scipy https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

No response

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1
heth27 commented 1 month ago

If this is better suited for e.g. torchmetrics (https://lightning.ai/docs/torchmetrics/stable/) this would also be good to know

NicolasHug commented 1 month ago

Hi @heth27 and thank you for the feature request. Torchvision doesn't really have a holistic support for 3D data in general, so I'm not sure procrustes alignement would be in scope. We typically add such metrics when they directly relate to one of the CV tasks that torchvision supports (classification, detection, etc.), but 3D human pose is not yet in scope. Thank you for providing a snippet, I hope it can be useful to users looking for this exact feature.

NicolasHug commented 1 month ago

If this is better suited for e.g. torchmetrics (lightning.ai/docs/torchmetrics/stable) this would also be good to know

It might be in scope for torchmetrics, although note that this isn't owned by the pytorch org, so we don't have any weight in the decision process over there.