desy-ml / cheetah

Fast and differentiable particle accelerator optics simulation for reinforcement learning and optimisation applications.
https://cheetah-accelerator.readthedocs.io
GNU General Public License v3.0
25 stars 12 forks source link

Change `Screen` implementation to KDE? #192

Open jank324 opened 1 week ago

jank324 commented 1 week ago

... to help differentiability? (see https://arxiv.org/pdf/2404.10853v1)

cr-xu commented 1 week ago

Good point, I actually have a working batched KDE implementation for another project. I can add this feature if that's needed.

jank324 commented 1 week ago

That would probably not be bad. I thought @jp-ga and @roussel-ryan probably also have one from Bmad-X (maybe not batched though?).

roussel-ryan commented 1 week ago
# modified from kornia.enhance.histogram
import math
from typing import Optional, Tuple, Union

import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
from torch.profiler import profile, ProfilerActivity

def marginal_pdf(
    values: torch.Tensor,
    bins: torch.Tensor,
    sigma: torch.Tensor,
    weights: Optional[Union[Tensor, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Calculate the marginal probability distribution function of the input tensor based on the number of
    histogram bins.

    Args:
        values: shape [BxNx1].
        bins: shape [NUM_BINS].
        sigma: shape [1], gaussian smoothing factor.
        epsilon: scalar, for numerical stability.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
          - torch.Tensor: shape [BxN].
          - torch.Tensor: shape [BxNxNUM_BINS].
    """

    if not isinstance(values, torch.Tensor):
        raise TypeError(f"Input values type is not a torch.Tensor. Got {type(values)}")

    if not isinstance(bins, torch.Tensor):
        raise TypeError(f"Input bins type is not a torch.Tensor. Got {type(bins)}")

    if not isinstance(sigma, torch.Tensor):
        raise TypeError(f"Input sigma type is not a torch.Tensor. Got {type(sigma)}")

    if not bins.dim() == 1:
        raise ValueError(
            "Input bins must be a of the shape NUM_BINS" " Got {}".format(bins.shape)
        )

    if not sigma.dim() == 0:
        raise ValueError(
            "Input sigma must be a of the shape 1" " Got {}".format(sigma.shape)
        )

    if type(weights) == float:
        weights = torch.ones(values.shape[:-1])
    elif weights is None:
        weights = 1.0

    residuals = values - bins.repeat(*values.shape)
    kernel_values = (
        weights
        * torch.exp(-0.5 * (residuals / sigma).pow(2))
        / torch.sqrt(2 * math.pi * sigma**2)
    )

    prob_mass = torch.sum(kernel_values, dim=-2)
    return prob_mass, kernel_values

def joint_pdf(
    kernel_values1: torch.Tensor, kernel_values2: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
    """Calculate the joint probability distribution function of the input tensors based on the number of histogram
    bins.

    Args:
        kernel_values1: shape [BxNxNUM_BINS].
        kernel_values2: shape [BxNxNUM_BINS].
        epsilon: scalar, for numerical stability.

    Returns:
        shape [BxNUM_BINSxNUM_BINS].
    """

    if not isinstance(kernel_values1, torch.Tensor):
        raise TypeError(
            f"Input kernel_values1 type is not a torch.Tensor. Got {type(kernel_values1)}"
        )

    if not isinstance(kernel_values2, torch.Tensor):
        raise TypeError(
            f"Input kernel_values2 type is not a torch.Tensor. Got {type(kernel_values2)}"
        )

    joint_kernel_values = torch.matmul(kernel_values1.transpose(-2, -1), kernel_values2)
    normalization = (
        torch.sum(joint_kernel_values, dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)
        + epsilon
    )
    pdf = joint_kernel_values / normalization

    return pdf

def histogram(
    x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
    """Estimate the histogram of the input tensor.

    The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.

    Args:
        x: Input tensor to compute the histogram with shape :math:`(B, D)`.
        bins: The number of bins to use the histogram :math:`(N_{bins})`.
        bandwidth: Gaussian smoothing factor with shape shape [1].
        epsilon: A scalar, for numerical stability.

    Returns:
        Computed histogram of shape :math:`(B, N_{bins})`.

    Examples:
        >>> x = torch.rand(1, 10)
        >>> bins = torch.torch.linspace(0, 255, 128)
        >>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
        >>> hist.shape
        torch.Size([1, 128])
    """

    pdf, _ = marginal_pdf(x.unsqueeze(-1), bins, bandwidth, epsilon)

    return pdf

def histogram2d(
    x1: torch.Tensor,
    x2: torch.Tensor,
    bins1: torch.Tensor,
    bins2: torch.Tensor,
    bandwidth: torch.Tensor,
    weights=None,
) -> torch.Tensor:
    """Estimate the 2d histogram of the input tensor.

    The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.

    Args:
        x1: Input tensor to compute the histogram with shape :math:`(B, D1)`.
        x2: Input tensor to compute the histogram with shape :math:`(B, D2)`.
        bins: bin coordinates.
        bandwidth: Gaussian smoothing factor with shape shape [1].
        epsilon: A scalar, for numerical stability. Default: 1e-10.

    Returns:
        Computed histogram of shape :math:`(B, N_{bins}), N_{bins})`.

    Examples:
        >>> x1 = torch.rand(2, 32)
        >>> x2 = torch.rand(2, 32)
        >>> bins = torch.torch.linspace(0, 255, 128)
        >>> hist = histogram2d(x1, x2, bins, bandwidth=torch.tensor(0.9))
        >>> hist.shape
        torch.Size([2, 128, 128])
    """

    _, kernel_values1 = marginal_pdf(x1.unsqueeze(-1), bins1, bandwidth, weights)
    _, kernel_values2 = marginal_pdf(x2.unsqueeze(-1), bins2, bandwidth, weights)

    pdf = joint_pdf(kernel_values1, kernel_values2)

    return pdf

if __name__ == "__main__":
    # 2d histogram
    x = torch.linspace(-0.5, 0.5, 100)
    mesh_x = torch.meshgrid(x, x)
    test_x = torch.stack(mesh_x, dim=-1)

    # samples ( `batch_size x n_particles x coord_dim`)
    samples = torch.rand(100000, 2)

    prob_mass = histogram2d(
        samples[..., 0], samples[..., 1], bins1=x, bins2=x, bandwidth=(x[1] - x[0])
    )
    print(prob_mass.sum(dim=[-2, -1]))

    fig, ax = plt.subplots()
    c = ax.imshow(prob_mass)
    fig.colorbar(c)
    plt.show()
roussel-ryan commented 1 week ago

Add this disclaimer "adapted from kornia.enhance.histogram", works in batched mode considering that the last one (two) dimensions is particle coordinates

cr-xu commented 1 week ago

I have a similar code:

import numpy as np
import torch
import torch.nn as nn

class GaussianHistogram(nn.Module):
    """Use Gaussian KDE as an approximation of the histogram

    Args:
        bins (int): Number of bins
        min (torch.Tensor) : Minimum value of the histogram
        max (torch.Tensor) : Maximum value of the histogram
        sigma (torch.Tensor): Standard deviation in a Gaussian distribution
        It acts as a smoothing parameter of the Gaussian KDE
    """

    def __init__(
        self,
        bins: int,
        min: torch.Tensor,
        max: torch.Tensor,
        sigma: torch.Tensor,
    ):
        super(GaussianHistogram, self).__init__()
        self.bins = bins
        assert min.shape == max.shape
        self.min = min.unsqueeze(-1)
        self.max = max.unsqueeze(-1)
        self.sigma = sigma
        self.delta = (self.max - self.min) / bins
        self.centers = self.min + self.delta * (
            torch.arange(bins).float().unsqueeze(0).expand(*self.min.shape[:-1], -1)
            + 0.5
        )

    def forward(self, x: torch.Tensor):
        """Calculate the KDE of the input tensor

        Args:
            x (torch.Tensor): input tensor consisting of the random variables
                shape (..., n_x)

        Returns:
            Computed KDE of the input tensor (..., bins)
        """
        batch_shape = x.shape[:-1]
        # Try broadcasting
        if self.centers.dim() == 1:
            centers = (
                torch.unsqueeze(self.centers, 0).expand(*batch_shape, -1).unsqueeze(-1)
            )  # (..., bins, 1)
        else:
            centers = self.centers.unsqueeze(-1)
        if self.sigma.dim() == 0:
            sigma = self.sigma
        else:
            assert self.sigma.shape == x.shape[:-1]
            sigma = self.sigma.unsqueeze(-1).unsqueeze(-1)  # (..., 1, 1)
        x = torch.unsqueeze(x, -2) - centers  # (..., bins, n_x)
        delta = self.delta.unsqueeze(-1)
        x = (
            torch.exp(-0.5 * (x / sigma) ** 2) / (sigma * np.sqrt(np.pi * 2)) * delta
        )  # (..., bins, n_x)
        x = x.sum(dim=-1)  # (..., bins)
        return x
cr-xu commented 1 week ago

I would place the KDE functionality maybe in another file, as they will not only used by screens. Isutils.py a good option?

jank324 commented 1 week ago

Yeah, I think utils.py is good enough for now.