Open fritzo opened 5 years ago
So far this appears to be a deficiency in torch.distributions.constraints.lower_cholesky
whereby the optimized parameters become singular. My immediate workaround was to switch from full-covariance parameters to diagonal-covariance parameters. However we should look into more stable ways of optimizing positive definite matrices. I'll leave this issue open for discussion.
@fritzo Your point is interesting! I tried to see how that issue comes
import torch
torch.manual_seed(2)
x = torch.randn(40, 40).tril(-1) + torch.diag_embed(torch.randn(40).exp())
print("tril diag\n", x.diag())
print("cov eigen values\n", x.matmul(x.T).eig()[0][:, 0])
x_inv_stable = torch.eye(40).triangular_solve(x, upper=False).solution
print("tril stable inverse\n", x_inv_stable.diagonal(dim1=-2, dim2=-1))
x_inv = x.inverse()
print("tril inverse\n", x_inv.diagonal(dim1=-2, dim2=-1))
precision = x_inv.transpose(-2, -1).matmul(x_inv)
precision_stable = x_inv_stable.transpose(-2, -1).matmul(x_inv_stable)
print(precision_stable.cholesky()) . # singular error
print(precision.cholesky()) # singular error
which returns
tril diag
tensor([ 1.8164, 0.9008, 0.5388, 2.4168, 0.4780, 0.2259, 0.3260, 8.2935,
0.7697, 2.7750, 1.0188, 0.1855, 0.4467, 0.6841, 1.8005, 0.8780,
0.3297, 2.1109, 0.1622, 0.0466, 0.1453, 4.5444, 11.1870, 0.3835,
1.8799, 0.6463, 0.6641, 0.2685, 4.1876, 1.1715, 0.9467, 2.2795,
1.2949, 1.0002, 3.1797, 0.2366, 11.5137, 4.9844, 0.9715, 1.8482])
cov eigen values
tensor([ 1.7721e+02, 1.7453e+02, 1.1969e+02, 1.0392e+02, 9.6891e+01,
8.2600e+01, 7.2500e+01, 6.7973e+01, 5.4561e+01, 4.8541e+01,
4.5381e+01, 3.7133e+01, 2.9066e+01, 2.7167e+01, 1.9756e+01,
1.9204e+01, 1.5435e+01, 1.3110e+01, 1.0481e+01, 9.8482e+00,
8.0088e+00, 7.1055e+00, 6.3819e+00, 4.9923e+00, 4.1197e+00,
2.8929e+00, 2.6459e+00, 2.0562e+00, 1.2387e+00, 9.0811e-01,
5.9837e-01, 3.6503e-01, 2.8105e-01, 1.0950e-01, 1.2662e-02,
4.7462e-04, 5.6004e-05, 1.7590e-05, -1.3351e-06, 2.5402e-06])
tril stable inverse
tensor([ 0.5505, 1.1102, 1.8561, 0.4138, 2.0919, 4.4262, 3.0679, 0.1206,
1.2992, 0.3604, 0.9815, 5.3896, 2.2389, 1.4619, 0.5554, 1.1389,
3.0333, 0.4737, 6.1656, 21.4709, 6.8805, 0.2201, 0.0894, 2.6072,
0.5320, 1.5473, 1.5057, 3.7250, 0.2388, 0.8536, 1.0563, 0.4387,
0.7722, 0.9998, 0.3145, 4.2262, 0.0869, 0.2006, 1.0294, 0.5411])
tril inverse
tensor([ 0.8577, 1.7351, 2.4454, 0.0628, 4.4942, -8.6023, -0.4101,
0.1345, 2.5317, -0.2871, 1.2715, -16.7043, -6.0439, -2.6443,
0.5777, -0.8966, -5.0385, 0.6307, -13.4730, -28.8008, -3.9316,
0.2600, 0.1201, -6.0625, 0.9570, 1.2705, -4.9844, 4.0703,
0.2386, -0.2930, 0.1016, 0.4507, -0.1250, -0.0625, 0.3105,
-9.5000, -0.2500, -0.2344, -1.0000, -2.5000])
We can see that the current implementation of precision_matrix
property is not stable due to
torch.inverse
(which produces a singular tril matrix). Regardless, precision matrix is singular for both versions.
Regarding other parameterizations, there are two of them which might work:
scale_tril = D @ L
, where D is positive diagonal matrix and L is cholesky of a corr matrix. This is somehow similar to the current lower_cholesky implementation but maybe helpful for avoiding bad scale_tril matrices.cov = W @ W.t + D
. This is a perturbation version of your current solution by using an additional low_rank factor matrix. I think we can relax the current assertions checking if input distributions are MVN/IndependentNormal to allow LowRankMVN. I can make a PR for it if you want to try this way.In the mean time, I will try to see if any of these parameterizations is stable enough.
@fehiepsi I like the low rank idea. What if we implemented a PositiveDefiniteTransform
as something like
class PositiveDefiniteTransform(Transform):
...
def _call(self, x):
assert x.size(-1) == x.size(-2) # assume full rank
y = x.matmul(x.transpose(-1, -2))
jitter = 1e-6 * y.norm(dim=-1, -2).unsqueeze(-1).unsqueeze(-1)
y = y + jitter * torch.eye(y.size(-1))
That is, what if we replaced most uses of constraints.lower_cholesky
with constraints.positive_definite
, since positive definiteness is a little easier to enforce for the actual pd matrix than for a Cholesky factor.
The only downside I see is that forward computations would require an extra cholesky()
and backward computations would require its derivative, and those might be expensive.
@fritzo Except for the computation overhead of those cholesky
decomposition, I think that adding jitter would be fine and is more stable than my other variants. Adding a jitter term proportional to y.norm()
makes sense to me but I don't know if it is a good way in practice. At least, it has a nice property that the condition number is stable if we scale the matrix by a constant factor (condition_number(cA + 1e-6 ||cA||) = condition_number(A + 1e-6 ||A||)).
I tried to implement LKJ way but it also suffered from the issue L @ L.t
is singular despite that L.diag()
is positive. The low rank version is more stable but for high dimensional matrix (I tested with dim=40), small rank cov_factor
also leads to singular issue.
That made me rethinking about the problem. I tried to read kalman filter literatures to see how people deal with the singular problems. At least, we have some alternatives:
In Kalman filter, to deal with the issue of subtracting two positive definite matrices at e.g. marginalization, there is a Joshep form which converting the subtraction to a sum. I am able to derive a similar form for information filter by leveraging the "Joshep" identity
A - A B (Bt A B + R)^-1 Bt A = (I - C Bt) A (I - B Ct) + C R Ct,
where C = A B (Bt A B + R)^-1
Here, B plays the role of transition matrix of affine transform of two variables (e.g. at matrix_and_mvn_to_gaussian
function. However, in our implementation, B is consumed into the matrix AB, and Bt A B term is not available because we added it with some other precision matrix.
According to the book Kalman Filtering: Theory and Practice with MATLAB, using square root forms (Cholesky or LDLt) is a better solution. The formula can be found in the paper (which you sent me). It uses some sorts of "triangularization" (I think it is QR transform?). However, I haven't been able to write down the details for this approach yet.
A third method is a combination of the above two approaches, which is more intuitive to me. By allowing a precision matrix to be represented as H = M Mt (no need for M to be lower triangular), I am able to rescue part of the transition matrix B in the first approach during matrix_and_mvn_to_gaussian
. In particular, we have
gaussian.precision = [ [B @mvn_prec @ Bt, -B @ mvn_prec], [-mvn_prec @ Bt, mvn_prec] ]
hence
gaussian.precision_sqrt = [[B @ mvn_prec_sqrt], [mvn_prec_sqrt]]
This seems enough to resolve the issue of subtracting two positive definite matrix which I mentioned above, while still takes advantages of square root filters. Indeed, the trickiest marginalization operator (Paa - Pab @ inv(Pbb) @ Pba) under this representation will require us provide a square root for a form
N = I - Ct @ inv(C @ Ct) @ C
where C is a m x n matrix such that C @ Ct is non-singular. Using Joshep identity, we can realize that the above matrix N is idempotent, hence its square root is its self!! I think that this is a nice discovery so I am following up with this approach. I'll make a PR for it soon with a hope that I didn't make a mistake during the way. :D
@fehiepsi My current thinking is that Gaussian
is probably already sufficiently stable, and that we would be better off fixing the bad inputs to mvn_to_gaussian()
arising from bad parameters of support constraints.lower_cholesky
.
@fritzo Though fixing bad parameters of constraints.lower_cholesky
is a good solution to make sure precision matrices are positive definite, working with positive definite matrices might lead to the same issues as in KF literatures. The nice properties of square root form is we don't need to preserve the symmetry and positive definite. At least, GaussianS implementation will not involve any cholesky/inverse
operator. If MVN allows prec_scale_tril input, then we won't need any cholesky/inverse
operator from the parameter space to the loss in GaussianHMM. Instead, the corresponding operators in square root form are qr and triangular_solve.
I'm seeing cholesky errors in
_sequential_gaussian_tensordot()
when computingGaussianHMM.log_prob()
, including precision matrices with very negative eigenvalues. I'd like to figure out how to numerically stabilizegaussian_tensordot()
.One issue that makes this difficult to debug is that
torch.eig()
often segfaults due to a ref count bug https://github.com/pytorch/pytorch/issues/24450To reproduce
gaussian.pkl.zip