Open pumplerod opened 10 months ago
It appears to have something to do also with the number of samples. In my example I was using 100 samples, but if I increase that then the error goes away. I guess I need to keep my n_samples higher than n_components + n_features or something like that. Still tricky to work around.
I have the same issue. It appears when using cpu or cuda. Seems like it appears with high n_components or n_features in comparison to samples.
Is there a possibility to fix it? Or some rule how to avoid it?
Hi there,
It has been several months, but maybe changing the dtype
from PyTorch's default float32
to float64
(double) would help in some cases.
I think the issue comes from some eigenvalues of the var
here close to zero, so the Cholesky factorization will have some numerical issues.
https://github.com/ldeecke/gmm-torch/blob/23eaf64be98239af1da85c4b55a6215a2fcfff2f/gmm.py#L299
So changing to `double' will help to alleviate the numerical issues.
I change self.mu
, self.var
, self.pi
to double
in _init_params
by adding .double()
after their initializations.
e.g.,
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features) requires_grad=False)
==>
self.mu = torch.nn.Parameter(torch.randn(1, self.n_components, self.n_features).double(), requires_grad=False)
and make sure the input of function fit()
is also in double
.
However, this is a kind of temporary trick and indeed increases the running time and memory occupation. Hope this helps.
there seems to be some limit between the n_components and n_features. If I try and create a model with
it will fail with
_LinAlgError: linalg.cholesky: The factorization could not be completed because the input is not positive-definite (the leading minor of order 22 is not positive-definite).
reducing
n_features=98
will work but then if I raisen_components=2
the error returns.I am trying to work with many more features and components. Potentially 1000+ features and an unknown number of components, but it is likely to be high. Is there any workaround for this?