Open ahbarnett opened 9 months ago
I think this would be great. To me, one of the selling points of NUFFTs (well, NUDFTs) is that one can do things like this relatively easily since we have an exact formula for what we're trying to compute. So having a forward model based on NUFFTs means that we get all of the derivatives “for free”.
So just to make sure I'm following, what you're saying is that for type 1 we provide the NU points along with displacements (directions) for each of those NU points, then return the value of the transform on the grid (the type 1 transform) along with its directional derivative (so 2N values if we have N modes). That makes sense. I'm not following the adjoint, though, shouldn't we similarly get the type 2 transform along with the directional derivative at each of the target points (so 2M values if we have M points). How does this become f_j, fx_j and fy_j?
One way to avoid a big interface design (adding new functions, etc.) is to add flags for gradients to the opts and then adopt a convention where we interleave values along with their gradients (or directional displacements). So instead of specifying the NU points in an array of size M, we'd have an array of size 4M ordered as [x_1, y_1, dx_1, dy_1, x_2, y_2, dx_2, dy_2, …]. Similarly, the output for type 1 would of size 2N and be ordered as [f_1, df_1, f_2, df_2, …]. Of course Matlab/Octave/Python/Julia interfaces can make this all more digestible to high-level users. The alternative would be to add new functions such as setpts_with_grad
and execute_with_grad
, which is maybe not so bad. All the other functions could stay the same.
Hi Joakim, Actually for type 1 there'd be one output (since one FFT). You supply values, dx and dy (ie 3M strength data). This corresponds to computing the model plus a given vector dot grad of the model, eg in the pytorch setting (actually I don't know what pytorch needs :). Adjoint (type 2) takes one unif grid, does one FFT, and returns 3M data (vals,dx,dy) at the NU pts.
I like your idea of adding an opt which causes it to read 3M instead of M location data.
@eikenberg @wardbrian this is what we discussed over coffee - any comment on whether the above would satisfy the pytorch needs welcome - something to implement later this year.
I think there would certainly be users if this was implemented. In the mean time, @eickenberg and I still plan on writing out the derivatives in terms of other transforms, but these are usually ~3x as large as the one you're describing here if I understand it correctly
I've just seen this because Alex mentioned it over on the jax-finufft repo. It's not immediately obvious that this would help the JAX implementation because of JAX's opinions about autodiff, but I'll have a think.
@WardBrian — I'd be happy to chat about how we've implemented our backprop for JAX. We batch the operations so that the backward pass only requires a single call into finufft plus a single call in the forward pass. The implementation is here, but it's definitely a little opaque - especially if you're not so familiar with the JAX internals!
@dfm I think that is similar to what we currently do in pytorch_finufft, except in the case where you are requesting derivatives with respect to both the NU points and the values, in which case our reverse pass does 2 transforms.
@eickenberg and I are going to chat soon about our type 3 reverse code, so it may also be a good time to take a harder look at sharing work for type 1 and type 2. It's somewhat clear to me how we could do this for type 1, but type 2 seems a bit trickier, at least how we currently have it implemented. Would be good to chat more
Leslie and I realised that a quite useful enhancement would be a new interface that computes directional derivative w.r.t. NU source or target locations. This would involve a few more spread/interp flops, but no increase in FFT effort. If FFT dominates, it would be basically (d+1) times faster in dimension d than computing d+1 separate transforms (or vectorized transforms, which are admittedly faster) which is the only current way to get such derivatives. This would speed up (by up to d+1 factor) the pytorch/jax/TF interfaces which need these gradients w.r.t. NU pt locations. It would be used by Leslie for double-layer evals in heat-split methods for Laplace BVPs. The disadvantage is a more general plan interface; the original would be kept but some new wrappers made available to advanced users (in C/Fort to begin with).
Consider 2D type 1. To the input c_j strengths would be appended dx_j and dy_j dipole strength components; the outputs would be the same. The spread step evaluates c_j p(x) p(y) + dx_j p'(x) p(y) + dy_j p(x) p'(y) where p(.) is the spreading kernel. This needs Horner to spit out the derivative of the poly p(x) too; just a few more flops. The writing to fine-grid is the same, so the overall speed should be similar.
Its adjoint is a 2D type 2 that spits out f_j, fx_j, and fy_j, the value and grad at each target. (Anna-Karin has asked for this before, I think.) Again, this would prevent the need for a stack of ntransf=3 transforms.
An overall flag would switch this on/off, determining whether to look at the dx_j etc arrays or write to the fx_j etc arrays. This would switch between existing spread/interp inner functions and new ones with grad in there.
I dread to think what a type-3 version would look like (but ML-interfaces to type 3 have the same complexity: grad_source and grad_target outputs would be expected).
Should be a couple of days to try out in 2D and get a sense if the speedup is close to 3x.
Do we think there'd be enough users of this?