ldeecke / gmm-torch

Gaussian mixture models in PyTorch.
MIT License
502 stars 85 forks source link

Add support for full covariance #15

Closed luyvlei closed 3 years ago

ldeecke commented 3 years ago

Hey @luyvlei,

Thanks for working on the full covariance & opening this PR. I wasn't aware that init_params was using kMeans per default in SKL, makes sense to add this here! :+1:

I left some comments just now & will review in detail next week.

luyvlei commented 3 years ago

I've found that since PyTorch uses float32 by default (skearn uses float64 to store variance), this can lead to some cases of inaccuracy and value overflows. Later i'll fix it.

luyvlei commented 3 years ago

I'm a little confused about this formula in function _estimate_log_prob : return -.5 * (self.n_features * np.log(2. * pi) + log_p) + log_det. I deduced whether it should be return -.5 * (self.n_features * np.log(2. * pi) + log_p+ log_det)

ldeecke commented 3 years ago

Hi @luyvlei, awesome commits! 👍

I looked into the point you raised, the return is meant to match that of sklearn's _estimate_log_prob.

Check out the reference: https://github.com/scikit-learn/scikit-learn/blob/39f37bb63d395dd2b97be7f8231ddd2113825c42/sklearn/mixture/_gaussian_mixture.py#L448

luyvlei commented 3 years ago

Hi @luyvlei, awesome commits! 👍

I looked into the point you raised, the return is meant to match that of sklearn's _estimate_log_prob.

Check out the reference: https://github.com/scikit-learn/scikit-learn/blob/39f37bb63d395dd2b97be7f8231ddd2113825c42/sklearn/mixture/_gaussian_mixture.py#L448

Thanks, I have check it. Because it use the precision matrix instead of convariance matrix, so there is no need to add the minus sign. Mybe I should use precision matrix to prevent overflow.

luyvlei commented 3 years ago

Let's remove README.md and moon_cluster.png from the PR. Otherwise good to go! 👍

:ok_hand:

ldeecke commented 3 years ago

Thanks for the hard work @luyvlei! 🦾