tknopp / SparsityOperators.jl

Other
0 stars 3 forks source link

Maximul level transform for Wavelet operator in non 2^n matrix #12

Open aTrotier opened 2 years ago

aTrotier commented 2 years ago

Hi,

In the MRIReco example we have to deal with matrix x = 220x160 where the maximum level transform is respectively : 2 x 5 In order to make the code works I made a PR in Wavelets.jl which will perform only a L=2 transform.

My guess is that it is better to have a higher number of level transform for sparsity. Do you think we should zero-filled to a squared matrix of the nearest above 2^n matrix ?

Something like that (here I just zero-filled the matrix to be squared and not 2^n) :

function WaveletOp_custom(T::Type, shape, wt=wavelet(WT.db2))
    shape_z = (maximum(shape),maximum(shape))
    L = minimum(Wavelets.Util.maxtransformlevels.(shape))
    @info "Number of wavelet transform = $L"

    function zeroS(im)
        im = reshape(im,shape)
        im_z = zeros(typeof(im[1]),shape_z)
        im_z[1:shape[1],1:shape[2]] .= im
        return im_z
    end

    function cropS(im_z)
        im = im_z[1:shape[1],1:shape[2]]
        return im
    end

    function EH(W_im)
        im = vec(cropS(idwt(reshape(W_im,shape_z),wt,L)))
        return im
    end

    function E(im)
        W_im = vec(dwt(zeroS(reshape(im,shape)),wt,L))
        return W_im
    end

    return LinearOperator{T}(prod(shape_z), prod(shape), false, false
              , wrapProd( x->E(x))
              , nothing
              , wrapProd( y->EH(y)))
  end
aTrotier commented 1 year ago

I have a working example to use the maximum level transform along each dimension. (I think it is what Bart does) Should I implement it in this package ?

using ImageUtils
using Wavelets
using Plots

ph = shepp_logan(128)
heatmap(ph)

W_ph = dwt(ph,wavelet(WT.db2))
heatmap(abs.(W_ph),clim=(0,1))

## perform Wavelet on rectangular matrix with different maximum deph level

ph2 = ph[:,1:80]
sph2 = size(ph2)
MaxL = Int[]
for (i,val) in enumerate(sph2)
    push!(MaxL,maxtransformlevels(val))
end
display(MaxL)

# when using dwt only max level transform is used -> 4
W_ph = dwt(ph2,wavelet(WT.db2))
heatmap(abs.(W_ph),clim=(0,1))

# we can perform dwt along one axis for max transform = 7
res = zeros(Float64,sph2)
for i in 1:sph2[1]
    res[i,:] = dwt(ph2[i,:],wavelet(WT.db2))
end

# and then the second axis for maxtransform = 4
for j in 1:sph2[2]
    res[:,j] = dwt(res[:,j],wavelet(WT.db2))
end
heatmap(res)

## get back the original image
im = zeros(Float64,sph2)
for j in 1:sph2[2]
    im[:,j] = idwt(res[:,j],wavelet(WT.db2))
end

for i in 1:sph2[1]
    im[i,:] = idwt(im[i,:],wavelet(WT.db2))
end
heatmap(im)
tknopp commented 1 year ago

Interesting. Ideally that would be implemented in Wavelets.jl. But I would be pragmatic here and would support that we first use you implementation in SparsityOperators.jl. Would be good, however, if this could be implemented using the inlace variant dwt!. Not sure if that works though.

aTrotier commented 1 year ago

I'll share the example in an issue for wavelets.jl and see if I can put that directly into it.

Le mer. 26 oct. 2022 à 23:55, Tobias Knopp @.***> a écrit :

Interesting. Ideally that would be implemented in Wavelets.jl. But I would be pragmatic here and would support that we first use you implementation in SparsityOperators.jl. Would be good, however, if this could be implemented using the inlace variant dwt!. Not sure if that works though.

— Reply to this email directly, view it on GitHub https://github.com/tknopp/SparsityOperators.jl/issues/12#issuecomment-1292698411, or unsubscribe https://github.com/notifications/unsubscribe-auth/AC5P7O3OZLBYIYSSZVUVQTLWFGSGHANCNFSM5WATKCBQ . You are receiving this because you authored the thread.Message ID: @.***>

aTrotier commented 1 year ago

something like that ?

ph3 = copy(ph2)
tmp = similar(ph3)
#tmp2 = similar(ph3)
@views for i in 1:sph2[1]
   dwt!(tmp[i,:], ph3[i,:],wavelet(WT.db2))
end
heatmap(y)
# and then the second axis for maxtransform = 4
@views for j in 1:sph2[2]
    dwt!(ph3[:,j],tmp[:,j],wavelet(WT.db2))
end
heatmap(ph3)

I can't do dwt!(ph3[i,:], ph3[i,:],wavelet(WT.db2)). It returns an error from line : https://github.com/JuliaDSP/Wavelets.jl/blob/3451b210e0e817281b8c0b1647ea8a51630a5979/src/mod/transforms_filter.jl#L32

aTrotier commented 1 year ago

The choice of implementation seems more complex than I thought and maybe we should stick to a spectific mri implementation (see : https://github.com/JuliaDSP/Wavelets.jl/issues/80) and bring the discussion back to MRIReco because it will be linked to the specificity of FISTA / ADMM at some point.

tknopp commented 1 year ago

I don't know. Ideally we get the desired feature into Wavelets.jl but this certainly will be a longer process. That there are different possibilities just means that it needs to be configurable.

In the mean time you can prototype it within SparsityOperators.jl. I don't see a direct link to FISTA / ADMM and this think that this here is the correct place to implement it. It would be good to have some option to switch between implementations, i.e. by default we could keep the current behavior but change it when passing certain options.

aTrotier commented 1 year ago

We add a discussion with @JeffFessler about the possibility to zero-pad the data before Wavelet transform and he thought it might be a good option for ADMM not for proximal algorithm.

https://github.com/MagneticResonanceImaging/MRIReco.jl/issues/78#issuecomment-1152188154

I don't have any experience with zero-padding prior to a wavelet transform. The DWT is a unitary transform that simplifies proximal operators, whereas the product of the DTW with a zero-padding matrix is not unitary so that seems likely to complicate fast proximal optimization methods, but for splitting methods like ADMM it would be no problem. Zero-padding of axial slices that have air all around seems benign. But for coronal or sag. slices (or in the z dimension for 3D imaging) I would not recommend zero padding because of possible discontinuities at the top/bottom slices, unless a slab excitation is used that makes those end slices close to zero anyway.

We have a lot of cases available and I really don't know the impact with CS reconstruction :

I can implement some of them here, without optimization and try to benchmark the effect :

tknopp commented 1 year ago

Well, I totally agree with Jeff that one first has to put some thoughts on what it means to put a non-unitary transform into FISTA / ADMM. But having the option to play around with different Wavelet transforms in SparsityOperators.jl still would be very nice. In particular it would be very helpful to have the variant that BART implements because we can then directly compare.

JeffFessler commented 1 year ago

Let me add a comment to balance the theoretical and practical considerations here. For a unitary transform, the proximal operator for the 1-norm with that transform in it is simply to transform, soft threshold, and then un-transform, as is well known. In the era of deep learning it is common to replace that type of proximal-based denoising step with a deep network, at which point the lovely convergence theory disappears but the results can still be practically useful. Likewise, it might be quite practically useful to have the option to do pad,transform,threshold,un-transform,un-pad as a denoising operator. However that sequence of steps does not correspond to optimizing any cost function (to the best of my knowledge) so the results might depend on the initial image and the particular iterative algorithm. It would be fine with me to provide that recipe as an option for users, but I would recommend that it come with a warning (@warn?) so that the user is fully aware that they have departed from the assurances of convex optimization. My preference though would be to find a general solution in Wavelets.jl if possible...

tknopp commented 1 year ago

Yes, that are very good points. I also directly though about plug-and-play denoisers and actually we are about to implement just that in the near future (see this paper https://link.springer.com/chapter/10.1007/978-3-031-17247-2_11). Therefore, I am currently thinking about how to get them into RegularizedLeastSquares.

But still, to make this clear: This package SparsityOperators.jl aka LinearOperatorCollection.jl has no knowledge about convex optimization and therefore it is perfectly fine to implement any non-unitary transformation here. This issue needs to be addressed when building up the optimization problem. In our case this happens in MRIReco.jl in this line https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Reconstruction/RecoParameters.jl#L107 and the sparseTrafo lines directly above. Later in the reconstruction we then plug the sparsifying transform into the regularizer https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Reconstruction/IterativeReconstruction.jl#L37. And that is exactly that line, where we can ask the SparseOp: "Are you unitary?" and then we can do something about that.

My preference though would be to find a general solution in Wavelets.jl if possible...

Yes indeed. It then needs someone from us diving deeper into that package. I know that this can be a major barrier and therefore did not want to slowdown @aTrotier if it is easier for him to first start a prototype here. But this would just be a stop-gap solution. It definitely belongs into Wavelets.jl.