arviz-devs / arviz

Exploratory analysis of Bayesian models with Python
https://python.arviz.org
Apache License 2.0
1.58k stars 394 forks source link

Parallelizing ArviZ: Dask compatibility #786

Open OriolAbril opened 5 years ago

OriolAbril commented 5 years ago

Currently, ArviZ does not try to parallelize its functions in order to improve performance. It would be great to combine the speed-ups from using Numba with parallel computation using Dask if desired. Citing xarray docs:

xarray integrates with Dask to support parallel computations and streaming computation on datasets that don’t fit into memory.

Currently, Dask is an entirely optional feature for xarray. However, the benefits of using Dask are sufficiently strong that Dask may become a required dependency in a future version of xarray.

Both arguments about parallelization and about arrays not fitting in memory are relevant to ArviZ. Moreover, Dask may become a requirement for xarray, and writing our code to be dask compatible could save a lot of time if this eventually happens.

I propose to modify all functions using wrap_xarray_ufunc to have some kwargs passed to xr.apply_ufunc so that if the datasets in the InferenceData object are dask arrays, its capabilities can be used to parallelize the code. In addition, allowing to pass kwargs also to xr.open_dataset will allow to load InferenceData objects from netdcf files directly with dask (with chunks if desired for example).

Below there are some experiments to see the possible benefits of this changes (and keep in mind I don't really know dask at all, I have read its tutorial and xarray docs on it):

import arviz as az
import numpy as np
import logging

# hide logged warning due to wrong shape, as it is already fixed in development and irellevant
# https://github.com/pydata/xarray/issues/3168
logging.getLogger("arviz").setLevel(logging.ERROR)

idata = az.from_dict(posterior={"var": np.random.random(size=(10, 3000, 2000))}, dims={"var": ["dim1"]})

print("###########   xr.Dataset containing numpy array   ###########")
%timeit ess = az.ess(idata)
ess_orig = az.ess(idata)

print("###########   xr.Dataset containing dask array   ###########")
# chunking an xr.Dataset automatically converts it to dask array
idata.posterior = idata.posterior.chunk({"dim1": 100})  

# here there may be another warning https://github.com/pydata/xarray/issues/2928
# that can be solved updating to xarray's latest version
%timeit ess_parallelized = az.ess(idata, dask="parallelized", output_dtypes=[float]).compute()
ess_parallelized = az.ess(idata, dask="parallelized", output_dtypes=[float]).compute()

print("###########   check results    ###########")
print((ess_orig == ess_parallelized).all())

which outputs:

###########   xr.Dataset containing numpy array   ###########
15.2 s ± 89.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

###########   xr.Dataset containing dask array   ###########
6.09 s ± 182 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

###########   check results    ###########
<xarray.Dataset>
Dimensions:  ()
Data variables:
    var      bool True

As you can see, the speed-up is quite relevant directly with this vanilla solution. There is probably room for improvement using pure dask, however, I do not think this should be an immediate concern (or a concern at all, maybe an idea for next year's GSoC?), using dask via xarray should be simple enough and provide good speed-ups for large data.

Implementation ideas

As more or less said before, just summing up, we can use dask via xarray to handle parallelization with very little work involved (passing kwargs to wrap_xarray_ufunc and to xr.open_dataset should be enough). However, some things must be taken into account:

I do not have a clear idea on how to pass these dask related kwargs for all relevant functions, for ess it can be quite straightforward, but for summary? Should there be an option to allow different dask kwargs for ess and for rhat? Also, should ArviZ return unevaluated dask espressions so the user can run compute() whenever he wants? Should there be an option to optionally compute things? For summary for instance it could be better to compute ess and rhat at the same time, and also, could summary work without evaluating the values? Or will it always trigger computation via .values attribute?

It would probably be great to eventually add some tests on InferenceData objects containing dask arrays too, but not too many to avoid segmentation fault like we would get if tests were not run on "eager" mode.

ahartikainen commented 5 years ago

rcParam for dask?

Some function to translate idata to dask

Some function to set dask rcParam -> also set if return dask graph or evaluate

Also, see summary func, it already contains the main calculations in one func.

For rhat, ess etc, we can add .compute based on rcparams

OriolAbril commented 4 years ago

Possible reference: http://xarray.pydata.org/en/stable/examples/apply_ufunc_vectorize_1d.html