Closed alexander-held closed 8 months ago
Another autodiff
issue to self-assign, @Saransh-cpp. Thanks!
Thanks for the tags, self-assigned!
Hi @alexander-held, I've been looking at this issue, and it seems more of a new feature request for the Jax backend.
The implementation (_impl
) of ak.mean
calls _impl
of ak.count
-
but, ak.count
is not implemented for the Jax backend -
I could trace back the history of this file, and it looks like the intentional error has always been there.
@jpivarski, were count
, argmin
, argmax
, and count_nonzero
not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future? Will adding their implementations be a good starting point for my project? I can work on these this week.
I vaguely remember some conversations in the past about what kind of derivatives we might want to be able to evaluate and which might not be as useful. For discrete things like ak.count
we would need to have some relaxation to define a derivative I believe, but we might not need to support cases where the amount of elements over which we take the mean changes. With the mean being the sum over the elements divided by the count, the derivative I was ultimately after in the code snippet above is just the derivative of the mean and then a division by a constant.
As soon as the number of elements changes (which might happen in practice with selection cuts) then handling that seems like a broader issue that the user might need to take care of externally. I'll point some more people towards this issue to invite some other opinions for what might be the most useful.
@jpivarski, were
count
,argmin
,argmax
, andcount_nonzero
not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future?
These seem to be fundamentally non-differentiable—in my understanding, at least. ak.count
, at least, depends only on the array's structure, and we don't even represent the array structure (the ak.index.Index
parts) using JAX because you can't differentiate through that. It would be equivalent to differentiating through variables that are used in an if
predicate. It's less obvious for ak.count_nonzero
(the delta function is differentiable in Lebesgue measure but not Riemannian measure—I don't think that's relevant) and ak.argmin
/ak.argmax
.
However, differentiable arrays should carry through some of this stuff. For instance, ak.count
should be treated as a constant, so that ak.mean
can be implemented by using the implementation of ak.sum
and dividing it by the constant that comes out of ak.count
. (The derivative of ak.mean
is a scaled version of the derivative of ak.sum
.) I don't know about the others, ak.count_nonzero
and ak.argmin
/ak.argmax
, though.
Version of Awkward Array
main branch
Description and code to reproduce
This is a follow-up to #2591 with a slightly more simplified setup. It should be conceptually possible to differentiate through taking a mean. Currently this does not work.
Reproducer:
Result:
A standalone jax version of taking a mean works fine: