lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
98 stars 17 forks source link

Deblurring with simple Gaussian kernel using ADMM not converging #440

Closed shnaqvi closed 1 year ago

shnaqvi commented 1 year ago

I've setup synthetic image and blurred it with an anisotropic Gaussian kernel. I started off with using simple ADMM following the example here to solve the inverse problem. However, solver rapidly diverges to inf. within a few iterations. Can you please check my code below and see if there is anything obviously wrong with the setup of the problem? Also, how do we nest two operators, say FiniteDifference and L21Norm to get the TV loss?

P.S. I'm on M1 Mac on Python 3.10.1

import numpy as np
import matplotlib.pyplot as plt
import jax

#======
# SYNTHETIC DATA

# Create Synthetic Horizontal Stripes Pattern Image
im_s = np.zeros((2748, 3840)).astype(float)
stripe_width, stripe_gap, stripe_start, stripe_end = 50, 50, 500, 500
for y in range(0, im_s.shape[0]-(stripe_start+stripe_end), stripe_width + stripe_gap):
    im_s[stripe_start+y : stripe_start+y + 50, :] = .8

xx, yy = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
im_ctr = (np.array(im_s.shape)/2).astype(int)
r = np.sqrt((xx - im_ctr[0])**2 + (yy - im_ctr[1])**2)
mask = np.zeros_like(im_s).astype(float)
mask[r < 1000] = 1
im_s *= mask
plt.subplot(131); plt.imshow(im_s); plt.title('Ground Truth');

# Create Gaussian Kernel
from scipy.stats import multivariate_normal
x, y = np.mgrid[0:im_s.shape[0], 0:im_s.shape[1]]
pos = np.dstack((x, y))
rv = multivariate_normal.pdf(pos, im_ctr, [[200, 0], [300, 500]])
psf2 = rv/np.max(rv)
psf2_cropped = psf2[im_ctr[0]-100:im_ctr[0]+101, im_ctr[1]-100:im_ctr[1]+101]
plt.subplot(132); plt.imshow(psf2_cropped); plt.title('PSF zoomed in');
# Convolve and Create Blurred Image
im_jx = jax.device_put(im_sb)  
psf_jx = jax.device_put(psf2_cropped)  

C = linop.CircularConvolve(h=psf_jx, input_shape=im_jx.shape, h_center=[psf_jx.shape[0] // 2, psf_jx.shape[1] // 2])
Cx = C(im_jx)
plt.subplot(133); plt.imshow(Cx); plt.title('Blurred Image')
image
from scico import linop, loss, functional
from scico.optimize.admm import ADMM, CircularConvolveSolver
from scico.util import device_info

#======
# SOLVER

f = loss.SquaredL2Loss(y=Cx, A=C)
lbd = 2e-1  
D = linop.FiniteDifference(input_shape=im_jx.shape, circular=True)
g = lbd * functional.L21Norm()

rho = 1.0e-2  
maxiter = 20  

solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[D],
    rho_list=[rho],
    x0=Cx,
    maxiter=maxiter,
    subproblem_solver=CircularConvolveSolver(),
    itstat_options={"display": True, "period": 10},
)

print(f"Solving on {device_info()}\n")
x = solver.solve()
hist = solver.itstat_object.history(transpose=True)

plt.imshow(x)
Solving on CPU

Iter  Time      Objective  Prml Rsdl  Dual Rsdl
-----------------------------------------------
   0  6.97e-01  3.696e+09  6.496e+03  4.452e+04
  10  1.06e+01        nan        nan        nan
  12  1.26e+01        nan        nan        nan
bwohlberg commented 1 year ago

The ADMM penalty parameter is too small. Try setting it significantly larger. Setting rho = 1e0 would be a good place to start, but you'll need to experiment to find the value that gives the best convergence.

shnaqvi commented 1 year ago

Thanks @bwohlberg , I tried various values of rho, 1, 10, 1000, and by 1000, I get reasonable image recovery, however, the solver doesn't seem to apply the TV regularizer norm1(gradient()), no matter what value of lambda I try, 1, 10, 100. I was expecting higher contrast in the recovered image, with sharp edges and blocky content. It seems to me that I haven't setup the solver properly to do the TV regularization. It taken norm1 in g_list and FiniteDifference in C_list but does it actually nest them inside the function?

Do you know how to get the desired deblur results?

image image
bwohlberg commented 1 year ago

The result looks reasonable to me. Try lbd = 5e-1, rho = 5e0, and maxiter = 50.

shnaqvi commented 1 year ago

Wow this worked beautifully @bwohlberg. Would you help me understand the role rho is playing here? I'm curious what symptoms were you seeing that motivated you to use these combinations. Can I change rho and lambda so as to help it converge quicker?

image
bwohlberg commented 1 year ago

lambda determines the strength of the regularization in the functional you want to minimize, and rho plays a major role in the convergence of the ADMM algorithm to minimize that functional, so in principle, at least, you should first choose lambda for best results and then choose rho for best convergence. Convergence will be very slow if rho is both too small or too big, but it's often difficult to know in advance what the right choice is. For problems like this, the best choice for rho is typically somewhere between 10 times and 100 times lambda.