lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
105 stars 17 forks source link

Setting up Sum of Weighted Convolutions Forward Operator #441

Closed shnaqvi closed 1 year ago

shnaqvi commented 1 year ago

I'm trying to solve an implementation of the space-"variant" deconvolution algorithm presented in Flicker & Rigaut (2005) and reviewed here. The optimization problem can be written like so, where the forward model is a sum of weighted convolutions:

image

The kernels, U, and weights, W, are given by the right and left singular vectors respectively of the matrix formed by stacking the flattened PSFs sampled over the field of view.

The closest representative example that I found on scico's website to assist with implementation seems to be Convolutional Sparse Coding, here.

I understand that we have to define operators and corresponding functionals. However, I'm struggling to define the forward operator, possibly as a standalone function, that takes image, kernels, and weights, x, U and W, and describes the forward model using CircularConvolve() operator. Would you please help?

bwohlberg commented 1 year ago

I think what you want is a spatially-weighted sum of convolutions of a single image, whereas the CSC problem involves a sum of convolutions of a number of matched filter-image pairs. As in the CSC problem example that you referenced, you should be able to construct the forward operator you need by composing CircularConvolve and Sum linear operators. There is currently no fast ADMM subproblem solver for this kind of problem, but you might be able to solve it using one of the other optimization algorithms, such as PDHG or ProximalADMM.

shnaqvi commented 1 year ago

Edit: I was incorrect about getting correct result with PDHG, my cache was not cleared, so I was getting the results from ADMM. I've updated this post to first get the simple "space-invariant" case to work with PDHG

Thanks @bwohlberg , and what about PGM? Do you know which of these 3 solvers would exhibit faster convergence in this case, or is it something I would need to just experiment with?

I see that the implementation of the optimization problem in case of PDHG or PGM are similar to that of ADMM but it seems to be quite different from ProximalADMM, where we are minimizing the f(x) + g(z) where A(x) + B(z) = c. So let me start with PDHG like you suggested.

With PDHG, I started by first getting the simple problem of space-invariant convolution with a single kernel to work:

image

I did that by just defining the forward operator using CircularConvolve, having inspired by the example here:

im_jx = jax.device_put(im_s) 
psf_jx = jax.device_put(psf2_cropped)  

C = linop.CircularConvolve(h=psf_jx, input_shape=im_jx.shape, h_center=[psf_jx.shape[0] // 2, psf_jx.shape[1] // 2])
Cx = C(im_jx)

# PDHG
f = loss.SquaredL2Loss(y=Cx, A=C)
lbd = 5e-1#50  # L1 norm regularization parameter
g = lbd * functional.L21Norm()
D = linop.FiniteDifference(input_shape=im_jx.shape, circular=True)

maxiter = 50
tau, sigma = PDHG.estimate_parameters(D, factor=1.5)
solver_pdhg = PDHG(
    f=f,
    g=g,
    C=D,
    tau=tau,
    sigma=sigma,
    maxiter=maxiter,
    itstat_options={"display": True, "period": 10},
)

print(f"Solving on {device_info()}\n")
x = solver_pdhg.solve()
hist = solver.itstat_object.history(transpose=True)

plt.subplot(121); plt.imshow(im_jx); plt.title('Blurred image')
plt.subplot(122); plt.imshow(x); plt.title(f'Recovered Image; PDHG: lambda: {lbd}, iter: {maxiter}, tau: {tau}');

However, I'm getting this TypeError that the Operation __rmul__ not defined between and . I'm getting this error when passing forward operator, A, into the loss.SquaredL2Loss(). Can you please see if I'm doing anything obviously wrong in setting up the PDHG instance?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[47], line 39
     28 solver_pdhg = PDHG(
     29     f=f,
     30     g=g,
   (...)
     35     itstat_options={"display": True, "period": 10},
     36 )
     38 print(f"Solving on {device_info()}\n")
---> 39 x = solver_pdhg.solve()
     40 hist = solver.itstat_object.history(transpose=True)
     42 plt.subplot(121); plt.imshow(im_jx); plt.title('Blurred image')

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_common.py:198), in Optimizer.solve(self, callback)
    196 self.timer.start()
    197 for self.itnum in range(self.itnum, self.itnum + self.maxiter):
--> 198     self.step()
    199     if self.nanstop and not self._working_vars_finite():
    200         raise ValueError(
    201             f"NaN or Inf value encountered in working variable in iteration {self.itnum}."
    202             ""
    203         )

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_primaldual.py:227](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/scico/optimize/_primaldual.py:227), in PDHG.step(self)
...
     50 if np.isscalar(b) or isinstance(b, jax.core.Tracer):
     51     return func(a, b)
---> 53 raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

TypeError: Operation __rmul__ not defined between  and .
bwohlberg commented 1 year ago

Discussion moved to #443; closing this issue.