Closed mhavu closed 1 year ago
I'm not sure what a "circular distribution" is but you can create custom distributions: https://github.com/jmschrei/pomegranate/blob/master/tutorials/C_Feature_Tutorial_5_Custom_Distributions.ipynb
Circular distributions are ones where the lower and upper bound of the support are sort of glued together to make a circle. Following the example, it was easy to bake my own von Mises distribution:
import numpy as np
from scipy.special import iv
class VonMisesDistribution():
def __init__(self, mu, kappa):
self.mu = mu
self.kappa = kappa
self.parameters = (self.mu, self.kappa)
self.d = 1
self.summaries = np.zeros(3)
def probability(self, X):
return np.exp(self.log_probability(X))
def log_probability(self, X):
return (self.kappa * np.cos(X - self.mu)
- np.log(2 * np.pi * iv(0, self.kappa)))
def summarize(self, X, w=None):
if w is None:
w = np.ones(X.shape[0])
X = X.reshape(X.shape[0])
self.summaries[0] += w.sum()
self.summaries[1] += X.dot(w)
self.summaries[2] += (X ** 2.).dot(w)
def from_summaries(self, inertia=0.0):
self.mu = self.summaries[1] / self.summaries[0]
self.kappa = np.sqrt(self.summaries[2] / self.summaries[0]
- self.summaries[1] ** 2 / self.summaries[0] ** 2)
self.parameters = (self.mu, self.kappa)
self.clear_summaries()
def clear_summaries(self, inertia=0.0):
self.summaries = np.zeros(3)
@classmethod
def from_samples(cls, X, weights=None):
d = cls(0, 0)
d.summarize(X, weights)
d.from_summaries()
return d
@classmethod
def blank(cls):
return cls(0, 0)
Seems I got exactly what I was looking for:
import matplotlib
import matplotlib.pyplot as plt
d1 = VonMisesDistribution(0, 8)
d2 = VonMisesDistribution(np.pi / 2, 1)
theta = np.arange(-np.pi, np.pi, 0.01)
matplotlib.rcParams.update({'font.family': 'stixgeneral'})
plt.plot(theta, d1.probability(theta), label="von Mises, 𝜇=0, 𝜅=8")
plt.plot(theta, d2.probability(theta), label="von Mises, 𝜇=𝜋/2, 𝜅=1")
plt.xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi],
["-𝜋", "-𝜋/2", "0", "𝜋/2", "𝜋"])
plt.ylabel("Probability density", fontsize=12)
plt.legend(fontsize=12)
plt.show()
Oh, that's cool. Do you mind if I add your implementation, rewritting slightly to be in torch, to https://github.com/jmschrei/torchegranate/tree/main/torchegranate ?
No, not at all! My contribution to the example is close to zero, anyway. 😄
Is it possible to represent circular distributions (e.g. wrapped normal distribution) in pomegranate? How would one go about building one?