sgkit-dev / sgkit

Scalable genetics toolkit
https://sgkit-dev.github.io/sgkit
Apache License 2.0
235 stars 32 forks source link

Dask 2024.8.1 and later is very slow #1267

Open tomwhite opened 1 month ago

tomwhite commented 1 month ago

This was originally reported in #1247 and a temporary pin introduced in #1248. I've opened this to track the issue so we can remove the pin.

tomwhite commented 1 month ago

I've opened https://github.com/dask/dask/issues/11416

tomwhite commented 3 weeks ago

Unfortunately, it looks like Dask 2024.10.0 doesn't fix this, see https://github.com/sgkit-dev/sgkit/actions/runs/11551276595 which is taking 19 minutes to run, rather than 6 (with Dask 2024.08.0).

tomwhite commented 3 weeks ago

On further investigation what's happening is that locally defined functions that are passed to Dask map_blocks and that wrap Numba functions are being recompiled every time the (genomics) method is called. For example in pbs:

https://github.com/sgkit-dev/sgkit/blob/9dd940e2de95edbc917a947b4ecc52193bf46e1e/sgkit/stats/popgen.py#L598-L600

The lambda function calls a Numba function that is recompiled each time.

In most cases it's fairly easy to rewrite the code to avoid the use of locally defined functions. For PBS we can just do:

-    p = da.map_blocks(
-        lambda t: _pbs_cohorts(t, ct), t, chunks=shape, new_axis=3, dtype=np.float64
-    )
+    p = da.map_blocks(_pbs_cohorts, t, ct, chunks=shape, new_axis=3, dtype=np.float64)

The distance metrics code is more dynamic though, so it's not a simple fix:

https://github.com/sgkit-dev/sgkit/blob/9dd940e2de95edbc917a947b4ecc52193bf46e1e/sgkit/distance/api.py#L111-L143

tomwhite commented 3 weeks ago

I've fixed the non-distance functions in this commit: https://github.com/sgkit-dev/sgkit/pull/1261/commits/e83b52cdf1ef1b305eefdd8bcaca55b437cc4e4b

I'm not sure what to do about the distance functions at this point.

jeromekelleher commented 2 weeks ago

There's only two possible metrics right now ('euclidean' or 'correlation') so I vote we make the code less clever and just code in the function names directly for those two cases?

tomwhite commented 2 weeks ago

That's what I thought too - but there is another wrinkle. In this diff

https://github.com/tomwhite/sgkit/commit/e1119ca68f979cafec8ead9bd0c829de2e6e4d8e

previously metric_param was initialized outside the function to prevent Dask serialization/deserialization time (see the comment).

I suppose we could have a map of (shared) empty arrays keyed by dtype - but that doesn't seem very thread safe. Or we could initialize in the function, and leave a comment about how this previously caused Dask slowdown. Another option would be to remove the code!

jeromekelleher commented 2 weeks ago

Ah, I see. I'm reluctant to remove the code as we put quite a lot of effort in and it's our main usage of GPUs...

Perhaps @aktech would like to comment here? Is there an easy way to avoid using lambdas?