lanl / scico

Scientific Computational Imaging COde
BSD 3-Clause "New" or "Revised" License
90 stars 17 forks source link

Problem with custom denoiser in ADMM #509

Closed matinaz closed 2 months ago

matinaz commented 2 months ago

We modified super-resolution to work with custom denoiser, however something must be wrong. The result gets worse iteration by iteration, while when trying another denoiser inside MyFunctional, the results remained the same. Might be something wrong with the implementation of MyFunctional that doesn't let ADMM take into consideration the result of the denoiser?

Below is the main code:

import jax import jax.numpy as jnp from jax import jit, grad from jax.scipy.ndimage import map_coordinates from jax.image import resize from jax import device_get import math from scipy import misc import jax.scipy as jsp import cv2 import skimage from skimage.metrics import structural_similarity as ssim import scico import scico.numpy as snp import scico.random import scico.examples import os import numpy as np from numpy import array import matplotlib.pyplot as plt from scico.metric import mse from scico import denoiser, functional, linop, loss, metric, plot from scico.data import kodim23 from scico.optimize.admm import ADMM, LinearSubproblemSolver from scico.solver import cg, minimize from scico.util import device_info from scico.functional import Functional from scipy.optimize import minimize import chandenoised2 from scico.optimize import LinearizedADMM from PIL import Image import matplotlib.pyplot as plt import imageio from skimage.util import random_noise

plot.config_notebook_plotting()

rate = 4 # downsampling rate σ = 2e-2 # noise standard deviation

σ=0

lambda_val=1.0

m = 0; variance = 0.01;

class MyFunctional(scico.functional.Functional):

has_eval = True
has_prox = True

#def _eval( z: jnp.array) -> jnp.array:
#     return chandenoised.main(jnp.array(z))

#def __init__(self):
    #super().__init__()

def __call__(self, z):

    return 0.0

def prox(self, z: jnp.array, lambda_val: float) -> jnp.array:
    #z_np = np.array(device_get(z))  # Ensure z is a NumPy array
    #result_np = chandenoised.main(z_np)  # Denoise using chandenoised
    #result = jnp.array(result_np)
    #return jnp.array(result)
    result=chandenoised2.main(z)
    return jnp.array(result)

def warp_flow(flow, img2):

# Calculate the coordinates for motion compensation
x, y = np.meshgrid(np.arange(img2.shape[1]), np.arange(img2.shape[0]))
x=np.clip(x -0.5*flow[..., 1],0,img2.shape[1]-1)
y=np.clip(y -0.5*flow[..., 1],0,img2.shape[0]-1)
coords = [y, x]

# Map the coordinates to the input frame2
compensated_frame2 = map_coordinates(img2, coords, order=1, mode='nearest')

compensated_frame2=snp.array(compensated_frame2)
compensated_frame2=snp.roll(compensated_frame2,719,1)
return compensated_frame2

def motionestimation(frame1, frame2):

Compute the Fourier transforms of the two frames

frame1=np.array(frame1)
frame2=np.array(frame2)
frame1 = cv2.cvtColor(frame1, cv2.COLOR_RGB2GRAY)
frame2 = cv2.cvtColor(frame2, cv2.COLOR_RGB2GRAY)

f1 = np.fft.fft2(frame1)
f2 = np.fft.fft2(frame2)

# Compute the cross power spectrum and the phase correlation matrix
cps = np.multiply(np.conj(f1), f2)
phase_corr = np.fft.ifft2(np.divide(cps, np.absolute(cps)))

# Find the peak in the phase correlation matrix (excluding the DC component)
height, width = phase_corr.shape
mid_height = int(height / 2)
mid_width = int(width / 2)
phase_corr[mid_height, mid_width] = 0
peak_loc = np.unravel_index(np.argmax(phase_corr), phase_corr.shape)

# Compute the displacement vector
x=peak_loc[1]
#if x>mid_height:
#   x=x-height
y=peak_loc[0]
#if y>mid_width:
#   y=y-width
return x,y

def downsample_image(img, rate):

print (img.shape)

img = img.reshape((img.shape[0]//rate, rate, img.shape[1]//rate, rate)).mean(axis=(1,3))
#print (img2.shape)
#img=img[::rate,::rate]
return img

def roll1(x, flow, a): if (a==1): r=warp_flow(flow, x) d=downsample_image(r, rate) else: d=downsample_image(x, rate) d=snp.roll(d,719,1) return d

for ii in range (1,2):

frames = []
for i in range(ii,ii+3):
    frame = cv2.imread("/home/matina/Επιφάνεια εργασίας/scico/Vid4/calendar/0%i.png" %i)
    frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
    frame = skimage.img_as_float(frame)
    frames.append(frame)

frames = np.array(frames)
frames = jax.device_put(frames)
#print (frames.shape)

img=cv2.imread("/home/matina/Επιφάνεια εργασίας/scico/Vid4/calendar/0%i.png" %(ii+1))
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
img= skimage.img_as_float(img)
img = np.array(img)
img_down = downsample_image(img, rate//2)
img_down=np.clip(img_down, 0, 1)

Nx, Ny=img.shape
q = np.zeros((Nx, Ny))
dimension = (Ny, Nx)
img_inter=cv2.resize(img_down, dimension, interpolation = cv2.INTER_CUBIC)
img_inter=np.clip(img_inter, 0, 1)
min_val = np.min(img_inter)
max_val = np.max(img_inter)
print (min_val, max_val)
imageio.imwrite('img_inter.png', img_inter)
#ave=np.mean(img)
#img=img-ave

noise, key = scico.random.randn(img.shape, seed=0)
#img2=random_noise(img, mode='gaussian', mean=m, var=variance, clip=True)
img2=img+σ*noise
img2= np.clip(img2, 0, 1)
min_val = np.min(img2)
max_val = np.max(img2)
#print (min_val, max_val)
imageio.imwrite('img2.png', img2)
img = jax.device_put(img)

print(img.shape)

Afn = lambda x: downsample_image(x, rate=rate)
s = Afn(img)
input_shape = img.shape
output_shape = s.shape
noise, key = scico.random.randn(s.shape, seed=0)
sn = s+noise*σ

#frame1 = img
frame2=np.array(frames[2])

frame0=np.array(frames[0])

frame1=np.array(frames[1])

frame3=np.array(frames[4])

flow1=cv2.calcOpticalFlowFarneback(frame2.mean(-1), frame0.mean(-1), None, 0.5, 3, 15, 3, 5, 1.2, 0)
flow2=cv2.calcOpticalFlowFarneback(frame2.mean(-1), frame1.mean(-1), None, 0.5, 3, 15, 3, 5, 1.2, 0)
flow3=cv2.calcOpticalFlowFarneback(frame2.mean(-1), frame2.mean(-1), None, 0.5, 3, 15, 3, 5, 1.2, 0)

flow4=cv2.calcOpticalFlowFarneback(frame2.mean(-1), frame3.mean(-1), None, 0.5, 3, 15, 3, 5, 1.2, 0)

Afn1= lambda x: roll1(x,flow1, a=1)
s1=downsample_image(frame0, rate=rate)

Afn2= lambda x: roll1(x,flow2, a=1)
s2=downsample_image(frame1, rate=rate)

Afn3= lambda x: roll1(x,flow3, a=1)
s3=downsample_image(frame2, rate=rate)

Afn4= lambda x: roll1(x,flow4, a=1)

s4=downsample_image(frame3, rate=rate)

noise, key = scico.random.randn(s1.shape, seed=0)
sn1=s1+noise*σ
noise, key = scico.random.randn(s2.shape, seed=0)
sn2=s2+noise*σ
noise, key = scico.random.randn(s3.shape, seed=0)
sn3=s3+noise*σ

noise, key = scico.random.randn(s4.shape, seed=0)

sn4=s4+ σ * noise

A = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn)
A1 = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn1)
A2 = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn2)
A3 = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn3)

A4 = linop.LinearOperator(input_shape=input_shape, output_shape=output_shape, eval_fn=Afn4)

X= scico.linop.VerticalStack((A1, A2, A3), collapse=False)
Y= snp.blockarray([s1, s2, s3])

Chantas

#λ = 10 # L1 norm regularization parameter
f = loss.SquaredL2Loss(y=Y, A=X)
C = linop.Identity(input_shape=input_shape)

g=MyFunctional()
#g=functional.DnCNN("17M")

g=h

xpinv, info = cg(A.T @ A, A.T @ s, snp.zeros(input_shape))
dncnn = denoiser.DnCNN("17M")
denoised=chandenoised2.main(img_inter)
denoised=np.clip(denoised, 0, 1)
imageio.imwrite('den_image.png', denoised)
xden = img_inter
#xden=img
#xden=frames[1]
xden=np.clip(xden, 0, 1)
imageio.imwrite('my_image.png', xden)
min_val = np.min(xden)
max_val = np.max(xden)
print (min_val, max_val)

ρ = 3.4e-2 # ADMM penalty parameter
maxiter = 12# number of ADMM iterations
solver = ADMM(
    f=f,
    g_list=[g],
    C_list=[C],
    rho_list=[ρ],
    x0=xden,
    maxiter=maxiter,
    subproblem_solver=LinearSubproblemSolver(cg_kwargs={"tol": 1e-3, "maxiter": 10}),
    itstat_options={"display": True},
)

print(f"\nChantas Solving on {device_info()}\n")

xppp = solver.solve()
min_val = np.min(xppp)
max_val = np.max(xppp)
#xppp=xppp+ave

xppp=np.array(xppp)
#xppp=(xppp - min_val) / (max_val - min_val)
xppp=np.clip(xppp, 0, 1)
print (min_val, max_val)

file_name = 'Chanoutputcalendar{:03d}.png'.format(ii+1)
imageio.imwrite(file_name, xppp)

fig, ax = plot.subplots(nrows=1, ncols=2, figsize=(15, 5))
plot.imview(img, title="Ground truth", fig=fig, ax=ax[0])
plot.imview(xppp, title="Deconvolved image: %.2f (dB)" % metric.psnr(img, xppp), fig=fig, ax=ax[1])
file_name = 'Chanoutputcalendar{:03d}.svg'.format(ii+1)
fig.savefig(file_name)

And the code of the denoiser:

import cv2 import numpy as np from scipy.special import gamma from skimage.metrics import structural_similarity as ssim from PIL import Image from statistics import mean from numpy.fft import fft,fft2,ifft,ifft2,fftshift from numpy import real,conj,sort import numpy from numpy.linalg import norm import math from scico import denoiser, functional, linop, loss, metric, plot from skimage.transform import resize from skimage.util import random_noise import imageio import skimage

def myimresize(x, Nx, Ny, filtDFTcoef, decFac, extras): Nx = Nx + 2 extras Ny = Ny + 2 extras

if len(x.shape) == 3:

#    x=x[:, :, 0]
x_ = np.pad(x, ((extras, extras), (extras, extras)), mode='symmetric')
xfilt = np.real(np.fft.ifft2(fft2(x_) * filtDFTcoef))
y = xfilt[extras:Nx - extras:decFac, extras:Ny - extras:decFac]
return y

def myimresizeTranspose(x, Nx, Ny, filtDFTcoef, decFac, extras): Nx = Nx + 2 extras Ny = Ny + 2 extras x = np.zeros((Nx, Ny)) x[0:Nx:decFac, 0:Ny:decFac] = np.pad(x, [(extras // decFac, extras // decFac), (extras // decFac, extras // decFac)], mode='symmetric') y = np.real(np.fft.ifft2(fft2(x) * np.conj(filtDFTcoef))) y = y_[extras:Nx - extras, extras:Ny - extras] return y

def conjGradients(f, x0, b, iw, rw, rw2, maxiter, rtol): x = x0 r = b - Amat(x0, iw, rw, rw2) p = r rnormStart = np.sum(r 2) rnorm = rnormStart for iter in range(1, maxiter + 1): Ap = Amat(p, iw, rw, rw2) r_norm_sq = np.sum(r 2) p_Ap = np.sum(p * Ap)

    a = r_norm_sq / p_Ap

    if iter % 20 == 0:
        if a * np.linalg.norm(p) ** 2 / np.linalg.norm(x) ** 2 < rtol:
            return x

    x = x + a * p
    r = r - a * Ap

    rnormprev = rnorm
    rnorm = np.sum(r ** 2)
    beta = rnorm / rnormprev
    p = r + beta * p

print('Exit due to reaching the maximum number of iterations')
return x

def Amat(x, iw, rw, rw2): Nx = iw[4] Ny = iw[5] maxX = iw[6] maxY = iw[7]

P1 = np.array(rw).shape
P = P1[0]
decFactor = iw[1]
H1 = iw[2]
extras = iw[3]
coord = np.zeros((P, 2))

for k in range(0, P, 2):
    coord[k, :] = rw[k]
    coord[k, 0] = np.mod(coord[k, 0] + Nx // 2, Nx) - Nx // 2
    coord[k, 1] = np.mod(coord[k, 1] + Ny // 2, Ny) - Ny // 2

x = x.reshape(Nx, Ny)
y_ = myimresize(x, Nx, Ny, H1, decFactor, extras)
y = myimresizeTranspose(y_, Nx, Ny, H1, decFactor, extras)
maxmaxX = np.max(maxX).astype(int)
maxmaxY = np.max(maxY).astype(int)
x_ = np.pad(x, ((maxmaxX,maxmaxX), (maxmaxY,maxmaxY )), mode='wrap')
nx_ = np.arange(1, Nx + 1) + maxmaxX
nx_=nx_[0]
ny_ = np.arange(1, Ny + 1) + maxmaxY
ny_=ny_[0]

for k in range(0, P, 2):
    nx1 = np.arange(1, Nx + 1) + maxX[(k + 1) // 2]
    nx1=nx1[0].astype(int)
    ny1 = np.arange(1, Ny + 1) + maxY[(k + 1) // 2]
    ny1=ny1[0].astype(int)
    if k + 1 < len(rw2):  # Check if k + 1 is within bounds of rw2
        temp1 = x_[nx_, ny_]
        temp2 = x_[nx_ - coord[k, 0].astype(int), ny_ - coord[k, 1].astype(int)]
        temp = rw2[k + 1] * (temp1 - temp2)
        temp = np.pad(temp, ((maxX[(k + 1) // 2].astype(int), maxY[(k + 1) // 2].astype(int))), mode='wrap')
        y = y + (temp[nx1, ny1] - temp[nx1 + coord[k, 0].astype(int), ny1 + coord[k, 1].astype(int)])

return y.flatten()

def stat_rest(alpha, ssigma, g, Q, H1, Nx, Ny): NN = Nx * Ny G = fft2(g) Hf1 = np.abs(H1) 2 Qf = np.abs(Q) 2

p = 100

for k in range(p):
    Mfg = np.conj(H1) * G / (Hf1 + ssigma * alpha * Qf)
    Cfg = ssigma / (Hf1 + ssigma * alpha * Qf)
    s3 = np.sum(np.sum((Cfg + (np.abs(Mfg) ** 2) / NN) * Qf))
    alpha = (NN - 1) / s3
    s1 = np.sum(np.sum((np.abs(H1) ** 2) * (Cfg + (np.abs(Mfg) ** 2) / NN)))
    s2 = np.sum(np.sum((np.abs(G) ** 2 - 2 * np.real(np.conj(G) * H1 * Mfg))) / NN)
    ssigma = (s1 + s2) / NN

fhat = np.real(np.fft.ifft2(Mfg))

return alpha, ssigma, fhat

def main( img2): imNumber = 1 decFactor = 1

pathHR = '/home/matina/Επιφάνεια εργασίας/scico/Vid4/city/030.png'

image = cv2.imread("/home/matina/Επιφάνεια εργασίας/scico/Vid4/calendar/02.png")
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
image= skimage.img_as_float(image)
f=np.array(image)
min_val = np.min(f)
max_val = np.max(f)
range_val = max_val - min_val
#image = ((image - min_val) / range_val) * 255  # Steps 2 and 3: Normalize and scale
#image = image.astype(np.uint8)
#average_intensity = np.mean(image)
#image=image-average_intensity
Nx, Ny=f.shape
extras=3*4;#expand by $extras pixels the boundary of the image by replicating the bounds
#f=f_[:Nx,:Ny];

g=img2
Nx, Ny = f.shape
q = np.zeros((Nx, Ny))
q[0, 0] = -4
q[0, 1] = q[1, 0] = q[-1, 0] = q[0, -1] = 1
Q = fft2(q)
Qp = np.abs(fft2(Q))**2
ssigma = (10 / 255) ** 2
alpha = 1000
alpha, ssigma, x = stat_rest(alpha, ssigma, g, Q, np.ones((Nx, Ny)), Nx, Ny)

g2 = g.copy()
#alpha

Nx = Nx + 2 * extras
Ny = Ny + 2 * extras

hh1 = np.zeros((Nx, Ny))
Hcubic = np.ones_like(hh1)
Nx = Nx - 2 * extras
Ny = Ny - 2 * extras

norm(myimresize(f,Nx,Ny,Hcubic,decFactor,extras)-resize(f, (Nx, Ny)),'fro')**2/(Nx*Ny)

expNum = 400
coord = np.zeros((expNum, 2))

indx2 = 0
indx = 1
optcoord = np.zeros((Nx, Ny))
optcoord[Nx // 2, Ny // 2] = 1

for exper in range(1, expNum + 1):
    nu = 7.0001
    c2 = 1000*alpha
    A = np.zeros(expNum)
    pof = 1 / expNum
    Qp = np.zeros((Nx, Ny))
    E = np.zeros((Nx, Ny)) / expNum

indx2 = 0
indx = 1
optcoord = np.zeros((Nx, Ny))
optcoord[Nx // 2, Ny // 2] = 1

Fmask = np.zeros((Nx, Ny))
Fmask[Nx // 2:2:Nx // 2 + 91, Ny // 2 - 46:2:Ny // 2 + 46] = 1
Fmask = Fmask + np.roll(Fmask, (1, 1))
Fmask[Nx // 2:Nx // 2 + 9, Ny // 2 - 4:Ny // 2 + 4] = 1
Fmask[Nx // 2, Ny // 2:Ny] = 0

g3 = np.zeros((Nx * 2, Ny * 2))
g3[:Nx, :Ny] = g2
fcorr = np.fft.fftshift(np.real(np.fft.ifft2(np.abs(np.fft.fft2(g3 - np.mean(g3))) ** 2)))
fcorr = fcorr[Nx // 2:Nx * 3 // 2, Ny // 2:Ny * 3 // 2]
fcorrMax = fcorr[Nx // 2, Ny // 2]
fcorr = fcorr * (Fmask)

# Sort fcorr and get indices
fcorr_sorted = np.sort(fcorr.ravel())[::-1]
I = np.argsort(fcorr.ravel())[::-1]

k1, l1 = np.unravel_index(I, (Nx, Ny))

expNum = 1
indx = 0

for j1 in range(1, 120):
    indx = j1
    if k1[indx] - Nx // 2 - 1 % Nx != 0 or l1[indx] - Ny // 2 - 1 % Ny != 0:
        coord[expNum - 1, 0] = (k1[indx] - Nx // 2 - 1) % Nx
        coord[expNum - 1, 1] = (l1[indx] - Ny // 2 - 1) % Ny
        optcoord[k1[indx], l1[indx]] = -1
        expNum += 1
    #else:
        #print((k1[indx] - Nx // 2 - 1) % Nx, (l1[indx] - Ny // 2 - 1) % Ny)
    if fcorr[k1[indx], l1[indx]] < fcorrMax * 0.60:
        break

expNum -= 1

sigblurx = 0.25
sigblury = 0.25
hh = np.zeros((Nx, Ny))
for i in range(Nx):
    for j in range(Ny):
        hh[i, j] = np.exp(-(abs(i - np.floor(Nx // 2) - 1) ** 3.0) * sigblurx) * \
                   np.exp(-(abs(j - np.floor(Ny // 2) - 1) ** 3.0) * sigblury)

S=1
hh = S * (hh) / np.sum(np.sum(hh))
HQ = np.fft.fft2(np.fft.fftshift(hh))

E = 0.0;
rw = np.zeros((expNum, 2))
Z = np.ones((Nx, Ny)) / expNum
A = np.zeros((Nx, Ny)) / expNum
Qp = np.zeros((Nx, Ny))
rw2 = np.zeros((expNum, Nx, Ny))
PW = np.ones((Nx, Ny)) / expNum
for exper in range(1, expNum + 1):
    rw[exper-1][0] = coord[exper - 1, 0]
    rw[exper-1][1] = coord[exper - 1, 1]

ssigma = ssigma

maxX = [np.max(np.abs(np.mod((coord[exper - 1, 0] + Nx // 2), Nx) - Nx // 2)) + 1 for exper in range(1, expNum + 1)]
maxY = [np.max(np.abs(np.mod((coord[exper - 1, 1] + Ny // 2), Ny) - Ny // 2)) + 1 for exper in range(1, expNum + 1)]
maxmaxX = np.max(maxX)
maxmaxY = np.max(maxY)

for iter in range(1, 25):
    ZALL = np.zeros((Nx, Ny))
    for exper in range(1, expNum + 1):
        J = (S / 2 + nu / 2)
        E = real(ifft2(fft2((x-numpy.roll(x, rw[exper-1].astype(int),axis=[0,1]))**2+ Qp+numpy.roll(Qp,rw[exper-1].astype(int),axis=[0,1])) *conj(HQ)));
        A= (nu+S)/(c2*E+nu)
        Z = ((math.gamma((nu + S) / 2) / math.gamma(nu / 2)) *(c2 / nu) ** 0.5 *(1 + c2 * E / nu) ** (-(nu + S) / 2))
        rw2[exper - 1] = A * Z
        ZALL = ZALL + Z

    Qp = np.zeros((Nx, Ny))
    for exper in range(1, expNum + 1):
        B = rw2[exper - 1]
        B = B / ZALL
        B = c2 * np.real(np.fft.ifft2(np.fft.fft2(B) * HQ)) / S
        rw2[exper - 1] = ssigma * B
        Qp = Qp + B + np.roll(B, rw[exper - 1].astype(int), axis=(0, 1))

    Imask = np.ones((Nx, Ny))
    Imask[0::decFactor, 0::decFactor] = 0
    Imask = np.logical_not(Imask)
    hcubatrous = np.fft.fftshift(np.real(np.fft.ifft2(Hcubic)))[extras:Nx + extras, extras:Ny + extras]
    hcubatrous = hcubatrous * Imask
    Qp = Qp + Imask * np.sum(np.sum(hcubatrous * hcubatrous)) / ssigma
    Qp = 1.0 / Qp

    iw = [1, decFactor, Hcubic, extras, Nx, Ny, maxX, maxY]

    g1 = myimresizeTranspose(g, Nx, Ny, Hcubic, decFactor, extras)
    #g1=g
    x_prev = x
    b = g1
    x = conjGradients(Amat, x.reshape(-1), b.reshape(-1), iw, rw, rw2, 1000, 1e-10)
    x = x.reshape(Nx, Ny)

    x=np.array(x)
    min_val = np.min(x)
    max_val = np.max(x)
    print (min_val, max_val)

    if np.linalg.norm(x.flatten() - x_prev.flatten())**2 / np.linalg.norm(x.flatten())**2 < 1e-7:
        break
    #print("Iter:", iter)
    #imageio.imwrite('denoised.png', x)
return (x)

if name == "main": main()

bwohlberg commented 2 months ago

The iterations aren't guaranteed to converge for arbitary denoisers. Some suggestions for now:

  1. Try reducing the ρ value.
  2. Reduced the effect of the denoiser by replacing the output with a linear combination with the input.
  3. Substitute one of the standard denoisers to confirm that the remainder of the implementation is working as expected.
matinaz commented 2 months ago

Reducing the ρ value didn't work, but I'll try the other two solutions. So, there is no problem with "MyFunctional" implementation?Is it OK that the call method returns 0.0 ?

matinaz commented 2 months ago

I'm coming back since it seems that prox function in MyFunctional gets the same input in every iteration and it doesn't take into consideration the results of the denoisd output of the previous iteration. Watching the code, could you tell me if it might be due to a wrong implementation of myFunctional or wrong call in ADMM?

bwohlberg commented 2 months ago

If the prox of g is seeing the same input on each iteration, then my first guess would be is that ρ is far too small so that the z variable update is almost the same on each iteration (it would be useful to take a look at the mathematical expression for the updates if this isn't clear).

matinaz commented 2 months ago

This could be a problem indeed, but I thing something else is going on, since I get the same results even when I run ADMM using DNCNN as a denoiser inside prox of MyFunctional. Could it be something in the implementation of MyFunctional that doesn't let z get updated?

bwohlberg commented 2 months ago

That does suggest it's an implementation issue rather than parameters or some more abstract algorithmic issue.

bwohlberg commented 2 months ago

Any progress on this? Do you still want to keep the issue open?

matinaz commented 2 months ago

Problem solved, so I close the issue. Your help was definitelly crucial. Thank you