jmschrei / pomegranate

Fast, flexible and easy to use probabilistic modelling in Python.
http://pomegranate.readthedocs.org/en/latest/
MIT License
3.29k stars 590 forks source link

[BUG] Normal(covariance_type='sphere') crashes when fitting #1036

Closed gerwang closed 10 months ago

gerwang commented 1 year ago

Describe the bug Currently Normal distributions with covariance_type='sphere' are not properly considered in fitting. When covariance_type='sphere', its covs becomes a scalar torch tensor, which needs special care.

To Reproduce

Running the following code

from pomegranate.distributions import Normal
import torch

d = Normal(covariance_type='sphere').fit(torch.randn((10, 3)))

raises the following error:

File .../pomegranate/pomegranate/distributions/normal.py:169, in Normal._reset_cache(self)
    166     self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi)
    167     self.register_buffer("_inv_two_sigma", _inv_two_sigma)
--> 169 if any(self.covs < 0):
    170     raise ValueError("Variances must be positive.")
...
TypeError: iteration over a 0-d tensor