draeloslab / AdaptiveLatents

GNU General Public License v3.0
2 stars 0 forks source link

SVD won't converge #15

Open RioIParsons opened 3 weeks ago

RioIParsons commented 3 weeks ago

When running pro.updateSVD(), it sometimes throws an error saying "SVD won't converge"

Jonathan-Gould commented 3 weeks ago

This code replicates the error (summarizing this function) and documents two possible solutions:

from scipy.linalg import rq
import numpy as np
import scipy.linalg

x = np.loadtxt('x.txt')
Q = np.loadtxt('Q.txt')
R = np.loadtxt('R.txt')

def replicate_error(solution_method='none'):
    x_along = Q.T @ x
    x_orth = x - Q @ x_along
    x_orth_q, x_orth_r = np.linalg.qr(x_orth, mode='reduced')

    if solution_method == 'zero':
        x_orth_r[np.abs(x_orth_r) < 1e-16] = 0

    r_new = np.block([
        [R,                                         x_along],
        [np.zeros((x_orth_r.shape[0], R.shape[1])), x_orth_r]
    ])

    if not solution_method == 'gesvd':
        u_high_d, diag_high_d, vh_high_d = np.linalg.svd(r_new, full_matrices=False)
    else:
        u_high_d, diag_high_d, vh_high_d = scipy.linalg.svd(r_new, full_matrices=False, lapack_driver='gesvd')

if __name__ == '__main__':
    try:
        replicate_error(solution_method='none')
        assert False, 'The above call should raise an error.'
    except np.linalg.LinAlgError:
        pass

    replicate_error(solution_method='zero')
    replicate_error(solution_method='gesvd')

x.txt R.txt Q.txt

Jonathan-Gould commented 3 weeks ago

For this specific case, where Q is square (and so doesn't project out any data) we know x_orth_r should be all zeros, which makes the 'zero' solution appealing to me. However, the issue seems sensitive to the exact input data; the error fails to replicate when either the rows or columns of either x or r_new (or their sub-matrices) are permuted. This makes it unclear to me whether clipping the values in x_orth_r solves an issue at an algorithmic level or if it just makes enough of a change for the default algorithm to work. The biggest drawback of the 'gesvd' solution is that it's slower, so I think the best course of action would be to use it as a fallback:

try:
    u_high_d, diag_high_d, vh_high_d = np.linalg.svd(r_new, full_matrices=False)
except np.linalg.LinAlgError:
    u_high_d, diag_high_d, vh_high_d = scipy.linalg.svd(r_new, full_matrices=False, lapack_driver='gesvd')
Jonathan-Gould commented 2 weeks ago

We need to check if the gesvd solution is significantly slower, and if so if it would be easier to do something else like just add two data blocks at the same time.