SCIInstitute / ShapeWorks

100 stars 32 forks source link

Analysis - Add Mapping Error Metric #2213

Open jadie1 opened 4 months ago

jadie1 commented 4 months ago

One qualitiative metric we use to analyze correspondence is the consistency of particle neighborhoods across shapes. This can be quantitatively captured as the mapping error as described in Point2SSM (Equation 2). The more consistent the distance between particles and their neighbors is across shapes, the smaller the mapping error is.

Below Python/PyTorch code for calculating mapping error. Note this uses euclidean distance, but since we have access to the shape surfaces, we could also use geodesic distance. That may be slower to quantify, but would provide a better metric. This also averages over the mapping error computed between all pairs of particles in the cohort, but we could just do the mapping error of each particle set to reference particles like the median.

import os 
import glob
import argparse 
import numpy as np
import torch
from torch_cluster import knn
import pytorch3d
from pytorch3d import ops

    particles: particle system numpy array with shape (N,M,3) where N is the number of shapes and M is the number of points
    k: number of neighbors to use
    ME: the average mapping error of all possible pairs of point sets
def get_average_mapping_error(particles, k):
    N = particles.shape[0] # number of shapes
    particles = torch.FloatTensor(particles) # Convert to torch tensor
    # get K nearest neighors
    edge_index = [knn(particles[i], particles[i], k,) for i in range(particles.shape[0])]
    neigh_idxs = torch.stack([edge_index[i][1].reshape(particles.shape[1], -1) for i in range(particles.shape[0])])
    MEs = []
    for source_index in range(N):
        for target_index in range(N):
            if source_index != target_index:
                MEs.append(calculate_mapping_error(particles[source_index].unsqueeze(0), neigh_idxs[source_index].unsqueeze(0), particles[target_index].unsqueeze(0), k))
    avg_ME = torch.Tensor(MEs).mean().item()
    return avg_ME

Calculates the mapping error between source and target point sets
def calculate_mapping_error(source, source_neighs, target, k):
    source_grouped = pytorch3d.ops.knn_gather(source.contiguous(), source_neighs)
    source_diff = source_grouped[:, :, 1:, :] - torch.unsqueeze(source, 2)  # remove fist grouped element, as it is the seed point itself
    source_square = torch.sum(source_diff ** 2, dim=-1)

    target_cr_grouped = pytorch3d.ops.knn_gather(target.contiguous(), source_neighs)
    target_cr_diff = target_cr_grouped[:, :, 1:, :] - torch.unsqueeze(target, 2)  # remove fist grouped element, as it is the seed point itself
    target_cr_square = torch.sum(target_cr_diff ** 2, dim=-1)

    gaussian_heat_kernel = torch.exp(-source_square/GAUSSIAN_HEAT_KERNEL_T)
    ME_per_neigh = torch.mul(gaussian_heat_kernel, target_cr_square)

    ME = torch.mean(ME_per_neigh)
    return ME

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Calculates mapping error')
    parser.add_argument('-p', '--particle_dir', help='Directory with particle files', required=True)
    parser.add_argument('-k', '--k', help='Number of neighbors to use', required=True, type=int)
    args = parser.parse_args()

    # Get numpy array of particles
    particle_files = sorted(glob.glob(args.particle_dir+'/*_world.particles'))
    particles = []
    for particle_file in particle_files:
    particles = np.array(particles)

    ME = get_average_mapping_error(particles, args.k)
    print('Mapping error:', ME)