Closed heth27 closed 2 weeks ago
Hi! thanks for your contribution!, great first issue!
spatial procrustes alignment, a similarity test for two data sets
this sounds good, would you be interested in adding it to TM? creating draft PR and then we can help you finish it... š¼
Hi @heth27, I took a stab at implementing a batched version of your implementation:
import torch
def procrustus_batch(data1, data2):
if data1.shape != data2.shape:
raise ValueError("data1 and data2 must have the same shape")
if data1.ndim == 2:
data1 = data1[None, :, :]
data2 = data2[None, :, :]
data1 -= data1.mean(dim=1, keepdim=True)
data2 -= data2.mean(dim=1, keepdim=True)
data1 /= torch.linalg.norm(data1, dim=[1,2], keepdim=True)
data2 /= torch.linalg.norm(data2, dim=[1,2], keepdim=True)
try:
u, w, v = torch.linalg.svd(torch.matmul(data2.transpose(1, 2), data1).transpose(1,2), full_matrices=False)
except:
raise ValueError("SVD did not converge")
rotation = torch.matmul(u, v)
scale = w.sum(1, keepdim=True)
data2 = scale[:,None] * torch.matmul(data2, rotation.transpose(1,2))
disparity = (data1 - data2).square().sum(dim=[1,2])
return disparity
coords_3d_true = torch.rand(2, 10, 3)
coords_3d_prediction = torch.rand(2, 10, 3)
p2 = procrustus_batch(coords_3d_true.clone(), coords_3d_prediction.clone())
print(p2)
from scipy.spatial import procrustes as procrustes_scipy
for i in range(2):
mtx1, mtx2, disparity = procrustes_scipy(coords_3d_true[i].clone(), coords_3d_prediction[i].clone())
print(disparity)
for random inputs it seems to work when comparing against scipy. Are you interested in sending a PR or do you want me to take over?
Hi @SkafteNicki thank you, please feel free to create a PR. How do you feel about returning the rotation matrix, or the transformed coordinates as well? They are used for downstream calculation of procrustes-aligned mean per joint position error in a lot of human pose estimation tasks.
@heth27 I would be fine with that. Maybe it makes sense to add an additional argument like return_all_stats
or something similar to indicate if this additional information should be returned.
This metric does not fit under any of our current subdomains, do you have a recommendation for what new domain this metric fits under?
Wikipedia suggests shape analysis. I plan on adding things like Mean-Per-Joint-Position-Error (MPJPE) and Percent-of-correctly-classified-keypoints (PCK) as well when I have more time. Those would also fit the domain. There is also a new feature request for Hausdorff distance https://github.com/Lightning-AI/torchmetrics/issues/1990. I'm not sure if a more general domain (shape analysis) or the most common usage (human pose estimation in the case of procrustes) makes it easier to find.
š Feature
spatial procrustes alignment, a similarity test for two data sets
Motivation
Procrustes alignment is a staple when calculating metrics for 3d human pose estimation, but there seems to be no library that offers this function for pytorch, so I guess everyone just maintains their own version.
Pitch
There is a variant in scipy https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.procrustes.html
Alternatives
Additional context
The implementation I'm using, don't know if it is any good.
example usage:
The problem with this version is that it does not work on batches.