pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.5k stars 981 forks source link

Cholesky error in GaussianHMM, _sequential_gaussian_tensordot() #2017

Open fritzo opened 5 years ago

fritzo commented 5 years ago

I'm seeing cholesky errors in _sequential_gaussian_tensordot() when computing GaussianHMM.log_prob(), including precision matrices with very negative eigenvalues. I'd like to figure out how to numerically stabilize gaussian_tensordot().

One issue that makes this difficult to debug is that torch.eig() often segfaults due to a ref count bug https://github.com/pytorch/pytorch/issues/24450

To reproduce

import torch
from pyro.distributions.hmm import _sequential_gaussian_tensordot
gaussian = torch.load("gaussian.pkl")  # file is attached to this issue
_sequential_gaussian_tensordot(gaussian)

gaussian.pkl.zip

fritzo commented 5 years ago

So far this appears to be a deficiency in torch.distributions.constraints.lower_cholesky whereby the optimized parameters become singular. My immediate workaround was to switch from full-covariance parameters to diagonal-covariance parameters. However we should look into more stable ways of optimizing positive definite matrices. I'll leave this issue open for discussion.

fehiepsi commented 5 years ago

@fritzo Your point is interesting! I tried to see how that issue comes

import torch
torch.manual_seed(2)
x = torch.randn(40, 40).tril(-1) + torch.diag_embed(torch.randn(40).exp())
print("tril diag\n", x.diag())
print("cov eigen values\n", x.matmul(x.T).eig()[0][:, 0])
x_inv_stable = torch.eye(40).triangular_solve(x, upper=False).solution
print("tril stable inverse\n", x_inv_stable.diagonal(dim1=-2, dim2=-1))
x_inv = x.inverse()
print("tril inverse\n", x_inv.diagonal(dim1=-2, dim2=-1))
precision = x_inv.transpose(-2, -1).matmul(x_inv)
precision_stable = x_inv_stable.transpose(-2, -1).matmul(x_inv_stable)
print(precision_stable.cholesky()) . # singular error
print(precision.cholesky())  # singular error

which returns

tril diag
 tensor([ 1.8164,  0.9008,  0.5388,  2.4168,  0.4780,  0.2259,  0.3260,  8.2935,
         0.7697,  2.7750,  1.0188,  0.1855,  0.4467,  0.6841,  1.8005,  0.8780,
         0.3297,  2.1109,  0.1622,  0.0466,  0.1453,  4.5444, 11.1870,  0.3835,
         1.8799,  0.6463,  0.6641,  0.2685,  4.1876,  1.1715,  0.9467,  2.2795,
         1.2949,  1.0002,  3.1797,  0.2366, 11.5137,  4.9844,  0.9715,  1.8482])
cov eigen values
 tensor([ 1.7721e+02,  1.7453e+02,  1.1969e+02,  1.0392e+02,  9.6891e+01,
         8.2600e+01,  7.2500e+01,  6.7973e+01,  5.4561e+01,  4.8541e+01,
         4.5381e+01,  3.7133e+01,  2.9066e+01,  2.7167e+01,  1.9756e+01,
         1.9204e+01,  1.5435e+01,  1.3110e+01,  1.0481e+01,  9.8482e+00,
         8.0088e+00,  7.1055e+00,  6.3819e+00,  4.9923e+00,  4.1197e+00,
         2.8929e+00,  2.6459e+00,  2.0562e+00,  1.2387e+00,  9.0811e-01,
         5.9837e-01,  3.6503e-01,  2.8105e-01,  1.0950e-01,  1.2662e-02,
         4.7462e-04,  5.6004e-05,  1.7590e-05, -1.3351e-06,  2.5402e-06])
tril stable inverse
 tensor([ 0.5505,  1.1102,  1.8561,  0.4138,  2.0919,  4.4262,  3.0679,  0.1206,
         1.2992,  0.3604,  0.9815,  5.3896,  2.2389,  1.4619,  0.5554,  1.1389,
         3.0333,  0.4737,  6.1656, 21.4709,  6.8805,  0.2201,  0.0894,  2.6072,
         0.5320,  1.5473,  1.5057,  3.7250,  0.2388,  0.8536,  1.0563,  0.4387,
         0.7722,  0.9998,  0.3145,  4.2262,  0.0869,  0.2006,  1.0294,  0.5411])
tril inverse
 tensor([  0.8577,   1.7351,   2.4454,   0.0628,   4.4942,  -8.6023,  -0.4101,
          0.1345,   2.5317,  -0.2871,   1.2715, -16.7043,  -6.0439,  -2.6443,
          0.5777,  -0.8966,  -5.0385,   0.6307, -13.4730, -28.8008,  -3.9316,
          0.2600,   0.1201,  -6.0625,   0.9570,   1.2705,  -4.9844,   4.0703,
          0.2386,  -0.2930,   0.1016,   0.4507,  -0.1250,  -0.0625,   0.3105,
         -9.5000,  -0.2500,  -0.2344,  -1.0000,  -2.5000])

We can see that the current implementation of precision_matrix property is not stable due to torch.inverse (which produces a singular tril matrix). Regardless, precision matrix is singular for both versions.

Regarding other parameterizations, there are two of them which might work:

In the mean time, I will try to see if any of these parameterizations is stable enough.

fritzo commented 5 years ago

@fehiepsi I like the low rank idea. What if we implemented a PositiveDefiniteTransform as something like

class PositiveDefiniteTransform(Transform):
    ...
    def _call(self, x):
        assert x.size(-1) == x.size(-2)  # assume full rank
        y = x.matmul(x.transpose(-1, -2))
        jitter = 1e-6 * y.norm(dim=-1, -2).unsqueeze(-1).unsqueeze(-1)
        y = y + jitter * torch.eye(y.size(-1))

That is, what if we replaced most uses of constraints.lower_cholesky with constraints.positive_definite, since positive definiteness is a little easier to enforce for the actual pd matrix than for a Cholesky factor.

The only downside I see is that forward computations would require an extra cholesky() and backward computations would require its derivative, and those might be expensive.

fehiepsi commented 5 years ago

@fritzo Except for the computation overhead of those cholesky decomposition, I think that adding jitter would be fine and is more stable than my other variants. Adding a jitter term proportional to y.norm() makes sense to me but I don't know if it is a good way in practice. At least, it has a nice property that the condition number is stable if we scale the matrix by a constant factor (condition_number(cA + 1e-6 ||cA||) = condition_number(A + 1e-6 ||A||)).

I tried to implement LKJ way but it also suffered from the issue L @ L.t is singular despite that L.diag() is positive. The low rank version is more stable but for high dimensional matrix (I tested with dim=40), small rank cov_factor also leads to singular issue.

That made me rethinking about the problem. I tried to read kalman filter literatures to see how people deal with the singular problems. At least, we have some alternatives:

fritzo commented 5 years ago

@fehiepsi My current thinking is that Gaussian is probably already sufficiently stable, and that we would be better off fixing the bad inputs to mvn_to_gaussian() arising from bad parameters of support constraints.lower_cholesky.

fehiepsi commented 5 years ago

@fritzo Though fixing bad parameters of constraints.lower_cholesky is a good solution to make sure precision matrices are positive definite, working with positive definite matrices might lead to the same issues as in KF literatures. The nice properties of square root form is we don't need to preserve the symmetry and positive definite. At least, GaussianS implementation will not involve any cholesky/inverse operator. If MVN allows prec_scale_tril input, then we won't need any cholesky/inverse operator from the parameter space to the loss in GaussianHMM. Instead, the corresponding operators in square root form are qr and triangular_solve.