MagneticResonanceImaging / MRIReco.jl

Julia Package for MRI Reconstruction
https://magneticresonanceimaging.github.io/MRIReco.jl/latest/
Other
85 stars 22 forks source link

Fista algorithm : accuracy problem #114

Open aTrotier opened 2 years ago

aTrotier commented 2 years ago

Compressed sensing is now really fast with FISTA : https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3655548

But there is still an issue with accuracy (https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3623535). We observed it on phantom : https://atrotier.github.io/MRIRecoVsBART_Benchmark/test_bart.html

And I also have an issue on real MP2RAGE data : https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/ I am not able to reach the same image quality as BART.

If I use ADMM, it works well !

andrewwmao commented 2 years ago

Hi aTrotier, I looked at the notebooks that you linked, and it seems that the second notebook using the in vivo MP2RAGE data is using an old version of 'RegularizedLeastSquares'. I contributed a bugfix to the FISTA algorithm in comparison to BART that should have fixed the accuracy problem. Can you try updating this to v0.8.7?

For the first notebook, my feeling is that there may be an issue with the regularization strength/implementation of the regularizer. Lambdas won't be comparable between Julia/BART esp. when both packages are scaling lambda in different ways. Perhaps you can try testing both BART/MRIReco FISTA implementations setting lambda close to zero to see if they give the same results.

tknopp commented 2 years ago

I don't know if it's just the different scaling but @migrosser also reported that he got better results by just taking another regularization parameter. It would be great if someone could have a deeper look. But probably we should keep this on hold and first look that RegularizedLeastSquares is consistent.

If somebody wants to give this a go: MRIRecoBenchmarks now includes the example script: https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2 It will generate timings and outputs images. Errors are currently not calculated but that would be easy to add.

Regarding the scaling: In my opinion this is primarily a documentation issue. We need to document somewhere how to translate a regularization parameter scaled with BART to MRIReco.

JakobAsslaender commented 2 years ago

So if I understand it correctly, the return value of BART depends whether it terminates based on max number of iteration, in which case it returns after the gradient step, i.e. half way through Eq. 4.1 in the FISTA paper:

https://github.com/mrirecon/bart/blob/5428c0ae9f6cdb1667b549323802682ce1171bd9/src/iter/italgos.c#L252

but if the residual is smaller than the termination threshold, they return after adding the momentum (ravine) step, i.e. after Eq. 4.3 in the FISTA paper:

https://github.com/mrirecon/bart/blob/5428c0ae9f6cdb1667b549323802682ce1171bd9/src/iter/italgos.c#L241

Not sure why they distinguish between these two cases, but maybe @JeffFessler has thoughts?

We currently return after applying the proximal operator, i.e. after Eq 4.1. I think it would be easy to change the behavior to termination after the gradient step by moving line

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L138

below

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L150

or after the momentum step by moving it below

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/22d58db104e374232a4e7d99b8863cd5a3ac36af/src/FISTA.jl#L146

but I am not sure that the done function supports a distinction based why we are terminating.

Do you, @JeffFessler or @andrewwmao have comments on what is the "right" termination point?

tknopp commented 2 years ago

Intuitively, I would have said after the prox step because only then the solution lies in the desired subspace (as Jeff has said). Making it dependent on maxiter seems to be suboptimal. We should not replicate that.

JakobAsslaender commented 2 years ago

Well in this case the current implementation is correct :). @aTrotier : Have you played with lambda and maxiter to see if either of those resolve the issue? Note that there is a factor 2 difference in lambda between ADMM and FISTA as @andrewwmao pointed out to me. We should probably resolve this to make the algorithms more comparable.

JeffFessler commented 2 years ago

I have trouble understanding the BART code, but if @JakobAsslaender's reading of it is correct then I think it is a "bug" in BART. For FISTA, the proper thing to return (whenever one stops) is the output of the prox update. (Whether this is "x" or "y" depends on the paper's notation BTW.)

I've double-checked that we do this properly in MIRT: https://github.com/JeffFessler/MIRT.jl/blob/26cfbc2a26b26814fa85739dbe01b5f1b8be5e21/src/algorithm/general/pogm_restart.jl#L262

BTW, I'd recommend POGM instead of FISTA for the problem at hand. See Fig. 3 of my survey paper on optimization for MRI: http://doi.org/10.1109/MSP.2019.2943645

There is code for reproducing that figure here: https://github.com/JeffFessler/mirt-demo/blob/main/isbi-19/01-recon.ipynb and a Documenter-type example of it here: https://juliaimagerecon.github.io/Examples/generated/mri/2-cs-wl-l1-2d/

The main issue is that currently that pogm code is buried in MIRT.jl which is too large, though I am working on paring it down into smaller packages. In the long run I should contribute POGM to RegLS it seems, or to some optimization package somewhere. It would benefit from some optimization like in-place ops that I haven't had time to do...

JakobAsslaender commented 2 years ago

@JeffFessler : I would loooooovvvvveeeee to throw POGM at our data! But so far I have been hesitant because of the different interface (compare to RLS.jl) and the lack of in-place operations etc. Do I understand correctly that the algorithm builds on FISTA? Maybe the easiest way to do those optimizations would be to copy the RSL.jl implementation of FISTA and turn it into POGM? We can, of course, also think about converting the existing FISTA implementation into a super-function similar to yours, but I am not sure about the best compromise of code duplication vs. speed and readability. Let me know if I could be of help, just not sure that I have enough knowledge about POGM to do the job...

aTrotier commented 2 years ago

@JakobAsslaender @JeffFessler After playing around I think you are right. I am able to get something close to the BART implementation of fista with a little bit much of noise which might then be related to https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3623535 In my first tests, I wanted to reduce the noise by increasing the lambda value which creates the threshold effect which is suppose to happen (BART misleads me in this case)

Maybe @uecker can give some advice about BART implementation and why they don't send the image after soft-thresholding.

Something to mentions : With BART most of the time I don't have to change a play a lot with the parameters to make fista works (lambda is generally close to 0.01). I guess the pre-scaling operation helps.

tknopp commented 2 years ago

I guess the pre-scaling operation helps. The option params[:normalizeReg] = true is actually supposed to make things independent of the input data but probably it is not enough. So there is certainly a TODO item left.

aTrotier commented 2 years ago

Ok I think BART is doing that the other way, they scale the input image (https://github.com/MagneticResonanceImaging/MRIReco.jl/issues/92) rather than the lambda value.

edit : results after playing with parameters https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/

JeffFessler commented 1 year ago

throw POGM

Thanks for the encouragement. I will first write a version of pogm that is streamlined (putting pogm and fista and pgm in one function is too messy) and uses in-place ops etc. It will be for a general composite cost function and I will illustrate how to use it using a regularized LS problem. Then we can decide if/how to make a wrapper in RLS.jl to call it just as easily as you call FISTA.

aTrotier commented 1 year ago

Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.

tknopp commented 1 year ago

Then we probably want to have options to refine the wavelet transform. Don't know how we should approach this but one would first start making the WaveletOp more general and then introduce some new high-level parameters.

aTrotier commented 1 year ago

Actually it cannot be the issue in the benchmark because the 3D datasets dimension are the same along each axis. But it is something that might impact the image quality for non-square matrix like in the example : https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/examples/mridataorg/example.jl

andrewwmao commented 1 year ago

@aTrotier have you run this notebook recently with the latest version of RLS? And also the FISTA recon with params2[:ρ] = 0.95? I am having trouble getting your binder to work.

aTrotier commented 1 year ago

No, I think it is outdated (before the splitting of the package). I will update it

aTrotier commented 1 year ago

I create a rapid test with your pogm branch of RegularizedLeastSquared.jl and my PR for MRIReco that gives this results : compare.pdf

Results between BART and MRIReco can be closed. However, if I increase the number of iteration, bart seems to converge whereas MRIReco (fista and even admm) increase the noise level. For fista, the noise amplification is really fast regarding the number of iteration.

By the way @andrewwmao, pogm and optista works but also gives the same results when increasing the number of iteration.

Maybe, we should work on the benchmark test rather than my real MP2RAGE datasets for quantitative metrics : https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2

Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.

At least it does not seems really related to the wavelet implementation.

@tknopp @JeffFessler Do you have some thoughts about that ?

using MRIReco, MRIFiles, MRICoilSensitivities
using BartIO, QuantitativeMRI
using CairoMakie
include("utils_MP2RAGE.jl")
## Setup BartIO and Global variable
set_bart_path("/usr/local/bin/bart")

slice = 25 # slice to show

## load data
b = BrukerFile("data/LR_3T_CS4")
raw = RawAcquisitionData_MP2RAGE_CS(b); # create an object with function in utils_MP2RAGE.jl
acq = AcquisitionData(raw,OffsetBruker = true)

## plot the mask 

begin# check mask
    mask = zeros(acq.encodingSize[1],acq.encodingSize[2],acq.encodingSize[3]);
    for i =1:length(acq.subsampleIndices[1]);
      mask[acq.subsampleIndices[1][i]]=1;
    end 
    heatmap!(Axis(Figure()[1,1],aspect=1), mask[64,:,:,1],colormap = :grays)
  current_figure()
end

## CoilSensitivities
sens = espirit(acq)

imMP_MRIReco_fista = Vector{Array{Float32,3}}()
imMP_MRIReco_admm = Vector{Array{Float32,3}}()
imMP_pics = Vector{Array{Float32,3}}()
iter_vec = (1,5,10,15,20,30,50)
for iter in iter_vec 
    # Then Wavelet
    params2 = Dict{Symbol, Any}()
    params2[:reco] = "multiCoil"
    params2[:reconSize] = acq.encodingSize
    params2[:senseMaps] = sens;

    params2[:solver] = "fista"
    params2[:sparseTrafoName] = "Wavelet"
    params2[:regularization] = "L1"
    params2[:λ] = 0.01# 5.e-2
    params2[:iterations] = iter
    params2[:normalize_ρ] = true
    params2[:ρ] = 0.95
    params2[:normalizeReg] = true

    I_wav = reconstruction(acq, params2);
  push!(imMP_MRIReco_fista,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))

  params2[:solver] = "admm"
  I_wav = reconstruction(acq, params2);
  push!(imMP_MRIReco_admm,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))
  #heatmap(imMP_wav[:,:,slice],colormap=:grays,axis= (;title="MRIReco wav, iter = $iter"))

  ## compare to bart

  k_bart = kDataCart(acq)
  k_bart = permutedims(k_bart,(1,2,3,4,6,5))
  size(k_bart)

  im_pics = bart(1,"pics -e -S -i $iter -R W:7:0:0.01",k_bart,sens);
  im_pics = permutedims(im_pics,(1,2,3,6,4,5));
  im_pics = im_pics[:,:,:,:,:,1];
  push!(imMP_pics,mp2rage_comb(im_pics[:,:,:,:,1]))
end

f = Figure(resolution=(400,600))
ga = f[1,1] = GridLayout()
asp = 128/96
for i in 1:length(imMP_pics)

  ax1 = Axis(ga[i,1],aspect=asp)
  hidedecorations!(ax1)
  heatmap!(ax1,imMP_MRIReco_fista[i][:,:,slice],colormap=:grays)

  ax2 = Axis(ga[i,2],aspect=asp)
  hidedecorations!(ax2)
  heatmap!(ax2,imMP_MRIReco_admm[i][:,:,slice],colormap=:grays)

  ax3 = Axis(ga[i,3],aspect=asp)
  hidedecorations!(ax3)
  heatmap!(ax3,imMP_pics[i][:,:,slice],colormap=:grays)

  Label(ga[i,0],"iter = $(iter_vec[i])",tellheight = false)

  if i == 1
    ax1.title = "MRIReco \n fista"
    ax2.title = "MRIReco \n admm"
    ax3.title = "bart \n fista"
  end
  rowsize!(ga,i,75)
end
rowgap!(ga,0)
f

save("compare.pdf",f)
JakobAsslaender commented 1 year ago

To what degree did you fine tune lambda? As discussed earlier, I don't think we can assume that the same lambda is optimal for BART and Julia. But it would maybe be nice to match the implementations, i.e. re-create the BART normalization in Julia instead of the current norm_reg implementation. Thoughts, @tknopp?

Regrading the comparison between FISTA and ADMM: What does "iteration" mean in the MRIReco interface (sorry, I never use that interface)? When you run both FISTA and ADMM long enough with the right lambda, they should converge to the same solution, assuming that the norm_reg is doing the same thing in both algorithms.

Last, I would suggest to benchmark it with a different regularization (e.g. TV) to avoid the known difference in the wavelet implementation.

andrewwmao commented 1 year ago

My guess is something is probably wrong with params2[:normalize_ρ] = true, i.e. in the calculation of the Lipschitz constant. That would explain why ISTA/POGM both seem to fail whereas ADMM seems to give a good result. But since I am also not using the high-level interface it is difficult for me to say where this problem occurs. This option is also 'false' in the above mentioned benchmark, where FISTA appears to be working fine.

For ADMM it's difficult at a glance to say what's going on there. Certainly there is a factor 2 difference in the appropriate lambda to use w.r.t. FISTA/POGM, and the parameter rho also has a different meaning there. This could probably be fixed with some appropriate tuning.

aTrotier commented 1 year ago

Stuff I am aware of that is different in the Bart implementation :

I will try to code an example with the low level interface.

For admm I think we have the same issue the convergence then noise amplification seems slower (I will increase the number of iteration just to check that hypothesis)

aTrotier commented 1 year ago

For the scaling issue I can also force bart to scale the data with 1

JeffFessler commented 1 year ago

I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.

aTrotier commented 1 year ago

For the scaling issue I can also force bart to scale the data with 1

Indeed it was the scaling / fine tuning the lambda value, if I force the scaling of data to 1 in bart (option -w 1) I get similar noise amplification : compare.pdf

edit : Weirdly, If I inverse scale acq.kdata by the value calculated by BART, the results converge for BART but not for MRIReco. I have to try with the low level interface

JakobAsslaender commented 1 year ago

I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.

@JeffFessler: I didn't have the Lipschitz constant on my radar in this context. Would you mind double checking that this is calculated correctly?

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/FISTA.jl#L68

And the called function can be found here:

https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/Utils.jl#L294

tknopp commented 1 year ago

So, I am not really sure how to move forward here. I am not sure if this is possible but I wonder if it makes sense to first improve our test cases in RegularizedLeastSquares to be sure that there are no issues on that level? Furthermore, we might want to improve the documentation and define more clearly the semantics of the algorithms. So that we clearly define what optimization problem is being solved and what normalizations are being done. Does that make sense?

On the other side it seems that it might be worth to translate the MRIReco.jl reconstruction code from @aTrotier to a low-level interface so that it becomes clearer for @JakobAsslaender and @andrewwmao?

These are just ideas. I don't have so much coding capacities right now unfortunately.

aTrotier commented 1 year ago

I have a benchmarkwith shepp logan phantom for MRIReco high level / low level and BART reconstruction here : https://github.com/aTrotier/MRIReco_Accuray_fista which gives the following results :

From metrics and qualitative evaluation High and low level gives approximatively the same results : still some residual artefacts (visible on img not on the the RMSE metrics) compare to the bart reconstruction compare_metrics.pdf compare_img.pdf

JeffFessler commented 1 year ago

you mind double checking

If the data is (possibly under-sampled) Cartesian, and if the encoding matrix uses the unitary DFT (with no B0 correction), and if the sensitivity maps are normalized so that the SSoS = 1, then the Lipschitz constant is 1, then there is no need to run the power iteration. I have seen situations where code set it to be 1 but one of those three "ifs" was not satisfied, leading to problems. I didn't realize that here we are always (?) using the power iteration, so that way should always be safe.

Probably we could use a smarter initial guess than randn to reduce iterations a bit, because often the principle eigenvector is quite smooth (e.g., when low frequencies are heavily sampled like in radial) but otherwise the code looks fine. https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/Utils.jl#L295

JakobAsslaender commented 1 year ago

Thanks for looking into this! To answer your question: we are running power iterations if the flag normalize_ρ=true, which it is by default. @aTrotier: Can you check that the high-level wrappers don't overwrite the default?

Would ones be a better initialization? Or did you have something else in mind?

JeffFessler commented 1 year ago

Would ones be a better initialization? Or did you have something else in mind?

I was thinking ones but I've never done any serious testing of it so caveat emptor...

aTrotier commented 1 year ago

@aTrotier: Can you check that the high-level wrappers don't overwrite the default?

it does not. Anyway, I forced it to true : https://github.com/aTrotier/MRIReco_Accuray_fista/blob/c6ed10cbe97477086f8e214c8a127f7b6a5d73bb/Accuracy_fista.jl#L81