JuliaMath / NFFT.jl

Julia implementation of the Non-equidistant Fast Fourier Transform (NFFT)
Other
153 stars 28 forks source link

Support AD engines #50

Closed kiranshila closed 1 year ago

kiranshila commented 3 years ago

First off, thank you for this excellent library! It has been invaluable in my work on radio astronomy imaging.

I am working on some imaging methods that are iterative, and I would like to use the rest of the Julia ecosystem to support gradient-based optimization methods. I have tried ForwardDiff, ReverseDiff, and Zygote (seemingly the three most popular AD engines), and none of them seem to work with this library. I think this might come from the requirements that the type of the variables in the plan must be the same as the data, and if I'm ADing w.r.t the data, the dual number wrapper forms a different type - but I'm not quite sure how to rectify that. I think zygote should get around that problem, but I'm not as familiar with the workings on the source-to-source methods.

kiranshila commented 3 years ago

I'm not quite sure if this makes sense, but this is doing something haha

using Zygote:@adjoint

Zygote.@adjoint function nfft(x, image)
    return nfft(x, image), function (Δ)
        return (nothing, nfft_adjoint(x, size(image), Δ))
    end
end
tknopp commented 3 years ago

@kiranshila: Thanks for you feedback. Much appreciated if the software that we developed has some real use cases.

Regarding the adjoint: Could you please have a look at this: https://github.com/MagneticResonanceImaging/MRIReco.jl/blob/master/src/Operators/NFFTOp.jl It is basically an operator version of the nfft and it already has a adjoint. I would say Zygote should be happy with that.

kiranshila commented 3 years ago

Will do! I'll keep you posted, and if it works, I'll submit a PR to add in Zygote support.

tknopp commented 3 years ago

Cool, NFFTOp is currently in a different package but it actually would fit into NFFT.jl itself. We pull in a dependency on LinearOperators but that should actually not be a big deal.

kiranshila commented 3 years ago

Somewhat related to this, and our discussion in the other issue on the the existence of the inverse, is it a side effect of this algorithm that the magnitudes of the fourier components in "there and back" calculation are orders of magnitude bigger? As in, in the example in your documentation, I would expect fHat and g to be identical, but the magnitudes of g are much higher. If I use the adjoint to make an image, the relative magnitudes are identical, but it is scaled.

The reason this is important to me is that I am writing an optimization to match the fourier components of data using a learned image with some regularization. However, the starting image for me is just the nfft_adjoint of the data. I would expect the difference of the fourier components of this image and the data to be zero, but because of this scaling, it is not.

It seems as though handing the adjoint to Zygote does indeed work, now the problem is in this scaling.

tknopp commented 3 years ago

yes, you need to take scalings into account and it often makes sense to properly "normalize" an operator. Just to give you an example: Here we are doing exactly that in order to make the FFT unitary, which it is not in the standard definition: https://github.com/tknopp/SparsityOperators.jl/blob/master/src/FFTOp.jl#L41

But for the NFFT this is more complicated at it actually involves, what is usually called the density compensation. Let A be the NFFT matrix. Then A^H A is in general not the identity matrix. Instead what you want /need ist A^H W A where W is a diagonal matrix with the squared density weights in it. It holds that A^H W A \approx I if the weights are appropriately chosen. Then I recommend that you use A^H W^{1/2} and W^{1/2} A as your transformation pair.

We actually also have a method in NFFT.jl to automatically calculate the density weights (called sdc I think).

kiranshila commented 3 years ago

Oh perfect! Thank you! That makes sense.

tknopp commented 3 years ago

Please have a look at this article: https://downloads.hindawi.com/journals/ijbi/2007/024727.pdf It is on MRI reconstruction but actually the initial formulas touch exactly what I described.

kiranshila commented 3 years ago

Yeah it seems that MRI reconstruction is very similar to radio telescope imaging

tknopp commented 2 years ago

Hi @roflmaostc,

I would like to make NFFT.jl AD friendly but unfortunately don't yet have a deeper understanding on how to write custom chain rules. I have seen that you were involved in the chain rules for AbstractFFTs: https://github.com/JuliaDiff/ChainRules.jl/issues/127 Could you help me with NFFT.jl?

If I get it right, this probably could be done on the level of AbstractNFFTs allowing all implementations to benefit from that.

JeffFessler commented 2 years ago

I don't know about ChainRulesCore, but it is about 1 line of code using @adjoint in Flux / Zygote; see this example: https://jefffessler.github.io/SPECTrecon.jl/stable/generated/examples/6-dl/#Custom-backpropagation

Now that I RTFM I see that they recommend using ChainRulesCore: https://fluxml.ai/Zygote.jl/dev/adjoints/

I usually wrap my NFFT calls in a LinearMap so I think I will look into a chain rule for that. (I looked at LinearOperators.jl and didn't see a dependency on ChainRules there BTW.)

tknopp commented 2 years ago

Yes I read that manual as well and with AbstractFFTs they did basically the same: https://github.com/JuliaMath/AbstractFFTs.jl/pull/58

If I look at the LinearMaps example, it could look something like this:

function rrule(::typeof(*), A:: AbstractNFFTPlan, x::AbstractArray)
    y = A*x
    function pullback(dy)
      DY = unthunk(dy)
      # Because A is an abstract map, the product is only differentiable w.r.t the input
      return NoTangent(), NoTangent(), @thunk(A' * DY)
    end
    return y, pullback
end

I am not sure if the same works for mul!.

I like that the LinearMaps PR has a test for the Chain rules. Flux is a pretty heavy dependency, not sure if that is really needed, or if there are easier ways to test the rrule.

migrosser commented 1 year ago

This issue should be fixed by commit 8552b281341e68ab101fd9b0555e4218b7e562d9, where we implemented ChainRulesCore.frule and ChainRulesCore.rrule. Therefore, I will close it for now.