iris-hep / func_adl

Construct hierarchical data queries using SQL-like concepts in python
MIT License
7 stars 4 forks source link

Bad parsing of a lambda in the ttbar analysis notebook #113

Closed gordonwatts closed 1 year ago

gordonwatts commented 1 year ago

The below code (which is a test case) cauess an assert failure - no lambda is found.

def test_parse_multiline_lambda_with_comment():
    "Comment in the middle of things"

    found = []

    class my_obj:
        def Where(self, x: Callable):
            found.append(parse_as_ast(x))
            return self

        def Select(self, x: Callable):
            found.append(parse_as_ast(x))
            return self

        def AsAwkwardArray(self, stuff: str):
            return self

        def value(self):
            return self

    source = my_obj()
    # fmt: off
    r = source.Where(lambda e:
        # == 1 lep
        e.electron_pt.Where(lambda pT: pT > 25).Count() + e.muon_pt.Where(lambda pT: pT > 25).Count()== 1
        )\
        .Where(lambda e:\
            # >= 4 jets
            e.jet_pt.Where(lambda pT: pT > 25).Count() >= 4
        )\
        .Where(lambda e:\
            # >= 1 jet with pT > 25 GeV and b-tag >= 0.5
            {"pT": e.jet_pt, "btag": e.jet_btag}.Zip().Where(lambda jet: jet.btag >= 0.5 and jet.pT > 25).Count() >= 1
        )\
        .Select(lambda e:\
            # return columns
            {
                "electron_e": e.electron_e,
                "electron_pt": e.electron_pt,
                "muon_e": e.muon_e,
                "muon_pt": e.muon_pt,
                "jet_e": e.jet_e,
                "jet_pt": e.jet_pt,
                "jet_eta": e.jet_eta,
                "jet_phi": e.jet_phi,
                "jet_btag": e.jet_btag,
                "numbermuon": e.numbermuon,
                "numberelectron": e.numberelectron,
                "numberjet": e.numberjet,
            }
        )
    # fmt: on

    assert "electron_pt" in ast.dump(found[0])
gordonwatts commented 1 year ago

Turns out it is the following lambda that is causing the problem here. Here is the minimum expression to cause the exception:

    r = source.Where(lambda e:
        # == 1 lep
        e.electron_pt.Where(lambda pT: pT > 25).Count() + e.muon_pt.Where(lambda pT: pT > 25).Count()== 1
        )\
        .Where(lambda e:\
            # >= 4 jets
            e.jet_pt.Where(lambda pT: pT > 25).Count() >= 4
        )

The lambda that fails to find is, I think, the first Where.

gordonwatts commented 1 year ago

If you let black format this, then this works correctly (I've been using black as the "must work" standard for lambda's).

gordonwatts commented 1 year ago

Even more condensed - this fails:

    r = source.Where(lambda e:
        e.electron_pt.Where(lambda pT: pT > 25).Count() + e.muon_pt.Where(lambda pT: pT > 25).Count()== 1)\
        .Where(lambda e:\
            e.jet_pt.Where(lambda pT: pT > 25).Count() >= 4
        )

but this passes:

    r = source.Where(lambda e:
        e.electron_pt.Where(lambda pT: pT > 25).Count() + e.muon_pt.Where(lambda pT: pT > 25).Count()== 1)\
        .Where(lambda e:
            e.jet_pt.Where(lambda pT: pT > 25).Count() >= 4
        )

(note the use of the uncessary continuation character in the Where).

gordonwatts commented 1 year ago

Ok - this comes down to the fact we aren't handling line continuation characters properly. python hands us source code with the line continuation characters in them - which is crazy! So we need to take them out.