hahnec / torchimize

numerical optimization in pytorch
https://hahnec.github.io/torchimize/
GNU General Public License v3.0
130 stars 7 forks source link

Frequent torch lstsq errors during fits #4

Closed pkienzle closed 11 months ago

pkienzle commented 1 year ago

Attempting to fit a set of simple gaussians I'm encountering frequent torch lstsq errors, presumably because the matrix is ill-conditioned near the minimum.

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-4-bde2f34b56b2>](https://localhost:8080/#) in <cell line: 69>()
     67 extra = dict(meth = "mar", ftol = 0, ptol = 1e-8, gtol = 0)
     68 
---> 69 coeffs = opt(
     70     p = p0,
     71     function = lm_fun,

[/usr/local/lib/python3.10/dist-packages/torchimize/functions/parallel/lma_fun_parallel.py](https://localhost:8080/#) in lsq_lma_parallel(p, function, jac_function, args, wvec, ftol, ptol, gtol, tau, meth, rho1, rho2, beta, gama, max_iter)
     96         D = lm_dg_step(H, D)
     97         Hu = H+u[:, None, None]*D
---> 98         h = -torch.linalg.lstsq(Hu.double(), g.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
     99         f_h = fun(p+h)
    100         rho_nom = torch.einsum('bcp,bci->bc', f, f).sum(1) - torch.einsum('bcp,bci->bc', f_h, f_h).sum(1)

RuntimeError: false INTERNAL ASSERT FAILED at "../aten/src/ATen/native/BatchLinearAlgebra.cpp":1538, please report a bug to PyTorch. torch.linalg.lstsq: (Batch element 0): Argument 4 has illegal value. Most certainly there is a bug in the implementation calling the backend library.

torch version: 2.0.1+cu118

The following runs in a cell in colab. Maybe need to run several times to get it to fail, or increase n. Other optimizers fail similarly.

[Edit: Δy values were too small in some cases. Code modified to remove that problem.]

%pip install torchimize

n = 10

import torch
import numpy as np

sqrt2π = torch.sqrt(torch.tensor(2*torch.pi))
def lm_fx(p, x):
    A, μ, σ = p.T
    x, A, μ, σ = x[None, :], A[:, None], μ[:, None], σ[:, None]
    fx = A*torch.exp(((x - μ)/σ)**2/-2)
    return fx

def lm_fun(p, x, y, dy=1):
    fx = lm_fx(p, x)
    cost = ((fx - y)/dy)**2
    return torch.unsqueeze(cost, 1)  # Add an empty multi-cost dimension

def lm_jac(p, x, y, dy=1):
    A, μ, σ = p.T
    x, A, μ, σ = x[None, :], A[:, None], μ[:, None], σ[:, None]
    shift = x - μ
    scale = shift/σ
    scalesq = scale**2
    G = torch.exp(scalesq/-2)
    fx = A * G
    dfdA = G
    dfdμ = fx * (scale/σ)
    dfdσ = fx * (scalesq/σ)

    Δ = (fx - y)/dy
    #cost = Δ**2
    partials = torch.stack([dfdA, dfdμ, dfdσ], dim=2)
    #print(partials.shape, Δ.shape)
    jac = (2*Δ[..., None])*partials
    return torch.unsqueeze(jac, 1) # Add an empty multi-cost dimension

def gen(n, noise=0):
    x = torch.linspace(0, 50, 30)
    mid = (x[0] + x[-1])/2
    width = (x[-1] - x[0])
    μ, σ, A = torch.randn(n)*width/3 + mid, torch.rand(n)*width/5, 10**(1+torch.rand(n))
    minwidth = 5
    σ[σ < minwidth] = minwidth
    #A = torch.ones_like(A)
    pars = torch.stack([A, μ, σ]).T
    y = lm_fx(pars, x)
    if noise > 0:
        dy = noise*y
        dy[dy < noise] = noise
        y += torch.randn(*y.shape)*dy
    else:
        dy = 1
    return x, y, dy, pars

x, y, dy, target = gen(n, noise=0.05)

# Generate initial guess from the data
np_y, np_x = y.numpy(), x.numpy()
μ = np_x[np_y.argmax(axis=1)]
# Cheat for now and use target value rather than finding FWHM/2.35
σ = target[:, 2].numpy()
A = np_y.max(axis=1)
p0 = torch.tensor(np.vstack([A, μ, σ]).T)

# Bigger cheat: Use the target values as the intial guess
#p0 = target

from torchimize.functions.parallel.lma_fun_parallel import lsq_lma_parallel as opt
extra = dict(meth = "mar", ftol = 0, ptol = 1e-8, gtol = 0)

coeffs = opt(
    p = p0,
    function = lm_fun,
    jac_function = lm_jac,
    args = (x.to(dtype=p0.dtype), y.to(dtype=p0.dtype), dy),
    max_iter = 99,
    **extra,
)
print("num steps", len(coeffs))
best = coeffs[-1]

from matplotlib import pyplot as plt
%matplotlib inline
fx = lm_fx(best, x).detach()
for k in range(min(y.shape[0], 10)):
    h, = plt.plot(x, y[k], '.')
    plt.plot(x, fx[k], '-', color=h.get_color())
pkienzle commented 1 year ago

After fixing the problem with the data simulator there are many fewer errors, but still too many to use on very large datasets with millions of fits.

For now my work around is to limit the number of iterations and add the following to the code before the call to lstsq:

if not Hu.isfinite().all():
    print("WARNING: non-finite values in Hu matrix")
    Hu[~Hu.isfinite()] = 0

It doesn't seem to harm the fitted curves too much, and anyway I need to check χ² on each pixel before using the results because it will sometimes fail to find the correct parameters.

pkienzle commented 1 year ago

Further refinement to deal with singular matrices:

try:
    h = -torch.linalg.lstsq(Hu.double(), g.double(), rcond=None, driver=None)[0].to(dtype=p.dtype)
except Exception:
    print("WARNING: falling back to SVD for Hu matrix")
    U, S, Vh = torch.linalg.svd(Hu, full_matrices=False)
    SVinv = Vh.mT.conj() / S[..., None]
    Uy = U.mT.conj() @ g[..., None]
    h = - (SVinv @ Uy)[..., 0]
pkienzle commented 1 year ago

Note that you can avoid divide by zero on the diagonal using the following, with tol=1e-8:

S[S<tol] = tol
hahnec commented 1 year ago

Hey Paul, Thanks for your valuable feedback. I currently have limited time. If you commit your suggestions as a PR, I am happy to include these changes.

pkienzle commented 1 year ago

I'll create some PRs after the current crunch. For now I'm getting things working, and noting problems as I go. Thank you for your patience.

pkienzle commented 1 year ago

First impression is that the SVD is giving better quality fits albeit more slowly (2 min vs 1.5 min for 20 iterations in one example). I should check the performance of QR as well. It should have similar stability to SVD but depending on the implementation it may be faster.

tvercaut commented 12 months ago

I am also interested in a stable version of LM for pytorch.

@pkienzle In the SVD path, you should not add a threshold to S before inverting it but rather get the pseudo-inverse of S: https://en.wikipedia.org/wiki/Singular_value_decomposition#Pseudoinverse

That is 0 (in practice small) singular values should be mapped to 0 rather than infinity (in practice large), along the lines of:

Spinv = torch.zeros_like(S)
Spinv[S>tol] = 1/S[S>tol]
pkienzle commented 11 months ago

Testing on cuda the QR decomposition is 10x slower than SVD, but SVD is only a little slower than lstsq. I suggest always using SVD.

In practice I'm not seeing any difference between regularization methods in the resulting fits:

Even better would be to use the the SVD directly on the Jacobian without forming H. This is much better conditioned.

For now I'll post a PR with Tikhonov (solve_svd_reg below).

[Edit: fixed indexing in the code below]

def solve_svd_cutoff(Hu, g, tol=1e-5):
    U, S, Vh = torch.linalg.svd(Hu, full_matrices=False)
    alpha = S[...,0:1]*tol
    S = S.maximum(alpha)
    SVinv = Vh.mT.conj() / S[:, None, :]  # was S[..., None]
    Uy = U.mT.conj() @ g[..., None]
    return (SVinv @ Uy)[..., 0]
def solve_svd_reg(Hu, g, tol=1e-5):
    U, S, Vh = torch.linalg.svd(Hu, full_matrices=False)
    alpha = S[...,0:1]*tol
    D = S / (S**2 + alpha**2)
    SVinv = Vh.mT.conj() * D[:, None, :] # was D[..., None]
    Uy = U.mT.conj() @ g[..., None]
    return (SVinv @ Uy)[..., 0]
def solve_svd_zero(Hu, g, tol=1e-5):
    U, S, Vh = torch.linalg.svd(Hu, full_matrices=False)
    alpha = S[...,0:1]*tol
    D = 0*S
    D[S>=alpha] = 1/S[S>=alpha]
    SVinv = Vh.mT.conj() * D[:, None, :] # was D[..., None]
    Uy = U.mT.conj() @ g[..., None]
    return (SVinv @ Uy)[..., 0]
tvercaut commented 11 months ago

Apologies for the off-topic mention here but I just stumbled onto Theseus: https://github.com/facebookresearch/theseus

Have you tried their LM implementation?

pkienzle commented 11 months ago

Because of a misplaced indexing element my previous solver code was giving incorrect results. Using the code below torchimize will pass its unit tests.

The torchimize code based on lstsq is now working, and fits to my dataset are no longer giving errors. I have no explanation. Given that runtime with SVD is considerably longer (contrary to what my previous timing had shown) I'm going to hold off creating a PR.

Note: detailed testing shows that torch single precision linear alegbra libraries perform poorly on ill-conditioned matrices, so using double in the code below.

I will close the issue for now and reopen when I come across a fitting problem that requires it.

def solve_svd_reg(H, g, tol=1e-5):
    dtype = H.dtype
    H, g = H.double(), g.double()
    if not H.isfinite().all():
        warn("non-finite values in H matrix")
        H[~H.isfinite()] = 0
    U, S, Vh = torch.linalg.svd(H, full_matrices=False)
    alpha = S[..., 0:1]*tol
    #D = 1 / S # No protection against singular matrices
    D = S / (S**2 + alpha**2)  # Tikhonov regularization
    #D = 1 / torch.maximum(S, alpha)  # clip S to alpha when S<alpha
    #D = 0*S; D[S>=alpha] = 1/S[S>=alpha] # clip 1/S to zero when S < alpha
    SVinv = Vh.mT.conj() * D[:, None, :]
    Uy = U.mT.conj() @ g[:, :, None]
    return (SVinv @ Uy)[..., 0].to(dtype=dtype)