Closed kiranshila closed 1 year ago
I'm not quite sure if this makes sense, but this is doing something haha
using Zygote:@adjoint
Zygote.@adjoint function nfft(x, image)
return nfft(x, image), function (Δ)
return (nothing, nfft_adjoint(x, size(image), Δ))
end
end
@kiranshila: Thanks for you feedback. Much appreciated if the software that we developed has some real use cases.
Regarding the adjoint: Could you please have a look at this: https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Operators/NFFTOp.jl It is basically an operator version of the nfft and it already has a adjoint. I would say Zygote should be happy with that.
Will do! I'll keep you posted, and if it works, I'll submit a PR to add in Zygote support.
Cool, NFFTOp
is currently in a different package but it actually would fit into NFFT.jl itself. We pull in a dependency on LinearOperators but that should actually not be a big deal.
Somewhat related to this, and our discussion in the other issue on the the existence of the inverse, is it a side effect of this algorithm that the magnitudes of the fourier components in "there and back" calculation are orders of magnitude bigger? As in, in the example in your documentation, I would expect fHat
and g
to be identical, but the magnitudes of g
are much higher. If I use the adjoint to make an image, the relative magnitudes are identical, but it is scaled.
The reason this is important to me is that I am writing an optimization to match the fourier components of data using a learned image with some regularization. However, the starting image for me is just the nfft_adjoint
of the data. I would expect the difference of the fourier components of this image and the data to be zero, but because of this scaling, it is not.
It seems as though handing the adjoint to Zygote does indeed work, now the problem is in this scaling.
yes, you need to take scalings into account and it often makes sense to properly "normalize" an operator. Just to give you an example: Here we are doing exactly that in order to make the FFT unitary, which it is not in the standard definition: https://github.com/tknopp/SparsityOperators.jl/blob/master/src/FFTOp.jl#L41
But for the NFFT this is more complicated at it actually involves, what is usually called the density compensation. Let A be the NFFT matrix. Then A^H A is in general not the identity matrix. Instead what you want /need ist A^H W A where W is a diagonal matrix with the squared density weights in it. It holds that A^H W A \approx I if the weights are appropriately chosen. Then I recommend that you use A^H W^{1/2} and W^{1/2} A as your transformation pair.
We actually also have a method in NFFT.jl to automatically calculate the density weights (called sdc
I think).
Oh perfect! Thank you! That makes sense.
Please have a look at this article: https://downloads.hindawi.com/journals/ijbi/2007/024727.pdf It is on MRI reconstruction but actually the initial formulas touch exactly what I described.
Yeah it seems that MRI reconstruction is very similar to radio telescope imaging
Hi @roflmaostc,
I would like to make NFFT.jl AD friendly but unfortunately don't yet have a deeper understanding on how to write custom chain rules. I have seen that you were involved in the chain rules for AbstractFFTs
: https://github.com/JuliaDiff/ChainRules.jl/issues/127 Could you help me with NFFT.jl?
If I get it right, this probably could be done on the level of AbstractNFFTs
allowing all implementations to benefit from that.
I don't know about ChainRulesCore, but it is about 1 line of code using @adjoint
in Flux / Zygote;
see this example: https://jefffessler.github.io/SPECTrecon.jl/stable/generated/examples/6-dl/#Custom-backpropagation
Now that I RTFM I see that they recommend using ChainRulesCore: https://fluxml.ai/Zygote.jl/dev/adjoints/
I usually wrap my NFFT calls in a LinearMap so I think I will look into a chain rule for that. (I looked at LinearOperators.jl and didn't see a dependency on ChainRules there BTW.)
Yes I read that manual as well and with AbstractFFTs
they did basically the same: https://github.com/JuliaMath/AbstractFFTs.jl/pull/58
If I look at the LinearMaps example, it could look something like this:
function rrule(::typeof(*), A:: AbstractNFFTPlan, x::AbstractArray)
y = A*x
function pullback(dy)
DY = unthunk(dy)
# Because A is an abstract map, the product is only differentiable w.r.t the input
return NoTangent(), NoTangent(), @thunk(A' * DY)
end
return y, pullback
end
I am not sure if the same works for mul!
.
I like that the LinearMaps PR has a test for the Chain rules. Flux is a pretty heavy dependency, not sure if that is really needed, or if there are easier ways to test the rrule
.
This issue should be fixed by commit 8552b281341e68ab101fd9b0555e4218b7e562d9, where we implemented ChainRulesCore.frule
and ChainRulesCore.rrule
. Therefore, I will close it for now.
First off, thank you for this excellent library! It has been invaluable in my work on radio astronomy imaging.
I am working on some imaging methods that are iterative, and I would like to use the rest of the Julia ecosystem to support gradient-based optimization methods. I have tried ForwardDiff, ReverseDiff, and Zygote (seemingly the three most popular AD engines), and none of them seem to work with this library. I think this might come from the requirements that the type of the variables in the plan must be the same as the data, and if I'm ADing w.r.t the data, the dual number wrapper forms a different type - but I'm not quite sure how to rectify that. I think zygote should get around that problem, but I'm not as familiar with the workings on the source-to-source methods.