jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 591 forks source link

Circular distributions #1002

Closed mhavu closed 1 year ago

mhavu commented 1 year ago

Is it possible to represent circular distributions (e.g. wrapped normal distribution) in pomegranate? How would one go about building one?

jmschrei commented 1 year ago

I'm not sure what a "circular distribution" is but you can create custom distributions: https://github.com/jmschrei/pomegranate/blob/master/tutorials/C_Feature_Tutorial_5_Custom_Distributions.ipynb

mhavu commented 1 year ago

Circular distributions are ones where the lower and upper bound of the support are sort of glued together to make a circle. Following the example, it was easy to bake my own von Mises distribution:

import numpy as np
from scipy.special import iv

class VonMisesDistribution():
    def __init__(self, mu, kappa):
        self.mu = mu
        self.kappa = kappa
        self.parameters = (self.mu, self.kappa)
        self.d = 1
        self.summaries = np.zeros(3)

    def probability(self, X):
        return np.exp(self.log_probability(X))

    def log_probability(self, X):
        return (self.kappa * np.cos(X - self.mu)
                - np.log(2 * np.pi * iv(0, self.kappa)))

    def summarize(self, X, w=None):
        if w is None:
            w = np.ones(X.shape[0])

        X = X.reshape(X.shape[0])
        self.summaries[0] += w.sum()
        self.summaries[1] += X.dot(w)
        self.summaries[2] += (X ** 2.).dot(w)

    def from_summaries(self, inertia=0.0):
        self.mu = self.summaries[1] / self.summaries[0]
        self.kappa = np.sqrt(self.summaries[2] / self.summaries[0]
                             - self.summaries[1] ** 2 / self.summaries[0] ** 2)
        self.parameters = (self.mu, self.kappa)
        self.clear_summaries()

    def clear_summaries(self, inertia=0.0):
        self.summaries = np.zeros(3)

    @classmethod
    def from_samples(cls, X, weights=None):
        d = cls(0, 0)
        d.summarize(X, weights)
        d.from_summaries()
        return d

    @classmethod
    def blank(cls):
        return cls(0, 0)

Seems I got exactly what I was looking for:

import matplotlib
import matplotlib.pyplot as plt

d1 = VonMisesDistribution(0, 8)
d2 = VonMisesDistribution(np.pi / 2, 1)
theta = np.arange(-np.pi, np.pi, 0.01)

matplotlib.rcParams.update({'font.family': 'stixgeneral'})
plt.plot(theta, d1.probability(theta), label="von Mises, 𝜇=0, 𝜅=8")
plt.plot(theta, d2.probability(theta), label="von Mises, 𝜇=𝜋/2, 𝜅=1")
plt.xticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi],
           ["-𝜋", "-𝜋/2", "0", "𝜋/2", "𝜋"])
plt.ylabel("Probability density", fontsize=12)
plt.legend(fontsize=12)
plt.show()

pdf

jmschrei commented 1 year ago

Oh, that's cool. Do you mind if I add your implementation, rewritting slightly to be in torch, to https://github.com/jmschrei/torchegranate/tree/main/torchegranate ?

mhavu commented 1 year ago

No, not at all! My contribution to the example is close to zero, anyway. 😄