lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
105 stars 17 forks source link

Proximal operator for a custom prior #437

Closed matinaz closed 1 year ago

matinaz commented 1 year ago

We want to solve an inverse problem with a custom prior, so we created a new Functional with our prior as evaluation function. However, ADMM solver needs a proximal operator, but our prior is complicated, so it is not clear how to evaluate it analytically. Is there a way to evaluate it using SCICO, or is there any solver implemented in SCICO apart from ADMM that does not require a proximal operator?

bwohlberg commented 1 year ago

There are a number of other algorithms other than ADMM in scico.optimize, but they're all proximal algorithms that depend on your functional having an efficiently computable proximal operator. In principle you could use scico.solver.minimize (a wrapper for scipy.optimize.minimize) to compute the proximal operator, but that wouldn't be feasible for anything other than a really small-scale problem.

Without knowing more about your prior functional, the best I can suggest would be to try a minimizer like jaxopt.LBFGS from the jaxopt package applied to the entire inverse problem functional.

bwohlberg commented 1 year ago

Closing as resolved. Feel free to re-open if you have further questions.

matinaz commented 1 year ago

I'm afraid I have to re-open the issue, since the proposed solution didn't work. In fact, we tried to use jaxopt.LBFGS, but it seems that it doesn't work with operators. A second thought was to use scico.optimize to evaluate the whole h=f+λg function, instead of evaluating only the prior, but it seems to be a large-scale problem, so it was runnong until being killed. Is there any suggestion to solve the problem? In order to help you, I provide some details on the function and the way we implement it: f = loss.SquaredL2Loss(y=Y, A=X), where Y is a jax array and X is a linear operator implemented with linop g is our prior implemented as follows:

import cv2 import numpy as np import jax import jax.numpy as jnp import sys import scico import scico.numpy as snp

block_size=9 target_value = 5.0 n = 7 lam = 1000

def mse_similarity(block1, block2):

Calculate the Mean Squared Error (MSE) similarity between two blocks.

#block1=np.asarray(block1)
#block2=np.asarray(block2)
diff = block1.astype("float") - block2.astype("float")
squared_diff = diff ** 2
mse = jnp.mean(squared_diff)
return mse

def calculate_block_similarity(image, x1, y1, x2, y2, block_size):

Calculate the similarity norm between two blocks in an image.

# Extract the two blocks from the image
block1 = image[y1:y1 + block_size, x1:x1 + block_size]
block2 = image[y2:y2 + block_size, x2:x2 + block_size]
#block1=np.asarray(block1)
#block2=np.asarray(block2)

# Calculate the similarity using the MSE metric
similarity = mse_similarity(block1, block2)
return similarity

def calculate_sumsimilarity(image, xi, yi):

Calculate the sum of similarity norms in an image

#image=np.asarray(image)
block_size = 9
Ax = 0
x1 = xi
y1 = yi
for i in range(1,6):
    x2=x1+i*block_size
    y2=y1+i*block_size
    similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
    Ax = Ax + (1 + lam * similarity_norm)**(-(n + 1)/2)/20
    x2 = x1
    y2 = y1 + i*block_size
    similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
    Ax = Ax + (1 + lam * similarity_norm)**(-(n + 1)/2)/20
    x2 = x1 + i*block_size
    y2 = y1
    similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
    Ax = Ax + (1 + lam * similarity_norm)**(-(n + 1)/2)/20
    x2 = x1 - i*block_size
    y2 = y1 - i*block_size
    similarity_norm = calculate_block_similarity(image, x1, y1, x2, y2, block_size)
    Ax = Ax + (1 + lam * similarity_norm) ** (-(n + 1) / 2) / 20

return Ax

def solve_for_x(image):

Find the pixel value 'x' that satisfies h(x) = log(B)

#image=np.asarray(image)
height, width = image.shape
B = 0.0
for y in range(60, height-60, block_size):
    for x in range(60, width-60, block_size):
       sumsimilarity = calculate_sumsimilarity(image, x, y)
        B += jnp.log(sumsimilarity)
        #print(B)
res=B/(height*width)
print(B)
print(res)
return abs(res)
bwohlberg commented 1 year ago

The g functional is solve_for_x?

matinaz commented 1 year ago

Yes, it is.

bwohlberg commented 1 year ago

I'm afraid I don't see any simple solutions. Most proximal algorithms will use a proximal operator for g, but it's going to be very slow if the prox is not efficiently implemented. It may be worth a careful look at the function to see if there is any way of designing a relatively efficient prox, but my guess is that it will be difficult or impossible.

matinaz commented 1 year ago

This is our opinion too. Thank you so much for trying, we will search it a lit bit more and if we figure out any efficient solution we'll let you know.

bwohlberg commented 1 year ago

One option that may not be possible, but that is perhaps worth considering: can you express your functional in the form $$g(x) = h(A(x))$$ where $h$ is a simple prox-friendly functional, and $A$ is a non-linear operator implemented in JAX so that it can be auto-differentiated? If so, you should be able to use the scico Non-linear Proximal ADMM solver to solve the problem.

matinaz commented 1 year ago

Well, we managed to implement g(x)=h(A(x)), but we solve h(A(x)) at once, since we don't have A(x), and we get an image as a result, not a proximal. So, is there any possibility this might be used into ADMM? I can provide the code, of course, if it may help you.

bwohlberg commented 1 year ago

I'm afraid I can't think of a way of making this work. The approach I suggested depends on being able to separately compute $h(\cdot)$ and $A(\cdot)$ so that the Non-linear Proximal ADMM solver can be configured so that it uses the prox of $h$ and the gradient of $A$.

matinaz commented 1 year ago

This is what I imagined...thank you very much, I'll try to find a way to make it work