ddbourgin / numpy-ml

Machine learning, in numpy
https://numpy-ml.readthedocs.io/
GNU General Public License v3.0
15.35k stars 3.72k forks source link

fix: multi dimension update for covariance in gmm #22

Closed WuZhuoran closed 5 years ago

WuZhuoran commented 5 years ago

- What bug I fixed

According to @jjjjohnson in #16 , we can apply multi dimension covariance in gmm.

This pull request fixes #16.

- How I fixed it

  1. change dimension from 2 to self.d.

- How you can verify it

The tests did not pass because of the following: And tests did pass before changing. We can ask @jjjjohnson to take a look.

/Users/zhuoran/Documents/git/numpy-ml/gmm/gmm.py:66: RuntimeWarning: invalid value encountered in double_scalars
  if np.isnan(vlb) or np.abs((vlb - prev_vlb) / prev_vlb) <= tol:
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
Singular matrix: components collapsed
Components collapsed; Refitting
/Users/zhuoran/Documents/git/numpy-ml/gmm/gmm.py:116: RuntimeWarning: invalid value encountered in true_divide
  self.mu[ix, :] = num / den
/Users/zhuoran/Documents/git/numpy-ml/gmm/gmm.py:43: RuntimeWarning: divide by zero encountered in log
  log_pi_k = np.log(pi_k)
/usr/local/lib/python3.7/site-packages/numpy/linalg/linalg.py:1817: RuntimeWarning: invalid value encountered in slogdet
  sign, logdet = _umath_linalg.slogdet(a, signature=signature)
Singular matrix: components collapsed
Components collapsed; Refitting
Traceback (most recent call last):
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/tests.py", line 111, in <module>
    plot()
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/tests.py", line 100, in plot
    ax = plot_clusters(G, X, ax)
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/tests.py", line 52, in plot_clusters
    rv = multivariate_normal(model.mu[c], model.sigma[c], allow_singular=True)
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_multivariate.py", line 363, in __call__
    seed=seed)
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_multivariate.py", line 736, in __init__
    self.cov_info = _PSD(self.cov, allow_singular=allow_singular)
  File "/usr/local/lib/python3.7/site-packages/scipy/stats/_multivariate.py", line 156, in __init__
    s, u = scipy.linalg.eigh(M, lower=lower, check_finite=check_finite)
  File "/usr/local/lib/python3.7/site-packages/scipy/linalg/decomp.py", line 374, in eigh
    a1 = _asarray_validated(a, check_finite=check_finite)
  File "/usr/local/lib/python3.7/site-packages/scipy/_lib/_util.py", line 239, in _asarray_validated
    a = toarray(a)
  File "/usr/local/lib/python3.7/site-packages/numpy/lib/function_base.py", line 1233, in asarray_chkfinite
    "array must not contain infs or NaNs")
ValueError: array must not contain infs or NaNs
jjjjohnson commented 5 years ago

Could you run it another time? I didnot encounter the error after the change

Could be den =0 in self.mu[ix, :] = num / den

WuZhuoran commented 5 years ago

Now I get another error:

Traceback (most recent call last):
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/tests.py", line 110, in <module>
    plot()
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/tests.py", line 89, in plot
    ret = _G.fit(max_iter=75, verbose=False)
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/gmm.py", line 59, in fit
    self._E_step()
  File "/Users/zhuoran/Documents/git/numpy-ml/gmm/gmm.py", line 102, in _E_step
    assert_allclose(np.sum(q_i), 1, err_msg="{}".format(np.sum(q_i)))
  File "/usr/local/lib/python3.7/site-packages/numpy/testing/nose_tools/utils.py", line 1398, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/usr/local/lib/python3.7/site-packages/numpy/testing/nose_tools/utils.py", line 781, in assert_array_compare
    raise AssertionError(msg)
AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0
2.0
(mismatch 100.0%)
 x: array(2.)
 y: array(1)
WuZhuoran commented 5 years ago

And then I ran again, it passed all test.

It seems that the test is not stable.

ddbourgin commented 5 years ago

@jjjjohnson - your suggestion in #16 is correct; I must have forgot to change the np.zeros(2, 2) after testing the model for the 2D case!

Anyway, it's pretty unreasonable for me to consider plots as any sort of test at all. I renamed the test.py file in the above commits to plot.py.

RE: the instability @WuZhuoran observes - this was occurring because we never seeded the rng for in the GMM, so different runs had different inits, and thus different results. I've fixed this in the above commits by adding a seed parameter to the GMM class and using it within plot.py to ensure reproducibility across runs.

I think the assertion errors you saw earlier were due to instances in which a mixture component collapsed and wasn't caught properly. I've added a few small checks in the code to guard against this - I think that should address it, though please check and let me know.

Finally, as I went over the code, I noticed that X was being passed during init rather than as a parameter to the fit method. I've updated the code to address this.