Describe the bug
Currently Normal distributions with covariance_type='sphere' are not properly considered in fitting. When covariance_type='sphere', its covs becomes a scalar torch tensor, which needs special care.
To Reproduce
Running the following code
from pomegranate.distributions import Normal
import torch
d = Normal(covariance_type='sphere').fit(torch.randn((10, 3)))
raises the following error:
File .../pomegranate/pomegranate/distributions/normal.py:169, in Normal._reset_cache(self)
166 self.register_buffer("_log_sigma_sqrt_2pi", _log_sigma_sqrt_2pi)
167 self.register_buffer("_inv_two_sigma", _inv_two_sigma)
--> 169 if any(self.covs < 0):
170 raise ValueError("Variances must be positive.")
...
TypeError: iteration over a 0-d tensor
Describe the bug Currently Normal distributions with
covariance_type='sphere'
are not properly considered in fitting. Whencovariance_type='sphere'
, itscovs
becomes a scalar torch tensor, which needs special care.To Reproduce
Running the following code
raises the following error: