comp-imaging / ProxImaL

A domain-specific language for image optimization.
MIT License
112 stars 29 forks source link

Problem having 17 individual `sum_squares(conv(x))` takes too long to group/absorb. Image deconvolution with spatially varying PSFs #83

Closed shnaqvi closed 10 months ago

shnaqvi commented 1 year ago

I'm trying to solve an implementation of the space-variant deconvolution algorithm presented in Flicker & Rigaut (2005) and reviewed here.

I have PSFs sampled over the field, which are centered and stacked in a matrix, of which we compute the SVD to get weights and kernels that form the low-rank approximation of the spatially varying blur forward model. (Image, im, weights, W, and kernels, U, are each vectors in R^n, n=1e7 (image of ~10million pixels)).

I construct my problem like so, but it is not able to return after 100s of minutes of runtime. It does not print any debug information either, despite the verbose flag, so I'm not sure what is going on.

P.S. I'm running Python 3.10.1 on a macOS Ventura on an M1 Max 64GB RAM

import proximal as px

def forward(im, G, W):
    for i in range(15):
        weight = W[i,:].reshape(*im.shape)
        psf_mode = G[:,i].reshape(*im.shape)
        im += px.conv(psf_mode, px.mul_elemwise(weight, im))
    return im

def cost_function(x, im, U, W, mu): 
    im_blur = forward(x, U, W)
    data_term = px.sum_squares(px.subsample(im_blur, steps=2) - im) 
    grad_term = mu*px.norm1(px.grad(im_blur)) + (1-mu)*px.sum_squares(px.grad(im_blur))
    return data_term + grad_term + px.nonneg(im_blur)

x = px.Variable(im.shape)
prob = px.Problem(cost_function(x, im, U, W, mu=0.01))
prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())

Can you please tell me:

  1. if I'm doing anything obviously wrong in setting up the problem?
  2. how should I get some debug info be printed to get a sense of where it's at in the optimization?
antonysigma commented 1 year ago

Tagging @SteveDiamond , the owner of the ProxImaL project.

At the first glance, range(15) probably overwhelms the solver. Internally, it was busy copying the pixels x.value by 15 * 2 = 30 times, ahead of mul_elemwise and conv. The same goes for the adjoint operation, copying the im pixels 30 times ahead of the adjoint operation to update x. In total, 60+ times redundant data copy per iterations in the solver.

As you are not utilizing Halide-accelerated compute right now, you may try hand rolling your own "plane-wise" 2D convolution algorithm via the BlackBox API, like this:

from proximal.lin_ops.black_box import LinOpFactory

def my_forward_operation(source_image, simulated_blur, n_psf=15, weights=my_weight):
    assert np.all(source_image.shape == simulated_blur.shape)
    simulated_blur[:] = 0
    for i in range(n_psf):
         weighted_image = source_image * weights[i]
         simulated_blur[:] += my_2d_convolute(psf_mode[i], weighted_image)

def my_adjoint_operation(actual_blur, back_propagated_image, n_psf=15, weights=my_weight):
     # Compute the matched filter of the "psf_model[i]", then use it to 
     # compute plan-wise 2d convolution. Followed by elemwise multiplication with weights.

my_blur_operation = LinOpFactory(
    im.shape,
    im.shape / 2,
    my_forward_operation,
    my_adjoint_operation,
)

Next, use the custom linear operation to define the problem:

x = px.Variable(im.shape)
prob = px.Problem(
    sum_squares( subsample(
        my_forward_operation(x),
        2,
    ))
)
# ...
antonysigma commented 1 year ago

Off-topic question for @SteveDiamond : is ProxImaL capable of inferring the commutative properties between mul_elemise and conv, assuming a cyclic image boundary condition? Should it?

I am asking because I am not sure if @shnaqvi 's problem can utliize ProxImaL's diagonal matrix absorption into the L2 norm to optimize the solver, like this:

\begin{align*}
\def\TV{\alpha\Vert \nabla u \Vert_F^2}\\
& \arg\min_{u} \left\Vert \sum_{i=1}^N \mathrm{conv}[h_i, \mathrm{Diag}(w_i)  u] - b \right\Vert_2^2 + \TV \\

\approx & \arg\min_{u} \sum_{i=1}^N \left\Vert \mathrm{Diag}(w_i) \mathrm{conv}[h_i, u] - \mathrm{Diag}(w_i) b \right\Vert_2^2 + \TV \\

\approx & \arg\min_{u} \sum_{i=1}^N f_i(\mathbf{K}_i u) + \TV \\

& \quad f_i(v)= \Vert \mathrm{Diag}(w_i) v \Vert_2^2 \\
& \mathrm{K}_i(u) = \mathrm{conv}[h_i, u]
\end{align*}

(On second thought, may no... The Proximal user should lead the problem formulation, not the Proximal compiler.)

SteveDiamond commented 1 year ago

Good question! This property is not encoded in the "compiler" but it could be. Some similar transformations are encoded: https://github.com/comp-imaging/ProxImaL/blob/master/proximal/algorithms/absorb.py

shnaqvi commented 1 year ago

Thanks for the insight @antonysigma and sorry I'm really new to matrix-free optimization. I get from your response that I need to do the following:

  1. Replace the built-in operators of mul_elemwise and conv with my own.
  2. Define an adjoint of the forward model and wrap them in a subclass of LinOp (and I see that the LinOpFactory is doing that under the hoods with the BlackBox class).

I've incorporated these below. Question 1 Kindly comment below if the adjoint operation looks good.

from scipy.signal import fftconvolve

def forward(im_src, im_sim, G, W, n_psf=15):
    assert np.all(im_src.shape == im_sim.shape)
    im_sim[:] = 0
    for i in range(n_psf):
        weight = W[i,:].reshape(*im.shape)
        psf_mode = G[:,i].reshape(*im.shape)
        im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
    return im_sim

def adjoint(im_src, im_sim, G, W, n_psf=15):
    assert np.all(im_src.shape == im_sim.shape)
    im_sim[:] = 0
    for i in range(n_psf):
        weight = W[i,:].reshape(*im.shape)
        psf_mode = G[:,i].reshape(*im.shape)
        im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight 
    return im_sim

Question 2 In your implementation though, shouldn't you be then using the my_blur_operation in the solver Problem()? And if you do so, how do you pass the arguments that are needed by the my_forward_operation, i.e. source_image, simulated_blur, psf_mode, weights, n_psf? Do we need to define a new class altogether that inherits from LinOp and then call its _forward() method, passing in the arguments?

I've attached an end-to-end Python implementation below. Would you mind critiquing it to make it work?

import proximal as px
from proximal.lin_ops.lin_op import LinOp
from proximal.lin_ops.black_box import LinOpFactory
from scipy.signal import fftconvolve

def forward(im_src, im_sim, G, W, n_psf=15):
    assert np.all(im_src.shape == im_sim.shape)
    im_sim[:] = 0
    for i in range(n_psf):
        weight = W[i,:].reshape(*im.shape)
        psf_mode = G[:,i].reshape(*im.shape)
        im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
    return im_sim

def adjoint(im_src, im_sim, G, W, n_psf=15):
    assert np.all(im_src.shape == im_sim.shape)
    im_sim[:] = 0
    for i in range(n_psf):
        weight = W[i,:].reshape(*im.shape)
        psf_mode = G[:,i].reshape(*im.shape)
        im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight 
    return im_sim

blur_op = LinOpFactory(im.shape, im.shape, forward, adjoint)

def cost_function(x, im, U, W, mu=0.01): 
    data_term = px.sum_squares(blur_op(x) - im) 
    grad_term = mu*px.norm1(px.grad(im_blur)) + (1-mu)*px.sum_squares(px.grad(im_blur))
    return data_term + grad_term + px.nonneg(im_blur)

im = np.random.randn(2748, 3840)
W = np.random.randn(20, np.prod(im.shape))
U = np.random.randn(np.prod(im.shape), 20)
x = px.Variable(im.shape)
prob = px.Problem(cost_function(x, im, U, W))
prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())
antonysigma commented 1 year ago

In your implementation though, shouldn't you be then using the my_blur_operation in the solver Problem()?

Yes, I had a typo. my_blur_operation should be used.

And if you do so, how do you pass the arguments that are needed by the my_forward_operation?

We don't pass arguments. Proximal does not make the distinction between constants and parameters. But we do accept contributions to enable a parameter1 = px.Parameter(shape) syntax.

To hardcode the constants, write a Python factory function:

def my_factory(G, W, n_psf=15):
    assert W.shape[-1] == n_psf
    precomputed_weights = W.T.reshape((n_psf, *im.shape))
    precomputed_psf_mode = G.T.reshape((n_psf, *im.shape))

    def forward(...):
         #...

    def adjoint(...):
        # ...

    return LinOpFactory(...)

blur_op = my_factory(G, W, n_psf=15)
antonysigma commented 1 year ago

I've attached an end-to-end Python implementation below. Would you mind critiquing it to make it work?

I am not sure about what you mean... If the current code is a working on your computer, sounds great! Otherwise, please submit a draft PR; I am more effective reviewing code in the PR panel.

shnaqvi commented 1 year ago

Thanks. I tried setting up the function as you described with forward and adjoint be the nested functions that see G, W, n_psf defined in the outer function. However, I'm still getting some error. ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all(). Do you know what could be causing it?

Do you know if in this paradigm, LinOpFactory will call forward and adjoint with im_src and im_sim arguments?

import proximal as px
from proximal.lin_ops.lin_op import LinOp
from proximal.lin_ops.black_box import LinOpFactory
from scipy.signal import fftconvolve

def my_factory(G, W, n_psf):

    def forward(im_src, im_sim):
        assert np.all(im_src.shape == im_sim.shape)
        im_sim[:] = 0
        for i in range(n_psf):
            weight = W[i,:].reshape(*im.shape)
            psf_mode = G[:,i].reshape(*im.shape)
            im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
        return im_sim

    def adjoint(im_src, im_sim):
        assert np.all(im_src.shape == im_sim.shape)
        im_sim[:] = 0
        for i in range(n_psf):
            weight = W[i,:].reshape(*im.shape)
            psf_mode = G[:,i].reshape(*im.shape)
            im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight 
        return im_sim

    return LinOpFactory(im.shape, im.shape, forward, adjoint)

def cost_function(x, im, U, W, mu=0.01): 
    blur_op = my_factory(U, W, 15)
    data_term = px.sum_squares(blur_op(x) - im) 
    grad_term = mu*px.norm1(px.grad(im_blur)) + (1-mu)*px.sum_squares(px.grad(im_blur))
    return data_term + grad_term + px.nonneg(im_blur)

# Sample Data
im = np.random.randn(2748, 3840)
W = np.random.randn(20, np.prod(im.shape))
U = np.random.randn(np.prod(im.shape), 20)

# Solve
x = px.Variable(im.shape)
prob = px.Problem(cost_function(x, im, U, W))
prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[154], line 62
     60 x = px.Variable(im.shape)
     61 prob = px.Problem(cost_function(x, im, U, W))
---> 62 prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/problem.py:106](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/problem.py:106), in Problem.solve(self, solver, test_adjoints, test_norm, show_graph, *args, **kwargs)
    104 # Merge prox fns.
    105 if self.merge:
--> 106     prox_fns = merge.merge_all(prox_fns)
    107 # Absorb offsets.
    108 prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:17](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:17), in merge_all(prox_fns)
     14 for i in range(len(prox_fns)):
     15     for j in range(i + 1, len(prox_fns)):
     16         if prox_fns[i] not in merged and prox_fns[j] not in merged and \
---> 17            can_merge(prox_fns[i], prox_fns[j]):
     18             no_merges = False
     19             merged += [prox_fns[i], prox_fns[j]]

File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:32](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:32), in can_merge(lh_prox, rh_prox)
     29 """Can lh_prox and rh_prox be merged into a single function?
     30 """
...
---> 32 if lh_prox.lin_op == rh_prox.lin_op:
     33     if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox:
     34         return True

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
antonysigma commented 1 year ago

I am not sure why the grad term doesn't contain variable x.

Please save the script in the folder proximal/examples folder and then submit a draft pull request. I will study it.

shnaqvi commented 1 year ago

Thanks @antonysigma for continuing to support this. I've just created a pull request with the test_deconv_sv_psf.py script. Kindly review it.

antonysigma commented 10 months ago

Summary: px.sum( px.conv(kernel, x) + px.conv(...) + ...) takes significant time to parse/compile in ProxImaL before the solver can be generated. After that, the solver converges quite well (PR #84). The actual time spent on the compilation is indicated by the first printed message of Problem.solve(verbose=2).