Closed shnaqvi closed 10 months ago
Tagging @SteveDiamond , the owner of the ProxImaL project.
At the first glance, range(15)
probably overwhelms the solver. Internally, it was busy copying the pixels x.value
by 15 * 2 = 30 times, ahead of mul_elemwise
and conv
. The same goes for the adjoint operation, copying the im
pixels 30 times ahead of the adjoint operation to update x
. In total, 60+ times redundant data copy per iterations in the solver.
As you are not utilizing Halide-accelerated compute right now, you may try hand rolling your own "plane-wise" 2D convolution algorithm via the BlackBox API, like this:
from proximal.lin_ops.black_box import LinOpFactory
def my_forward_operation(source_image, simulated_blur, n_psf=15, weights=my_weight):
assert np.all(source_image.shape == simulated_blur.shape)
simulated_blur[:] = 0
for i in range(n_psf):
weighted_image = source_image * weights[i]
simulated_blur[:] += my_2d_convolute(psf_mode[i], weighted_image)
def my_adjoint_operation(actual_blur, back_propagated_image, n_psf=15, weights=my_weight):
# Compute the matched filter of the "psf_model[i]", then use it to
# compute plan-wise 2d convolution. Followed by elemwise multiplication with weights.
my_blur_operation = LinOpFactory(
im.shape,
im.shape / 2,
my_forward_operation,
my_adjoint_operation,
)
Next, use the custom linear operation to define the problem:
x = px.Variable(im.shape)
prob = px.Problem(
sum_squares( subsample(
my_forward_operation(x),
2,
))
)
# ...
Off-topic question for @SteveDiamond : is ProxImaL capable of inferring the commutative properties between mul_elemise
and conv
, assuming a cyclic image boundary condition? Should it?
I am asking because I am not sure if @shnaqvi 's problem can utliize ProxImaL's diagonal matrix absorption into the L2 norm to optimize the solver, like this:
\begin{align*}
\def\TV{\alpha\Vert \nabla u \Vert_F^2}\\
& \arg\min_{u} \left\Vert \sum_{i=1}^N \mathrm{conv}[h_i, \mathrm{Diag}(w_i) u] - b \right\Vert_2^2 + \TV \\
\approx & \arg\min_{u} \sum_{i=1}^N \left\Vert \mathrm{Diag}(w_i) \mathrm{conv}[h_i, u] - \mathrm{Diag}(w_i) b \right\Vert_2^2 + \TV \\
\approx & \arg\min_{u} \sum_{i=1}^N f_i(\mathbf{K}_i u) + \TV \\
& \quad f_i(v)= \Vert \mathrm{Diag}(w_i) v \Vert_2^2 \\
& \mathrm{K}_i(u) = \mathrm{conv}[h_i, u]
\end{align*}
(On second thought, may no... The Proximal user should lead the problem formulation, not the Proximal compiler.)
Good question! This property is not encoded in the "compiler" but it could be. Some similar transformations are encoded: https://github.com/comp-imaging/ProxImaL/blob/master/proximal/algorithms/absorb.py
Thanks for the insight @antonysigma and sorry I'm really new to matrix-free optimization. I get from your response that I need to do the following:
mul_elemwise
and conv
with my own.LinOp
(and I see that the LinOpFactory
is doing that under the hoods with the BlackBox
class).I've incorporated these below. Question 1 Kindly comment below if the adjoint
operation looks good.
from scipy.signal import fftconvolve
def forward(im_src, im_sim, G, W, n_psf=15):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
return im_sim
def adjoint(im_src, im_sim, G, W, n_psf=15):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight
return im_sim
Question 2 In your implementation though, shouldn't you be then using the my_blur_operation
in the solver Problem()
? And if you do so, how do you pass the arguments that are needed by the my_forward_operation
, i.e. source_image
, simulated_blur
, psf_mode
, weights
, n_psf
? Do we need to define a new class altogether that inherits from LinOp
and then call its _forward()
method, passing in the arguments?
I've attached an end-to-end Python implementation below. Would you mind critiquing it to make it work?
import proximal as px
from proximal.lin_ops.lin_op import LinOp
from proximal.lin_ops.black_box import LinOpFactory
from scipy.signal import fftconvolve
def forward(im_src, im_sim, G, W, n_psf=15):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
return im_sim
def adjoint(im_src, im_sim, G, W, n_psf=15):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight
return im_sim
blur_op = LinOpFactory(im.shape, im.shape, forward, adjoint)
def cost_function(x, im, U, W, mu=0.01):
data_term = px.sum_squares(blur_op(x) - im)
grad_term = mu*px.norm1(px.grad(im_blur)) + (1-mu)*px.sum_squares(px.grad(im_blur))
return data_term + grad_term + px.nonneg(im_blur)
im = np.random.randn(2748, 3840)
W = np.random.randn(20, np.prod(im.shape))
U = np.random.randn(np.prod(im.shape), 20)
x = px.Variable(im.shape)
prob = px.Problem(cost_function(x, im, U, W))
prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())
In your implementation though, shouldn't you be then using the my_blur_operation in the solver Problem()?
Yes, I had a typo. my_blur_operation
should be used.
And if you do so, how do you pass the arguments that are needed by the my_forward_operation?
We don't pass arguments. Proximal does not make the distinction between constants and parameters. But we do accept contributions to enable a parameter1 = px.Parameter(shape)
syntax.
To hardcode the constants, write a Python factory function:
def my_factory(G, W, n_psf=15):
assert W.shape[-1] == n_psf
precomputed_weights = W.T.reshape((n_psf, *im.shape))
precomputed_psf_mode = G.T.reshape((n_psf, *im.shape))
def forward(...):
#...
def adjoint(...):
# ...
return LinOpFactory(...)
blur_op = my_factory(G, W, n_psf=15)
I've attached an end-to-end Python implementation below. Would you mind critiquing it to make it work?
I am not sure about what you mean... If the current code is a working on your computer, sounds great! Otherwise, please submit a draft PR; I am more effective reviewing code in the PR panel.
Thanks. I tried setting up the function as you described with forward and adjoint be the nested functions that see G, W, n_psf
defined in the outer function. However, I'm still getting some error. ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
. Do you know what could be causing it?
Do you know if in this paradigm, LinOpFactory
will call forward
and adjoint
with im_src
and im_sim
arguments?
import proximal as px
from proximal.lin_ops.lin_op import LinOp
from proximal.lin_ops.black_box import LinOpFactory
from scipy.signal import fftconvolve
def my_factory(G, W, n_psf):
def forward(im_src, im_sim):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim[:] += fftconvolve(psf_mode, im_src*weight, mode='same')
return im_sim
def adjoint(im_src, im_sim):
assert np.all(im_src.shape == im_sim.shape)
im_sim[:] = 0
for i in range(n_psf):
weight = W[i,:].reshape(*im.shape)
psf_mode = G[:,i].reshape(*im.shape)
im_sim += fftconvolve(im_src, np.flipud(np.fliplr(psf_mode)), mode='same')* weight
return im_sim
return LinOpFactory(im.shape, im.shape, forward, adjoint)
def cost_function(x, im, U, W, mu=0.01):
blur_op = my_factory(U, W, 15)
data_term = px.sum_squares(blur_op(x) - im)
grad_term = mu*px.norm1(px.grad(im_blur)) + (1-mu)*px.sum_squares(px.grad(im_blur))
return data_term + grad_term + px.nonneg(im_blur)
# Sample Data
im = np.random.randn(2748, 3840)
W = np.random.randn(20, np.prod(im.shape))
U = np.random.randn(np.prod(im.shape), 20)
# Solve
x = px.Variable(im.shape)
prob = px.Problem(cost_function(x, im, U, W))
prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[154], line 62
60 x = px.Variable(im.shape)
61 prob = px.Problem(cost_function(x, im, U, W))
---> 62 prob.solve(solver='pc', max_iters=2, verbose=True, x0=im.copy())
File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/problem.py:106](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/problem.py:106), in Problem.solve(self, solver, test_adjoints, test_norm, show_graph, *args, **kwargs)
104 # Merge prox fns.
105 if self.merge:
--> 106 prox_fns = merge.merge_all(prox_fns)
107 # Absorb offsets.
108 prox_fns = [absorb.absorb_offset(fn) for fn in prox_fns]
File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:17](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:17), in merge_all(prox_fns)
14 for i in range(len(prox_fns)):
15 for j in range(i + 1, len(prox_fns)):
16 if prox_fns[i] not in merged and prox_fns[j] not in merged and \
---> 17 can_merge(prox_fns[i], prox_fns[j]):
18 no_merges = False
19 merged += [prox_fns[i], prox_fns[j]]
File [~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:32](https://file+.vscode-resource.vscode-cdn.net/Users/salman_naqvi/Documents/Project-Display/t288_display_incubation/playground/~/.pyenv/versions/3.10.1/lib/python3.10/site-packages/proximal/algorithms/merge.py:32), in can_merge(lh_prox, rh_prox)
29 """Can lh_prox and rh_prox be merged into a single function?
30 """
...
---> 32 if lh_prox.lin_op == rh_prox.lin_op:
33 if type(lh_prox) == zero_prox or type(rh_prox) == zero_prox:
34 return True
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
I am not sure why the grad term doesn't contain variable x.
Please save the script in the folder proximal/examples
folder and then submit a draft pull request. I will study it.
Thanks @antonysigma for continuing to support this. I've just created a pull request with the test_deconv_sv_psf.py
script. Kindly review it.
Summary: px.sum( px.conv(kernel, x) + px.conv(...) + ...)
takes significant time to parse/compile in ProxImaL before the solver can be generated. After that, the solver converges quite well (PR #84). The actual time spent on the compilation is indicated by the first printed message of Problem.solve(verbose=2)
.
I'm trying to solve an implementation of the space-variant deconvolution algorithm presented in Flicker & Rigaut (2005) and reviewed here.
I have PSFs sampled over the field, which are centered and stacked in a matrix, of which we compute the SVD to get weights and kernels that form the low-rank approximation of the spatially varying blur forward model. (Image, im, weights, W, and kernels, U, are each vectors in R^n, n=1e7 (image of ~10million pixels)).
I construct my problem like so, but it is not able to return after 100s of minutes of runtime. It does not print any debug information either, despite the verbose flag, so I'm not sure what is going on.
P.S. I'm running Python 3.10.1 on a macOS Ventura on an M1 Max 64GB RAM
Can you please tell me: