dask-contrib / dask-awkward

Native Dask collection for awkward arrays, and the library to use it.
https://dask-awkward.readthedocs.io
BSD 3-Clause "New" or "Revised" License
60 stars 19 forks source link

Conditional branching and __bool__ for arrays #538

Closed alexander-held closed 4 days ago

alexander-held commented 2 weeks ago

In discussions with @jpivarski and @pfackeldey at pyhep.dev we identified a few aspects of behavior that I believe could be improved.

Short version

bool() on a dask-awkward array

import dask_awkward as dak

arr = dak.from_lists([[0, 0]])
bool(arr)

evaluates as True. This seems independent of the content of the array. The example above should raise an exception. If it is already known that the bool() does not make sense, such in this case where the argument is not a scalar, I believe an exception such as what numpy provides makes sense:

import dask.array as da
import numpy as np

arr = da.array(np.zeros(2))
bool(arr)

-> ValueError: The truth value of a Array is ambiguous. Use a.any() or a.all().

Automatic computation for branching

As a consequence of the behavior shown above, where the array evaluates as True, branching based on the content of a dask-awkward array is not working correctly.

import uproot
import dask
import awkward as ak

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

evts = uproot.dask({ttbar_file: "Events"})
evts["meta"] = ak.zeros_like(evts.run, dtype="float")

def f(evts):
    if ak.sum(evts["meta"]) > 100:
        res = {"num_electrons": ak.sum(ak.num(evts.Electron_pt, axis=1))}
    else:
        res = {"num_muons": ak.sum(ak.num(evts.Muon_pt, axis=1))}
    return res

dask.compute(f(evts))

This returns ({'num_electrons': 69},), as ak.sum(evts["meta"]) > 100 is True, despite

bool((ak.sum(evts["meta"]) > 100).compute())

being False when actually computed. The behavior here departs from what Dask does. Consider the following variation with dask.array:

import uproot
import dask
import awkward as ak
import numpy as np
import dask.array as da

ttbar_file = "https://github.com/scikit-hep/scikit-hep-testdata/"\
    "raw/main/src/skhep_testdata/data/nanoAOD_2015_CMS_Open_Data_ttbar.root"

evts = uproot.dask({ttbar_file: "Events"})
meta_arr = da.array(np.zeros(200))

def f(evts, meta_arr):
    if np.sum(meta_arr) > 100:
        res = {"num_electrons": ak.sum(ak.num(evts.Electron_pt, axis=1))}
    else:
        res = {"num_muons": ak.sum(ak.num(evts.Muon_pt, axis=1))}
    return res

dask.compute(f(evts, meta_arr))

In this case the last line evaluates as ({'num_muons': 41},), which is the correct answer. What happens as far as I understand is that a .compute() is triggered when building the graph and encountering if np.sum(meta_arr) > 100: dask/array/core.py#L1861 (same for scalars dask/array/core.py#L1869).

The behavior of dask.array ultimately causes the expected result to be produced, however the way in which this happens might be surprising. A possibly expensive computation can happen during what otherwise would just be graph building, in a way that is not explicitly communicated to the user.

The automatic .compute() in certain cases is also the reason for why

import dask.array as da
import numpy as np
bool(da.array(np.zeros(1))), bool(da.array(np.ones(1)))

evaluates to (False, True) without any .compute() visible.

Thoughts on expected behavior

I find the automatic .compute() very unintuitive. Users coming from pure Dask might be used to it, so I can see arguments for following the same pattern. I would however propose to raise an exception regardless when a bool() evaluation on a dask-awkward needs to happen during tracing, informing the user that they might want to consider calling .compute() manually. In the first example, this would look like the following:

if (ak.sum(evts["meta"]) > 100).compute():

instead of:

if ak.sum(evts["meta"]) > 100:

In high energy physics applications, typical use cases I have in mind where branching is needed are different processing pipelines for different types of input files. In these cases I believe it is desirable to not evaluate all possible ways to trace through the function, which I think is what a np.where does in contrast (evaluate both and then combine afterwards). This helps prevent potentially expensive calculations. Whenever possible, such branching should probably be made with static or already evaluated metadata as opposed to in-file non-evaluated metadata to prevent complications from the behavior listed above.

This behavior might catch users off guard, especially if they port awkward code over to dask-awkward. I am not sure where to best communicate this as a possible "gotcha" but I think it could be useful to spread awareness.

lgray commented 2 weeks ago

The implicit .compute()s are extremely unintuitive, we already see this routinely with folks using dak.to_parquet and uproot.dask_write. I would honestly suggest we change everything to be lazy by default as much as possible.

The metadata issue comes down to opinions about organizational practices and how those impact coding and data preparation.

I think some of the opinions are better justified than others, but they heavily depend on how you were taught to do data analysis and assumptions about how metadata is organized and presented to you.

I would argue that statements like:

if (ak.sum(evts["meta"]) > 100).compute():

should be avoided in columnar analysis scenarios since this pattern is the same as doing one step of cheap further processing of metadata.

i.e. using the metadata area available in nanoevents, during dataset preparation:

events.metadata["do_muons"] = (ak.sum(whatever_populates_evts["meta"]).compute() > 100)

Since this is something that everyone would do in order to process electrons vs. muons, this implies that when creating the dataset you get events from, some of this information should have been synthesized already. I imagine this is probably actually talking about ATLAS's PHYSLITE MC cross sections varying event by event or something like that. In that case what you really need is a lookup tool from process bar-code to cross section, which correctionlib handily provides you. One just needs to make the appropriate lookup table and centrally provide it.

On CMS we avoid that pattern by keeping the pt-binned datasets separate all the way to the end, they can then be processed in parallel by the same code, and later augmented and combined with appropriately normalized cross section weights, or just processed and filled into histograms normally. These are merely different ways of letting the user do what they want, but one of them reduces complexity and variation of code at the cost of slightly more end-user bookkeeping.

martindurant commented 2 weeks ago

Some of this is also trying to be compliant with various other standards around. One concrete case: python len() is required to return an int, so there's no choice there.

You get an exception if you try to bool on an array or dataframe (dask or otherwise), although exactly one value is a special case for arrays.

pfackeldey commented 2 weeks ago

JAX will instead raise a TracerBoolConversionError during trace-time if you try to bool on a jax.Array. To me this feels like the more natural way. In Dask the error message may want to explicitly tell the user that one can do this but then you need to consciously add a .compute() by yourself.