Open jank324 opened 1 week ago
Good point, I actually have a working batched KDE implementation for another project. I can add this feature if that's needed.
That would probably not be bad. I thought @jp-ga and @roussel-ryan probably also have one from Bmad-X (maybe not batched though?).
# modified from kornia.enhance.histogram
import math
from typing import Optional, Tuple, Union
import matplotlib.pyplot as plt
import torch
from torch import nn, Tensor
from torch.profiler import profile, ProfilerActivity
def marginal_pdf(
values: torch.Tensor,
bins: torch.Tensor,
sigma: torch.Tensor,
weights: Optional[Union[Tensor, float]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate the marginal probability distribution function of the input tensor based on the number of
histogram bins.
Args:
values: shape [BxNx1].
bins: shape [NUM_BINS].
sigma: shape [1], gaussian smoothing factor.
epsilon: scalar, for numerical stability.
Returns:
Tuple[torch.Tensor, torch.Tensor]:
- torch.Tensor: shape [BxN].
- torch.Tensor: shape [BxNxNUM_BINS].
"""
if not isinstance(values, torch.Tensor):
raise TypeError(f"Input values type is not a torch.Tensor. Got {type(values)}")
if not isinstance(bins, torch.Tensor):
raise TypeError(f"Input bins type is not a torch.Tensor. Got {type(bins)}")
if not isinstance(sigma, torch.Tensor):
raise TypeError(f"Input sigma type is not a torch.Tensor. Got {type(sigma)}")
if not bins.dim() == 1:
raise ValueError(
"Input bins must be a of the shape NUM_BINS" " Got {}".format(bins.shape)
)
if not sigma.dim() == 0:
raise ValueError(
"Input sigma must be a of the shape 1" " Got {}".format(sigma.shape)
)
if type(weights) == float:
weights = torch.ones(values.shape[:-1])
elif weights is None:
weights = 1.0
residuals = values - bins.repeat(*values.shape)
kernel_values = (
weights
* torch.exp(-0.5 * (residuals / sigma).pow(2))
/ torch.sqrt(2 * math.pi * sigma**2)
)
prob_mass = torch.sum(kernel_values, dim=-2)
return prob_mass, kernel_values
def joint_pdf(
kernel_values1: torch.Tensor, kernel_values2: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
"""Calculate the joint probability distribution function of the input tensors based on the number of histogram
bins.
Args:
kernel_values1: shape [BxNxNUM_BINS].
kernel_values2: shape [BxNxNUM_BINS].
epsilon: scalar, for numerical stability.
Returns:
shape [BxNUM_BINSxNUM_BINS].
"""
if not isinstance(kernel_values1, torch.Tensor):
raise TypeError(
f"Input kernel_values1 type is not a torch.Tensor. Got {type(kernel_values1)}"
)
if not isinstance(kernel_values2, torch.Tensor):
raise TypeError(
f"Input kernel_values2 type is not a torch.Tensor. Got {type(kernel_values2)}"
)
joint_kernel_values = torch.matmul(kernel_values1.transpose(-2, -1), kernel_values2)
normalization = (
torch.sum(joint_kernel_values, dim=(-2, -1)).unsqueeze(-1).unsqueeze(-1)
+ epsilon
)
pdf = joint_kernel_values / normalization
return pdf
def histogram(
x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor, epsilon: float = 1e-10
) -> torch.Tensor:
"""Estimate the histogram of the input tensor.
The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
Args:
x: Input tensor to compute the histogram with shape :math:`(B, D)`.
bins: The number of bins to use the histogram :math:`(N_{bins})`.
bandwidth: Gaussian smoothing factor with shape shape [1].
epsilon: A scalar, for numerical stability.
Returns:
Computed histogram of shape :math:`(B, N_{bins})`.
Examples:
>>> x = torch.rand(1, 10)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([1, 128])
"""
pdf, _ = marginal_pdf(x.unsqueeze(-1), bins, bandwidth, epsilon)
return pdf
def histogram2d(
x1: torch.Tensor,
x2: torch.Tensor,
bins1: torch.Tensor,
bins2: torch.Tensor,
bandwidth: torch.Tensor,
weights=None,
) -> torch.Tensor:
"""Estimate the 2d histogram of the input tensor.
The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.
Args:
x1: Input tensor to compute the histogram with shape :math:`(B, D1)`.
x2: Input tensor to compute the histogram with shape :math:`(B, D2)`.
bins: bin coordinates.
bandwidth: Gaussian smoothing factor with shape shape [1].
epsilon: A scalar, for numerical stability. Default: 1e-10.
Returns:
Computed histogram of shape :math:`(B, N_{bins}), N_{bins})`.
Examples:
>>> x1 = torch.rand(2, 32)
>>> x2 = torch.rand(2, 32)
>>> bins = torch.torch.linspace(0, 255, 128)
>>> hist = histogram2d(x1, x2, bins, bandwidth=torch.tensor(0.9))
>>> hist.shape
torch.Size([2, 128, 128])
"""
_, kernel_values1 = marginal_pdf(x1.unsqueeze(-1), bins1, bandwidth, weights)
_, kernel_values2 = marginal_pdf(x2.unsqueeze(-1), bins2, bandwidth, weights)
pdf = joint_pdf(kernel_values1, kernel_values2)
return pdf
if __name__ == "__main__":
# 2d histogram
x = torch.linspace(-0.5, 0.5, 100)
mesh_x = torch.meshgrid(x, x)
test_x = torch.stack(mesh_x, dim=-1)
# samples ( `batch_size x n_particles x coord_dim`)
samples = torch.rand(100000, 2)
prob_mass = histogram2d(
samples[..., 0], samples[..., 1], bins1=x, bins2=x, bandwidth=(x[1] - x[0])
)
print(prob_mass.sum(dim=[-2, -1]))
fig, ax = plt.subplots()
c = ax.imshow(prob_mass)
fig.colorbar(c)
plt.show()
Add this disclaimer "adapted from kornia.enhance.histogram", works in batched mode considering that the last one (two) dimensions is particle coordinates
I have a similar code:
import numpy as np
import torch
import torch.nn as nn
class GaussianHistogram(nn.Module):
"""Use Gaussian KDE as an approximation of the histogram
Args:
bins (int): Number of bins
min (torch.Tensor) : Minimum value of the histogram
max (torch.Tensor) : Maximum value of the histogram
sigma (torch.Tensor): Standard deviation in a Gaussian distribution
It acts as a smoothing parameter of the Gaussian KDE
"""
def __init__(
self,
bins: int,
min: torch.Tensor,
max: torch.Tensor,
sigma: torch.Tensor,
):
super(GaussianHistogram, self).__init__()
self.bins = bins
assert min.shape == max.shape
self.min = min.unsqueeze(-1)
self.max = max.unsqueeze(-1)
self.sigma = sigma
self.delta = (self.max - self.min) / bins
self.centers = self.min + self.delta * (
torch.arange(bins).float().unsqueeze(0).expand(*self.min.shape[:-1], -1)
+ 0.5
)
def forward(self, x: torch.Tensor):
"""Calculate the KDE of the input tensor
Args:
x (torch.Tensor): input tensor consisting of the random variables
shape (..., n_x)
Returns:
Computed KDE of the input tensor (..., bins)
"""
batch_shape = x.shape[:-1]
# Try broadcasting
if self.centers.dim() == 1:
centers = (
torch.unsqueeze(self.centers, 0).expand(*batch_shape, -1).unsqueeze(-1)
) # (..., bins, 1)
else:
centers = self.centers.unsqueeze(-1)
if self.sigma.dim() == 0:
sigma = self.sigma
else:
assert self.sigma.shape == x.shape[:-1]
sigma = self.sigma.unsqueeze(-1).unsqueeze(-1) # (..., 1, 1)
x = torch.unsqueeze(x, -2) - centers # (..., bins, n_x)
delta = self.delta.unsqueeze(-1)
x = (
torch.exp(-0.5 * (x / sigma) ** 2) / (sigma * np.sqrt(np.pi * 2)) * delta
) # (..., bins, n_x)
x = x.sum(dim=-1) # (..., bins)
return x
I would place the KDE functionality maybe in another file, as they will not only used by screens. Isutils.py
a good option?
Yeah, I think utils.py
is good enough for now.
... to help differentiability? (see https://arxiv.org/pdf/2404.10853v1)