scikit-hep / awkward

Manipulate JSON-like data with NumPy-like idioms.
https://awkward-array.org
BSD 3-Clause "New" or "Revised" License
849 stars 89 forks source link

Differentiating through an ak.mean #2638

Closed alexander-held closed 8 months ago

alexander-held commented 1 year ago

Version of Awkward Array

main branch

Description and code to reproduce

This is a follow-up to #2591 with a slightly more simplified setup. It should be conceptually possible to differentiate through taking a mean. Currently this does not work.

Reproducer:

import awkward as ak
import jax
import uproot

ak.jax.register_and_check()

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

def mean_jet_pt(jets):
    return ak.mean(jets.pt)

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Jet_pt","Jet_eta", "Jet_phi", "Jet_mass"])
    evtfilter = ak.num(arr["Jet_pt"]) >= 2
    jets = ak.zip(dict(zip(["pt","eta", "phi", "mass"], ak.unzip(arr))), with_name="Momentum4D")[evtfilter]
    jets = ak.to_backend(jets, "jax")

jax.value_and_grad(mean_jet_pt, argnums=0)(jets)

Result:

RuntimeError: Cannot differentiate through count_zero

This error occurred while calling

    ak.mean(
        <Array [[...], [...], ..., [...], [...]] type='140 * var * float32'>
    )

A standalone jax version of taking a mean works fine:

import jax.numpy as jnp

def mean(j):
    return jnp.mean(j)

data = jnp.array([1, 7, 3, 5],dtype=float)

jax.value_and_grad(mean, argnums=0)(data)
jpivarski commented 10 months ago

Another autodiff issue to self-assign, @Saransh-cpp. Thanks!

Saransh-cpp commented 10 months ago

Thanks for the tags, self-assigned!

Saransh-cpp commented 9 months ago

Hi @alexander-held, I've been looking at this issue, and it seems more of a new feature request for the Jax backend.

The implementation (_impl) of ak.mean calls _impl of ak.count -

https://github.com/scikit-hep/awkward/blob/c1e4f9f6e992cf0524756df46ba4f4167ea29239/src/awkward/operations/ak_mean.py#L195

but, ak.count is not implemented for the Jax backend -

https://github.com/scikit-hep/awkward/blob/c1e4f9f6e992cf0524756df46ba4f4167ea29239/src/awkward/_connect/jax/reducers.py#L91-L114

I could trace back the history of this file, and it looks like the intentional error has always been there.

@jpivarski, were count, argmin, argmax, and count_nonzero not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future? Will adding their implementations be a good starting point for my project? I can work on these this week.

alexander-held commented 9 months ago

I vaguely remember some conversations in the past about what kind of derivatives we might want to be able to evaluate and which might not be as useful. For discrete things like ak.count we would need to have some relaxation to define a derivative I believe, but we might not need to support cases where the amount of elements over which we take the mean changes. With the mean being the sum over the elements divided by the count, the derivative I was ultimately after in the code snippet above is just the derivative of the mean and then a division by a constant.

As soon as the number of elements changes (which might happen in practice with selection cuts) then handling that seems like a broader issue that the user might need to take care of externally. I'll point some more people towards this issue to invite some other opinions for what might be the most useful.

jpivarski commented 9 months ago

@jpivarski, were count, argmin, argmax, and count_nonzero not implemented for the Jax backend because they were not feasible for some reason, or were they just left for the future?

These seem to be fundamentally non-differentiable—in my understanding, at least. ak.count, at least, depends only on the array's structure, and we don't even represent the array structure (the ak.index.Index parts) using JAX because you can't differentiate through that. It would be equivalent to differentiating through variables that are used in an if predicate. It's less obvious for ak.count_nonzero (the delta function is differentiable in Lebesgue measure but not Riemannian measure—I don't think that's relevant) and ak.argmin/ak.argmax.

However, differentiable arrays should carry through some of this stuff. For instance, ak.count should be treated as a constant, so that ak.mean can be implemented by using the implementation of ak.sum and dividing it by the constant that comes out of ak.count. (The derivative of ak.mean is a scaled version of the derivative of ak.sum.) I don't know about the others, ak.count_nonzero and ak.argmin/ak.argmax, though.