Open aTrotier opened 2 years ago
Hi aTrotier, I looked at the notebooks that you linked, and it seems that the second notebook using the in vivo MP2RAGE data is using an old version of 'RegularizedLeastSquares'. I contributed a bugfix to the FISTA algorithm in comparison to BART that should have fixed the accuracy problem. Can you try updating this to v0.8.7?
For the first notebook, my feeling is that there may be an issue with the regularization strength/implementation of the regularizer. Lambdas won't be comparable between Julia/BART esp. when both packages are scaling lambda in different ways. Perhaps you can try testing both BART/MRIReco FISTA implementations setting lambda close to zero to see if they give the same results.
I don't know if it's just the different scaling but @migrosser also reported that he got better results by just taking another regularization parameter. It would be great if someone could have a deeper look. But probably we should keep this on hold and first look that RegularizedLeastSquares
is consistent.
If somebody wants to give this a go: MRIRecoBenchmarks now includes the example script: https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2 It will generate timings and outputs images. Errors are currently not calculated but that would be easy to add.
Regarding the scaling: In my opinion this is primarily a documentation issue. We need to document somewhere how to translate a regularization parameter scaled with BART to MRIReco.
So if I understand it correctly, the return value of BART depends whether it terminates based on max number of iteration, in which case it returns after the gradient step, i.e. half way through Eq. 4.1 in the FISTA paper:
but if the residual is smaller than the termination threshold, they return after adding the momentum (ravine) step, i.e. after Eq. 4.3 in the FISTA paper:
Not sure why they distinguish between these two cases, but maybe @JeffFessler has thoughts?
We currently return after applying the proximal operator, i.e. after Eq 4.1. I think it would be easy to change the behavior to termination after the gradient step by moving line
below
or after the momentum step by moving it below
but I am not sure that the done
function supports a distinction based why we are terminating.
Do you, @JeffFessler or @andrewwmao have comments on what is the "right" termination point?
Intuitively, I would have said after the prox step because only then the solution lies in the desired subspace (as Jeff has said). Making it dependent on maxiter seems to be suboptimal. We should not replicate that.
Well in this case the current implementation is correct :). @aTrotier : Have you played with lambda
and maxiter
to see if either of those resolve the issue? Note that there is a factor 2 difference in lambda
between ADMM and FISTA as @andrewwmao pointed out to me. We should probably resolve this to make the algorithms more comparable.
I have trouble understanding the BART code, but if @JakobAsslaender's reading of it is correct then I think it is a "bug" in BART. For FISTA, the proper thing to return (whenever one stops) is the output of the prox update. (Whether this is "x" or "y" depends on the paper's notation BTW.)
I've double-checked that we do this properly in MIRT: https://github.com/JeffFessler/MIRT.jl/blob/26cfbc2a26b26814fa85739dbe01b5f1b8be5e21/src/algorithm/general/pogm_restart.jl#L262
BTW, I'd recommend POGM instead of FISTA for the problem at hand. See Fig. 3 of my survey paper on optimization for MRI: http://doi.org/10.1109/MSP.2019.2943645
There is code for reproducing that figure here: https://github.com/JeffFessler/mirt-demo/blob/main/isbi-19/01-recon.ipynb and a Documenter-type example of it here: https://juliaimagerecon.github.io/Examples/generated/mri/2-cs-wl-l1-2d/
The main issue is that currently that pogm code is buried in MIRT.jl which is too large, though I am working on paring it down into smaller packages. In the long run I should contribute POGM to RegLS it seems, or to some optimization package somewhere. It would benefit from some optimization like in-place ops that I haven't had time to do...
@JeffFessler : I would loooooovvvvveeeee to throw POGM at our data! But so far I have been hesitant because of the different interface (compare to RLS.jl) and the lack of in-place operations etc. Do I understand correctly that the algorithm builds on FISTA? Maybe the easiest way to do those optimizations would be to copy the RSL.jl implementation of FISTA and turn it into POGM? We can, of course, also think about converting the existing FISTA implementation into a super-function similar to yours, but I am not sure about the best compromise of code duplication vs. speed and readability. Let me know if I could be of help, just not sure that I have enough knowledge about POGM to do the job...
@JakobAsslaender @JeffFessler After playing around I think you are right. I am able to get something close to the BART implementation of fista with a little bit much of noise which might then be related to https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3623535 In my first tests, I wanted to reduce the noise by increasing the lambda value which creates the threshold effect which is suppose to happen (BART misleads me in this case)
Maybe @uecker can give some advice about BART implementation and why they don't send the image after soft-thresholding.
Something to mentions : With BART most of the time I don't have to change a play a lot with the parameters to make fista works (lambda is generally close to 0.01). I guess the pre-scaling operation helps.
I guess the pre-scaling operation helps. The option
params[:normalizeReg] = true
is actually supposed to make things independent of the input data but probably it is not enough. So there is certainly a TODO item left.
Ok I think BART is doing that the other way, they scale the input image (https://github.com/MagneticResonanceImaging/MRIReco.jl/issues/92) rather than the lambda value.
edit : results after playing with parameters https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/
throw POGM
Thanks for the encouragement. I will first write a version of pogm that is streamlined (putting pogm and fista and pgm in one function is too messy) and uses in-place ops etc. It will be for a general composite cost function and I will illustrate how to use it using a regularized LS problem. Then we can decide if/how to make a wrapper in RLS.jl to call it just as easily as you call FISTA.
Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.
Then we probably want to have options to refine the wavelet transform. Don't know how we should approach this but one would first start making the WaveletOp
more general and then introduce some new high-level parameters.
Actually it cannot be the issue in the benchmark because the 3D datasets dimension are the same along each axis. But it is something that might impact the image quality for non-square matrix like in the example : https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/examples/mridataorg/example.jl
@aTrotier have you run this notebook recently with the latest version of RLS? And also the FISTA recon with params2[:ρ] = 0.95
? I am having trouble getting your binder to work.
No, I think it is outdated (before the splitting of the package). I will update it
I create a rapid test with your pogm branch of RegularizedLeastSquared.jl and my PR for MRIReco that gives this results : compare.pdf
Results between BART and MRIReco can be closed. However, if I increase the number of iteration, bart seems to converge whereas MRIReco (fista and even admm) increase the noise level. For fista, the noise amplification is really fast regarding the number of iteration.
By the way @andrewwmao, pogm and optista works but also gives the same results when increasing the number of iteration.
Maybe, we should work on the benchmark test rather than my real MP2RAGE datasets for quantitative metrics : https://github.com/MagneticResonanceImaging/MRIRecoBenchmarks/tree/master/benchmark2
Just a remark : the accuracy issue can also be linked to the wavelet implementation. Bart is doing a full decomposition along each axis whereas as Wavelet.jl determine the minimum level of decomposition along each axis and used that.
At least it does not seems really related to the wavelet implementation.
@tknopp @JeffFessler Do you have some thoughts about that ?
using MRIReco, MRIFiles, MRICoilSensitivities
using BartIO, QuantitativeMRI
using CairoMakie
include("utils_MP2RAGE.jl")
## Setup BartIO and Global variable
set_bart_path("/usr/local/bin/bart")
slice = 25 # slice to show
## load data
b = BrukerFile("data/LR_3T_CS4")
raw = RawAcquisitionData_MP2RAGE_CS(b); # create an object with function in utils_MP2RAGE.jl
acq = AcquisitionData(raw,OffsetBruker = true)
## plot the mask
begin# check mask
mask = zeros(acq.encodingSize[1],acq.encodingSize[2],acq.encodingSize[3]);
for i =1:length(acq.subsampleIndices[1]);
mask[acq.subsampleIndices[1][i]]=1;
end
heatmap!(Axis(Figure()[1,1],aspect=1), mask[64,:,:,1],colormap = :grays)
current_figure()
end
## CoilSensitivities
sens = espirit(acq)
imMP_MRIReco_fista = Vector{Array{Float32,3}}()
imMP_MRIReco_admm = Vector{Array{Float32,3}}()
imMP_pics = Vector{Array{Float32,3}}()
iter_vec = (1,5,10,15,20,30,50)
for iter in iter_vec
# Then Wavelet
params2 = Dict{Symbol, Any}()
params2[:reco] = "multiCoil"
params2[:reconSize] = acq.encodingSize
params2[:senseMaps] = sens;
params2[:solver] = "fista"
params2[:sparseTrafoName] = "Wavelet"
params2[:regularization] = "L1"
params2[:λ] = 0.01# 5.e-2
params2[:iterations] = iter
params2[:normalize_ρ] = true
params2[:ρ] = 0.95
params2[:normalizeReg] = true
I_wav = reconstruction(acq, params2);
push!(imMP_MRIReco_fista,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))
params2[:solver] = "admm"
I_wav = reconstruction(acq, params2);
push!(imMP_MRIReco_admm,mp2rage_comb(I_wav.data[:,:,:,:,1,1]))
#heatmap(imMP_wav[:,:,slice],colormap=:grays,axis= (;title="MRIReco wav, iter = $iter"))
## compare to bart
k_bart = kDataCart(acq)
k_bart = permutedims(k_bart,(1,2,3,4,6,5))
size(k_bart)
im_pics = bart(1,"pics -e -S -i $iter -R W:7:0:0.01",k_bart,sens);
im_pics = permutedims(im_pics,(1,2,3,6,4,5));
im_pics = im_pics[:,:,:,:,:,1];
push!(imMP_pics,mp2rage_comb(im_pics[:,:,:,:,1]))
end
f = Figure(resolution=(400,600))
ga = f[1,1] = GridLayout()
asp = 128/96
for i in 1:length(imMP_pics)
ax1 = Axis(ga[i,1],aspect=asp)
hidedecorations!(ax1)
heatmap!(ax1,imMP_MRIReco_fista[i][:,:,slice],colormap=:grays)
ax2 = Axis(ga[i,2],aspect=asp)
hidedecorations!(ax2)
heatmap!(ax2,imMP_MRIReco_admm[i][:,:,slice],colormap=:grays)
ax3 = Axis(ga[i,3],aspect=asp)
hidedecorations!(ax3)
heatmap!(ax3,imMP_pics[i][:,:,slice],colormap=:grays)
Label(ga[i,0],"iter = $(iter_vec[i])",tellheight = false)
if i == 1
ax1.title = "MRIReco \n fista"
ax2.title = "MRIReco \n admm"
ax3.title = "bart \n fista"
end
rowsize!(ga,i,75)
end
rowgap!(ga,0)
f
save("compare.pdf",f)
To what degree did you fine tune lambda
? As discussed earlier, I don't think we can assume that the same lambda
is optimal for BART and Julia. But it would maybe be nice to match the implementations, i.e. re-create the BART normalization in Julia instead of the current norm_reg
implementation. Thoughts, @tknopp?
Regrading the comparison between FISTA and ADMM: What does "iteration" mean in the MRIReco interface (sorry, I never use that interface)? When you run both FISTA and ADMM long enough with the right lambda, they should converge to the same solution, assuming that the norm_reg
is doing the same thing in both algorithms.
Last, I would suggest to benchmark it with a different regularization (e.g. TV) to avoid the known difference in the wavelet implementation.
My guess is something is probably wrong with params2[:normalize_ρ] = true
, i.e. in the calculation of the Lipschitz constant. That would explain why ISTA/POGM both seem to fail whereas ADMM seems to give a good result. But since I am also not using the high-level interface it is difficult for me to say where this problem occurs.
This option is also 'false' in the above mentioned benchmark, where FISTA appears to be working fine.
For ADMM it's difficult at a glance to say what's going on there. Certainly there is a factor 2 difference in the appropriate lambda to use w.r.t. FISTA/POGM, and the parameter rho also has a different meaning there. This could probably be fixed with some appropriate tuning.
Stuff I am aware of that is different in the Bart implementation :
I will try to code an example with the low level interface.
For admm I think we have the same issue the convergence then noise amplification seems slower (I will increase the number of iteration just to check that hypothesis)
For the scaling issue I can also force bart to scale the data with 1
I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.
For the scaling issue I can also force bart to scale the data with 1
Indeed it was the scaling / fine tuning the lambda value, if I force the scaling of data to 1 in bart (option -w 1) I get similar noise amplification : compare.pdf
edit : Weirdly, If I inverse scale acq.kdata by the value calculated by BART, the results converge for BART but not for MRIReco. I have to try with the low level interface
I was also going to guess that the issue is either a mismatched regularizer or an incorrect or inconsistent Lipschitz constant somewhere.
@JeffFessler: I didn't have the Lipschitz constant on my radar in this context. Would you mind double checking that this is calculated correctly?
And the called function can be found here:
So, I am not really sure how to move forward here. I am not sure if this is possible but I wonder if it makes sense to first improve our test cases in RegularizedLeastSquares to be sure that there are no issues on that level? Furthermore, we might want to improve the documentation and define more clearly the semantics of the algorithms. So that we clearly define what optimization problem is being solved and what normalizations are being done. Does that make sense?
On the other side it seems that it might be worth to translate the MRIReco.jl reconstruction code from @aTrotier to a low-level interface so that it becomes clearer for @JakobAsslaender and @andrewwmao?
These are just ideas. I don't have so much coding capacities right now unfortunately.
I have a benchmarkwith shepp logan phantom for MRIReco high level / low level and BART reconstruction here : https://github.com/aTrotier/MRIReco_Accuray_fista which gives the following results :
From metrics and qualitative evaluation High and low level gives approximatively the same results : still some residual artefacts (visible on img not on the the RMSE metrics) compare to the bart reconstruction compare_metrics.pdf compare_img.pdf
you mind double checking
If the data is (possibly under-sampled) Cartesian, and if the encoding matrix uses the unitary DFT (with no B0 correction), and if the sensitivity maps are normalized so that the SSoS = 1, then the Lipschitz constant is 1, then there is no need to run the power iteration. I have seen situations where code set it to be 1 but one of those three "ifs" was not satisfied, leading to problems. I didn't realize that here we are always (?) using the power iteration, so that way should always be safe.
Probably we could use a smarter initial guess than randn
to reduce iterations a bit, because often the principle eigenvector is quite smooth (e.g., when low frequencies are heavily sampled like in radial) but otherwise the code looks fine.
https://github.com/tknopp/RegularizedLeastSquares.jl/blob/19e50e83a85bdf9006a6433ce3512aec230422a2/src/Utils.jl#L295
Thanks for looking into this! To answer your question: we are running power iterations if the flag normalize_ρ=true
, which it is by default. @aTrotier: Can you check that the high-level wrappers don't overwrite the default?
Would ones
be a better initialization? Or did you have something else in mind?
Would
ones
be a better initialization? Or did you have something else in mind?
I was thinking ones
but I've never done any serious testing of it so caveat emptor...
@aTrotier: Can you check that the high-level wrappers don't overwrite the default?
it does not. Anyway, I forced it to true : https://github.com/aTrotier/MRIReco_Accuray_fista/blob/c6ed10cbe97477086f8e214c8a127f7b6a5d73bb/Accuracy_fista.jl#L81
Compressed sensing is now really fast with FISTA : https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3655548
But there is still an issue with accuracy (https://github.com/MagneticResonanceImaging/MRIReco.jl/discussions/102#discussioncomment-3623535). We observed it on phantom : https://atrotier.github.io/MRIRecoVsBART_Benchmark/test_bart.html
And I also have an issue on real MP2RAGE data : https://atrotier.github.io/EDUC_JULIA_CS_MP2RAGE/ I am not able to reach the same image quality as BART.
If I use ADMM, it works well !