y1xiaoc / fwdlap

Forward mode laplacian implemented in JAX tracer
https://github.com/y1xiaoc/fwdlap
Apache License 2.0
29 stars 1 forks source link

Return output not summed #4

Open jwnys opened 2 months ago

jwnys commented 2 months ago

Hi, first of all thanks a lot for the nice implementation! I was wondering if it would be possible to allow the output to be the individual d^2 f / d x_i^2 as well, rather than the version that sums immediately over i? This would be very useful when computing e.g. Laplace-Beltrami operators.

y1xiaoc commented 2 months ago

Hi and thanks for the issue. Yes, that would be possible. Check the kinetic example in the readme, the if branch that has the inner size is done though calculating each section of the vector $\partial_{ii}f$ and then sum them up. setting inner_size=1 and remove the summation should do the trick. The key code is the following (note this is the one with summation)

eye = eye.reshape(ncoord//inner_size, inner_size, ncoord)
primals, f_lap_pe = fwdlap.lap_partial(f, (flat_x,), (eye[0],), (zero,))
def loop_fn(i, val):
    (jac, lap) = f_lap_pe((eye[i],), (zero,))
    val += (jac**2).sum() + lap # this line sum them up
    return val
laplacian = lax.fori_loop(0, ncoord//inner_size, loop_fn, 0.0)

In the meantime, if you are looking at calculating some general 2nd order differential operator like $\sum{ij} a{ij} \partial_{ij}$, that is also possible through fwldlap (maybe with some small modification). Check this paper from authors of original forward laplacian paper, section 2.2, eq 7-9. Note the positive definite case can already be handled in the current version of fwdlap ($D$ is identity). Let me know if you need the general version that handles the signs and I could make a PR for that.

jwnys commented 2 months ago

Hi, thanks for responding so quickly. It'd be great indeed to just expose this unsummed loop. Does this offer any computational disadvantages? What I specifically would like to have is an implementation of https://en.wikipedia.org/wiki/Laplace%E2%80%93Beltrami_operator#Spherical_Laplacian For example, the Ylm spherical harmonics are eigenfunctions of this operator. So, it's not 100% of the form what you refer to, but when the gradient is available as well, I guess it could be constructed from that as well.

PS: it would be sufficient to have the unsummed output, together with the gradients (jac) to reconstruct the Spherical Laplacian.

y1xiaoc commented 2 months ago

Yes it would indeed be a bit slower using the loop version for unsummed output. In fact it reduce to jax's own jet implementation in this case. Although fwdlap might still be a bit faster due to the handle of symbolic zeros.

That said, I do agree it would be nice to expose this (along with the summed laplacian) as a high level api (like mentioned in #2). The current api is rather low level and mainly to be used for developing packages instead of end user. Any PR is welcomed as well!