CoffeaTeam / coffea

Basic tools and wrappers for enabling not-too-alien syntax when running columnar Collider HEP analysis.
https://coffeateam.github.io/coffea/
BSD 3-Clause "New" or "Revised" License
128 stars 126 forks source link

Nanoevents + jax looks for _mass2_kernel in the wrong spot #874

Closed alexander-held closed 1 week ago

alexander-held commented 1 year ago

Describe the bug When combining nanoevents + jax, invariant mass calculations no longer work.

It is not very clear to me if this is rather an issue in awkward or elsewhere, but _mass2_kernel comes from coffea so I figured I'd start here.

To Reproduce A full example is at https://github.com/alexander-held/agc-autodiff/blob/9bcad94a689063b130829bd33fff12e17dd43c36/nanoevents_plus_jax.ipynb. This is using coffea== 2023.7.0rc0.

import awkward as ak
from coffea.nanoevents import NanoEventsFactory, NanoAODSchema

ak.jax.register_and_check()
NanoAODSchema.warn_missing_crossrefs = False # silences warnings about branches we will not use here

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

events = NanoEventsFactory.from_root({ttbar_file: "Events"}, schemaclass=NanoAODSchema).events()
events = ak.to_backend(events, "jax")

evtfilter = ak.to_backend(ak.num(events.Jet.pt) >= 2, "jax")  # backend call is needed here!
jets = events.Jet[evtfilter]

(jets[:, 0] + jets[:, 1]).mass

Expected behavior Mass calculation succeeds (just like when not using jax).

Output

AttributeError: module 'jax.numpy' has no attribute '_mass2_kernel'

Desktop (please complete the following information): n/a

Additional context n/a

lgray commented 1 year ago

This looks like a behavior dispatch issue in awkward. Can you try boiling this down to not include coffea?

alexander-held commented 1 year ago

I have another example in https://github.com/scikit-hep/awkward/issues/2591 which was (partially) fixed by https://github.com/scikit-hep/awkward/pull/2592 (now only hits RuntimeError: Cannot differentiate through count_zero which is an independent issue). I'm no expert on behaviors but that does the addition & mass calculation successfully with this Momentum4D behavior (cc @agoose77 for that part).

Is this qualitatively similar? If so I guess I'd have to interpolate between that and the nanoevents behavior by simplifying it in coffea until it works?

lgray commented 1 year ago

I'd import some things with uproot directly, build the array you want, and then apply the behaviors from coffea by hand.

You don't have to use the full machinery of nanoevents to replicate this, for sure.

alexander-held commented 1 year ago

Here's another reproducer that only applies PtEtaPhiECandidate behavior:

import awkward as ak
from coffea.nanoevents.methods import candidate
import numpy as np
import uproot

ak.jax.register_and_check()
ak.behavior.update(candidate.behavior)

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

with uproot.open(ttbar_file) as f:
    arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
                              "Electron_mass", "Electron_charge"])

px = arr.Electron_pt * np.cos(arr.Electron_phi)
py = arr.Electron_pt * np.sin(arr.Electron_phi)
pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)

evtfilter = ak.num(arr["Electron_pt"]) >= 2

els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
              "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[evtfilter]
els = ak.to_backend(els, "jax")

(els[:, 0] + els[:, 1]).mass

which results in the same problem. Sounds like an awkward issue then perhaps? I'll open one there.

This does work with "Momentum4D" from vector (requiring vector.register_awkward()).

alexander-held commented 11 months ago

I believe the underlying issue is the clash of numba and jax as described in https://github.com/scikit-hep/awkward/issues/2603#issuecomment-1748380547. The setup in the previous comment can be patched with this snippet:

from coffea.nanoevents.methods import candidate, vector

def _mass2_kernel(t, x, y, z):
    return t * t - x * x - y * y - z * z

class PatchedLorentzVector(vector.LorentzVector):
    @property
    def mass2(self):
        """Squared `mass`"""
        return _mass2_kernel(self.t, self.x, self.y, self.z)

candidate.Candidate.__bases__ = (PatchedLorentzVector,)

It is not a good idea to remove numba for everyone just to make this work. Most of the time people are presumably not using jax, so this is not really something that can be merged.

Saransh-cpp commented 6 months ago

I tried running the code sample on the use_scikithep_vector branch, and it still errors, but the code runs fine using Momentum4D behavior from scikit-hep/vector. Since coffea is removing its vector module entirely in the upcoming months, this issue will likely be solved automatically on the coffea side. I am looking into the problem on awkward's side and will start a discussion on the linked issue.

Edit: this should be solved by https://github.com/scikit-hep/awkward/pull/3025.

Saransh-cpp commented 6 months ago

This can be closed now. The issue has been resolved on the main branch of awkward:

In [1]: import awkward as ak
   ...: from coffea.nanoevents.methods import candidate
   ...: import numpy as np
   ...: import uproot
   ...: 
   ...: ak.jax.register_and_check()
   ...: ak.behavior.update(candidate.behavior)
   ...: 
   ...: ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
   ...:     "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"
   ...: 
   ...: with uproot.open(ttbar_file) as f:
   ...:     arr = f["Events"].arrays(["Electron_pt", "Electron_eta", "Electron_phi",
   ...:                               "Electron_mass", "Electron_charge"])
   ...: 
   ...: px = arr.Electron_pt * np.cos(arr.Electron_phi)
   ...: py = arr.Electron_pt * np.sin(arr.Electron_phi)
   ...: pz = arr.Electron_pt * np.sinh(arr.Electron_eta)
   ...: E = np.sqrt(arr.Electron_mass**2 + px**2 + py**2 + pz**2)
   ...: 
   ...: 
   ...: evtfilter = ak.num(arr["Electron_pt"]) >= 2
   ...: 
   ...: els = ak.zip({"pt": arr.Electron_pt, "eta": arr.Electron_eta, "phi": arr.Electron_phi,
   ...:               "energy": E, "charge": arr.Electron_charge}, with_name="PtEtaPhiECandidate")[ev
   ...: tfilter]
   ...: els = ak.to_backend(els, "jax")
   ...: 
   ...: (els[:, 0] + els[:, 1]).mass
Out[1]: <Array [86.903534, 97.60412, ..., 62.408997, 50.49058] type='5 * float32'>
lgray commented 6 months ago

@Saransh-cpp lemme know the version of awkward this will correspond to and we will close it with a pin adjustment to coffea.

Saransh-cpp commented 6 months ago

Great! I can create a PR once the fix is out in a release.

Saransh-cpp commented 5 months ago

@lgray - the fix is a part of awkward 2.6.3 (#1068), so it should be safe to close this now.

lgray commented 5 months ago

@alexander-held can you test this with coffea main to see if your issue is resolved?

alexander-held commented 4 months ago

I'm trying to remind myself of the setup here. I can confirm that the example in https://github.com/CoffeaTeam/coffea/issues/874#issuecomment-1662407683 now works with latest coffea + awkward releases.

The example in the original issue at the top runs into a NotImplementedError at events = ak.to_backend(events, "jax") coming out of awkward/_dispatch.py", line 56, in dispatch. That is certainly a different issue than before. Is there something about the full events object that would impact the to_backend call in comparison to just sending over an electron object?

Saransh-cpp commented 4 months ago

I think this is happening because events in the original example is a dask_awkward array. The snippet works if .compute() is called -

In [1]: import awkward as ak
   ...: from coffea.nanoevents import NanoEventsFactory, NanoAODSchema
   ...: 
   ...: ak.jax.register_and_check()
   ...: NanoAODSchema.warn_missing_crossrefs = False # silences warnings about 
   ...: branches we will not use here
   ...: 
   ...: ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
   ...:     "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.
   ...: root"
   ...: 
   ...: events = NanoEventsFactory.from_root({ttbar_file: "Events"}, schemaclas
   ...: s=NanoAODSchema).events()
   ...: events = ak.to_backend(events.compute(), "jax")  # compute is required for switching backends
   ...: 
   ...: evtfilter = ak.to_backend(ak.num(events.Jet.pt) >= 2, "jax")  # backend call is needed here!
   ...: jets = events.Jet[evtfilter]
   ...: 
   ...: (jets[:, 0] + jets[:, 1]).mass
Out[1]: <Array [157.21956, 81.92088, ..., 32.363174, 223.94753] type='140 * float32'>
alexander-held commented 4 months ago

I see, mixing Dask + Jax is something that we don't support at the moment as far as I'm aware so that makes sense that it would not work. Then from my side we can close this as fixed, thank you!

Saransh-cpp commented 1 week ago

@lgray a gentle bump on closing this 🙂