martinnormark / neural-mesh-simplification

Un-official and WIP Implementation of the Neural Mesh Simplification paper
MIT License
10 stars 0 forks source link

Implement Probabilistic Surfaces Distance loss #29

Closed martinnormark closed 3 months ago

martinnormark commented 3 months ago

From the paper:

The Probabilistic Surfaces Distance (PSD) is introduced to avoid having triangles in regions that don't exist in the original mesh and to penalize the presence of surface holes. It's a Chamfer-inspired loss that measures the distance between a ground truth and a probabilistic surface.

Mathematical Formulation:

The PSD loss consists of two terms: a forward term and a reverse term.

  1. Forward term: d^f_{S,Ss} = Σ{b̂∈S_s} pb̂ min{b∈S} ||b̂ - b||^2

    Where:

    • S_s is the generated surface
    • S is the ground truth surface
    • b̂ are the barycenters of the generated triangles
    • b are the barycenters of the ground truth surface
    • p_b̂ is the probability of the generated triangle
  2. Reverse term: d^r_{S,Ss} = Σ{y∈S_s} py min{x∈S} ||x - y||^2 + (1 - py) 1/k Σ{k} p_{tk} ||x{t_k} - y||^2

    Where:

    • y is a point from the generated surface S_s
    • x is a point from the ground truth surface S
    • p_y is the probability of point y
    • t_k are the k-nearest triangles to y in S_s
    • x_{t_k} is a point on triangle t_k
martinnormark commented 3 months ago

The forward term is implemented:

Let's break down the forward term of the Probabilistic Surface Distance as described in the paper and implement it step by step.

Step 1: Understanding the Equation

The equation provided in the paper is:

d_{\mathcal{S}, \mathcal{S}_s}^{f} = \sum _{\hat{\mathbf{b}} \in \mathcal{S}_s} p_{\hat{\mathbf{b}}} \min _{\mathbf{b} \in \mathcal{S}} \| \hat{\mathbf{b}} - \mathbf{b} \|^2

The goal of this term is to sum the minimum squared distances between each barycenter in the simplified mesh and the closest barycenter in the original mesh, weighted by the probabilities of the barycenters in the simplified mesh.

Step 2: Implementation of the Forward Term

Let's translate this into code:

  1. Compute the Barycenters: First, we need to compute the barycenters for both the original and simplified meshes.

  2. Calculate the Squared Distances: For each barycenter in the simplified mesh, compute the squared distance to every barycenter in the original mesh.

  3. Find the Minimum Distance: For each simplified barycenter, find the minimum distance to any original barycenter.

  4. Weight by Probability and Sum: Multiply these minimum distances by the corresponding probabilities and sum them up.

Here’s how this could be implemented in PyTorch:

import torch

class ProbabilisticSurfaceDistanceLoss(nn.Module):
    def __init__(self, epsilon: float = 1e-8):
        super().__init__()
        self.epsilon = epsilon

    def forward(
        self,
        original_vertices: torch.Tensor,
        original_faces: torch.Tensor,
        simplified_vertices: torch.Tensor,
        simplified_faces: torch.Tensor,
        face_probabilities: torch.Tensor,
    ) -> torch.Tensor:
        if original_vertices.shape[0] == 0 or simplified_vertices.shape[0] == 0:
            return torch.tensor(0.0, device=original_vertices.device)

        # Step 1: Compute barycenters of both original and simplified meshes
        original_barycenters = self.compute_barycenters(original_vertices, original_faces)
        simplified_barycenters = self.compute_barycenters(simplified_vertices, simplified_faces)

        # Step 2: Calculate the squared distances between each simplified barycenter and all original barycenters
        distances = self.compute_squared_distances(simplified_barycenters, original_barycenters)

        # Step 3: Find the minimum distance for each simplified barycenter
        min_distances, _ = distances.min(dim=1)

        # Step 4: Weight by face probabilities and sum
        weighted_distances = face_probabilities * min_distances
        total_loss = weighted_distances.sum()

        return total_loss

    @staticmethod
    def compute_barycenters(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
        return vertices[faces].mean(dim=1)

    @staticmethod
    def compute_squared_distances(barycenters1: torch.Tensor, barycenters2: torch.Tensor) -> torch.Tensor:
        # barycenters1: (num_faces1, 3)
        # barycenters2: (num_faces2, 3)

        num_faces1 = barycenters1.size(0)
        num_faces2 = barycenters2.size(0)

        # Expand dimensions to compute pairwise differences
        barycenters1_exp = barycenters1.unsqueeze(1).expand(num_faces1, num_faces2, 3)
        barycenters2_exp = barycenters2.unsqueeze(0).expand(num_faces1, num_faces2, 3)

        # Compute squared Euclidean distances
        distances = torch.sum((barycenters1_exp - barycenters2_exp) ** 2, dim=2)

        return distances

Explanation:

  1. compute_barycenters: Computes the barycenters of the faces by averaging the vertices of each face.
  2. compute_squared_distances: Computes the squared Euclidean distances between each pair of barycenters from the simplified and original meshes.
  3. min on distances: For each barycenter in the simplified mesh, finds the minimum squared distance to any barycenter in the original mesh.
  4. Weighting and Summation: Multiplies these minimum distances by the corresponding face probabilities and sums them up to get the total loss.

Step 3: Testing the Implementation

The two tests pass to verify identical meshes has zero loss, and loss increase when displacing faces:

import torch
import pytest
from losses.surface_distance_loss import ProbabilisticSurfaceDistanceLoss

@pytest.fixture
def loss_fn():
    return ProbabilisticSurfaceDistanceLoss(k=3, num_samples=100)

@pytest.fixture
def simple_cube_data():
    vertices = torch.tensor(
        [
            [0, 0, 0],
            [1, 0, 0],
            [0, 1, 0],
            [1, 1, 0],
            [0, 0, 1],
            [1, 0, 1],
            [0, 1, 1],
            [1, 1, 1],
        ],
        dtype=torch.float32,
    )

    faces = torch.tensor(
        [
            [0, 1, 2],
            [1, 3, 2],
            [4, 5, 6],
            [5, 7, 6],
            [0, 4, 1],
            [1, 4, 5],
            [2, 3, 6],
            [3, 7, 6],
            [0, 2, 4],
            [2, 6, 4],
            [1, 5, 3],
            [3, 5, 7],
        ],
        dtype=torch.long,
    )

    return vertices, faces

def test_loss_zero_for_identical_meshes(loss_fn, simple_cube_data):
    vertices, faces = simple_cube_data
    face_probs = torch.ones(faces.shape[0], dtype=torch.float32)

    loss = loss_fn(vertices, faces, vertices, faces, face_probs)
    print(f"Loss for identical meshes: {loss.item()}")
    assert loss.item() < 1e-5

def test_loss_increases_with_vertex_displacement(loss_fn, simple_cube_data):
    vertices, faces = simple_cube_data
    face_probs = torch.ones(faces.shape[0], dtype=torch.float32)

    displaced_vertices = vertices.clone()
    displaced_vertices[0] += torch.tensor([0.1, 0.1, 0.1])

    loss_original = loss_fn(vertices, faces, vertices, faces, face_probs)
    loss_displaced = loss_fn(vertices, faces, displaced_vertices, faces, face_probs)

    print(
        f"Original loss: {loss_original.item()}, Displaced loss: {loss_displaced.item()}"
    )
    assert loss_displaced > loss_original
    assert not torch.isclose(loss_displaced, loss_original, atol=1e-5)