Closed martinnormark closed 3 months ago
The forward term is implemented:
Let's break down the forward term of the Probabilistic Surface Distance as described in the paper and implement it step by step.
The equation provided in the paper is:
d_{\mathcal{S}, \mathcal{S}_s}^{f} = \sum _{\hat{\mathbf{b}} \in \mathcal{S}_s} p_{\hat{\mathbf{b}}} \min _{\mathbf{b} \in \mathcal{S}} \| \hat{\mathbf{b}} - \mathbf{b} \|^2
\hat{\mathbf{b}} \in \mathcal{S}_s
$ These are the barycenters of the triangles in the simplified mesh $\mathcal{S}_s
$.\mathbf{b} \in \mathcal{S}
$: These are the barycenters of the triangles in the original mesh $\mathcal{S}
$.p_{\hat{\mathbf{b}}}
$: This is the probability associated with the barycenter $\hat{\mathbf{b}}
$.\| \hat{\mathbf{b}} - \mathbf{b} \|^2
$: This is the squared distance between the barycenters $\hat{\mathbf{b}}
$ and $\mathbf{b}
$.The goal of this term is to sum the minimum squared distances between each barycenter in the simplified mesh and the closest barycenter in the original mesh, weighted by the probabilities of the barycenters in the simplified mesh.
Let's translate this into code:
Compute the Barycenters: First, we need to compute the barycenters for both the original and simplified meshes.
Calculate the Squared Distances: For each barycenter in the simplified mesh, compute the squared distance to every barycenter in the original mesh.
Find the Minimum Distance: For each simplified barycenter, find the minimum distance to any original barycenter.
Weight by Probability and Sum: Multiply these minimum distances by the corresponding probabilities and sum them up.
Here’s how this could be implemented in PyTorch:
import torch
class ProbabilisticSurfaceDistanceLoss(nn.Module):
def __init__(self, epsilon: float = 1e-8):
super().__init__()
self.epsilon = epsilon
def forward(
self,
original_vertices: torch.Tensor,
original_faces: torch.Tensor,
simplified_vertices: torch.Tensor,
simplified_faces: torch.Tensor,
face_probabilities: torch.Tensor,
) -> torch.Tensor:
if original_vertices.shape[0] == 0 or simplified_vertices.shape[0] == 0:
return torch.tensor(0.0, device=original_vertices.device)
# Step 1: Compute barycenters of both original and simplified meshes
original_barycenters = self.compute_barycenters(original_vertices, original_faces)
simplified_barycenters = self.compute_barycenters(simplified_vertices, simplified_faces)
# Step 2: Calculate the squared distances between each simplified barycenter and all original barycenters
distances = self.compute_squared_distances(simplified_barycenters, original_barycenters)
# Step 3: Find the minimum distance for each simplified barycenter
min_distances, _ = distances.min(dim=1)
# Step 4: Weight by face probabilities and sum
weighted_distances = face_probabilities * min_distances
total_loss = weighted_distances.sum()
return total_loss
@staticmethod
def compute_barycenters(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
return vertices[faces].mean(dim=1)
@staticmethod
def compute_squared_distances(barycenters1: torch.Tensor, barycenters2: torch.Tensor) -> torch.Tensor:
# barycenters1: (num_faces1, 3)
# barycenters2: (num_faces2, 3)
num_faces1 = barycenters1.size(0)
num_faces2 = barycenters2.size(0)
# Expand dimensions to compute pairwise differences
barycenters1_exp = barycenters1.unsqueeze(1).expand(num_faces1, num_faces2, 3)
barycenters2_exp = barycenters2.unsqueeze(0).expand(num_faces1, num_faces2, 3)
# Compute squared Euclidean distances
distances = torch.sum((barycenters1_exp - barycenters2_exp) ** 2, dim=2)
return distances
compute_barycenters
: Computes the barycenters of the faces by averaging the vertices of each face.compute_squared_distances
: Computes the squared Euclidean distances between each pair of barycenters from the simplified and original meshes.min
on distances: For each barycenter in the simplified mesh, finds the minimum squared distance to any barycenter in the original mesh.The two tests pass to verify identical meshes has zero loss, and loss increase when displacing faces:
import torch
import pytest
from losses.surface_distance_loss import ProbabilisticSurfaceDistanceLoss
@pytest.fixture
def loss_fn():
return ProbabilisticSurfaceDistanceLoss(k=3, num_samples=100)
@pytest.fixture
def simple_cube_data():
vertices = torch.tensor(
[
[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[1, 1, 0],
[0, 0, 1],
[1, 0, 1],
[0, 1, 1],
[1, 1, 1],
],
dtype=torch.float32,
)
faces = torch.tensor(
[
[0, 1, 2],
[1, 3, 2],
[4, 5, 6],
[5, 7, 6],
[0, 4, 1],
[1, 4, 5],
[2, 3, 6],
[3, 7, 6],
[0, 2, 4],
[2, 6, 4],
[1, 5, 3],
[3, 5, 7],
],
dtype=torch.long,
)
return vertices, faces
def test_loss_zero_for_identical_meshes(loss_fn, simple_cube_data):
vertices, faces = simple_cube_data
face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
loss = loss_fn(vertices, faces, vertices, faces, face_probs)
print(f"Loss for identical meshes: {loss.item()}")
assert loss.item() < 1e-5
def test_loss_increases_with_vertex_displacement(loss_fn, simple_cube_data):
vertices, faces = simple_cube_data
face_probs = torch.ones(faces.shape[0], dtype=torch.float32)
displaced_vertices = vertices.clone()
displaced_vertices[0] += torch.tensor([0.1, 0.1, 0.1])
loss_original = loss_fn(vertices, faces, vertices, faces, face_probs)
loss_displaced = loss_fn(vertices, faces, displaced_vertices, faces, face_probs)
print(
f"Original loss: {loss_original.item()}, Displaced loss: {loss_displaced.item()}"
)
assert loss_displaced > loss_original
assert not torch.isclose(loss_displaced, loss_original, atol=1e-5)
From the paper:
The Probabilistic Surfaces Distance (PSD) is introduced to avoid having triangles in regions that don't exist in the original mesh and to penalize the presence of surface holes. It's a Chamfer-inspired loss that measures the distance between a ground truth and a probabilistic surface.
Mathematical Formulation:
The PSD loss consists of two terms: a forward term and a reverse term.
Forward term: d^f_{S,Ss} = Σ{b̂∈S_s} pb̂ min{b∈S} ||b̂ - b||^2
Where:
Reverse term: d^r_{S,Ss} = Σ{y∈S_s} py min{x∈S} ||x - y||^2 + (1 - py) 1/k Σ{k} p_{tk} ||x{t_k} - y||^2
Where: