Lightning-AI / torchmetrics

Machine learning metrics for distributed, scalable PyTorch applications.
https://lightning.ai/docs/torchmetrics/
Apache License 2.0
2.13k stars 404 forks source link

procrustes alignment #2691

Closed heth27 closed 2 weeks ago

heth27 commented 2 months ago

šŸš€ Feature

spatial procrustes alignment, a similarity test for two data sets

Motivation

Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.

Pitch

There is a variant in scipy https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html

Alternatives

Additional context

The implementation I'm using, don't know if it is any good.

def procrustes(pts1: torch.Tensor, pts2: torch.Tensor):
    assert pts1.shape == pts2.shape, f"{pts1.shape} != {pts2.shape}"
    assert pts1.shape[-1] == 3 and len(pts1.shape) == 2, f"{pts1.shape}"
    # estimate a sim3 transformation to align two point clouds
    # find M = argmin ||P1 - M @ P2||
    t1 = pts1.mean(dim=0)
    t2 = pts2.mean(dim=0)
    pts1 = pts1 - t1[None, :]
    pts2 = pts2 - t2[None, :]

    s1 = pts1.square().sum(dim=-1).mean().sqrt()
    s2 = pts2.square().sum(dim=-1).mean().sqrt()
    pts1 = pts1 / s1
    pts2 = pts2 / s2
    try:

        U, _, V = (pts1.T @ pts2).double().svd()
        U: torch.Tensor = U
        V: torch.Tensor = V
    except:
        print("Procustes failed: SVD did not converge!")
        s = s1 / s2
        return 1, torch.eye(3, device=pts1.device), torch.zeros_like(t1)
    # build rotation matrix
    R = (U @ V.T).float()
    if R.det() < 0:
        R[:, 2] *= -1
    s = s1 / s2
    t = t1 - s * t2 @ R.T

    # use as mat4: [sR, t] @ pts2
    # or as s * R @ pts2 + t

    # s, R, mean_1, mean_2 = procrustes(pts1, pts2)
    #
    # procrustes_aligned = torch.einsum("jd, od -> jo", coords3d_pred_rel_dataset_format[index_in_batch] - mean_2,
    #                                               s * R) + mean_1
    return s, R, t1, t2

example usage:

s, R, mean_1, mean_2 = procrustes(coords_3d_true,
                                              coords_3d_prediction)
procrustes_aligned = torch.einsum("jd, od -> jo", coords_3d_prediction - mean_2,
                                              s * R) + mean_1

The problem with this version is that it does not work on batches.

github-actions[bot] commented 2 months ago

Hi! thanks for your contribution!, great first issue!

Borda commented 2 months ago

spatial procrustes alignment, a similarity test for two data sets

this sounds good, would you be interested in adding it to TM? creating draft PR and then we can help you finish it... šŸ‘¼

SkafteNicki commented 2 months ago

Hi @heth27, I took a stab at implementing a batched version of your implementation:

import torch

def procrustus_batch(data1, data2):
    if data1.shape != data2.shape:
        raise ValueError("data1 and data2 must have the same shape")
    if data1.ndim == 2:
        data1 = data1[None, :, :]
        data2 = data2[None, :, :]

    data1 -= data1.mean(dim=1, keepdim=True)
    data2 -= data2.mean(dim=1, keepdim=True)
    data1 /= torch.linalg.norm(data1, dim=[1,2], keepdim=True)
    data2 /= torch.linalg.norm(data2, dim=[1,2], keepdim=True)

    try:
        u, w, v = torch.linalg.svd(torch.matmul(data2.transpose(1, 2), data1).transpose(1,2), full_matrices=False)
    except:
        raise ValueError("SVD did not converge")
    rotation = torch.matmul(u, v)
    scale = w.sum(1, keepdim=True)
    data2 = scale[:,None] * torch.matmul(data2, rotation.transpose(1,2))
    disparity = (data1 - data2).square().sum(dim=[1,2])
    return disparity

coords_3d_true = torch.rand(2, 10, 3)
coords_3d_prediction = torch.rand(2, 10, 3)

p2 = procrustus_batch(coords_3d_true.clone(), coords_3d_prediction.clone())
print(p2)

from scipy.spatial import procrustes as procrustes_scipy
for i in range(2):
    mtx1, mtx2, disparity = procrustes_scipy(coords_3d_true[i].clone(), coords_3d_prediction[i].clone())
    print(disparity)

for random inputs it seems to work when comparing against scipy. Are you interested in sending a PR or do you want me to take over?

heth27 commented 2 months ago

Hi @SkafteNicki thank you, please feel free to create a PR. How do you feel about returning the rotation matrix, or the transformed coordinates as well? They are used for downstream calculation of procrustes-aligned mean per joint position error in a lot of human pose estimation tasks.

SkafteNicki commented 2 months ago

@heth27 I would be fine with that. Maybe it makes sense to add an additional argument like return_all_stats or something similar to indicate if this additional information should be returned.

This metric does not fit under any of our current subdomains, do you have a recommendation for what new domain this metric fits under?

heth27 commented 2 months ago

Wikipedia suggests shape analysis. I plan on adding things like Mean-Per-Joint-Position-Error (MPJPE) and Percent-of-correctly-classified-keypoints (PCK) as well when I have more time. Those would also fit the domain. There is also a new feature request for Hausdorff distance https://github.com/Lightning-AI/torchmetrics/issues/1990. I'm not sure if a more general domain (shape analysis) or the most common usage (human pose estimation in the case of procrustes) makes it easier to find.