mreineck / ducc

Fork of https://gitlab.mpcdf.mpg.de/mtr/ducc to simplify external contributions
GNU General Public License v2.0
13 stars 12 forks source link

Allow custom parallelization functor (version 2). #8

Closed cantonios closed 1 year ago

cantonios commented 1 year ago

This is to support the use of a custom threadpool, like Eigen's or TensorFlow's. This will enable JAX/TF to use DUCC0 for FFTs with multithreading.

This alternative implementation provides more details to the custom threading module for computing chunks of work (i.e. input fmav_info, axis, vlen). It also replaces the execParallel calls with the more generic threading routines (like execStatic, execDynamic) to allow custom division of work via Scheduler::getNext(). This would allow TF/Eigen to make use of their cost-based recursive ParallelFor routines.

mreineck commented 1 year ago

Thanks a lot!

Two things that just came to my mind:

cantonios commented 1 year ago

if the requested FFT is strictly 1D (i.e. called with 1D arrays), we allow the 1D transforms to do their own parallelization

If you're happy with this approach, we can do similar for the 1D case. We may want to change the inputs to the *Threading operator to be a bit more generic so it can be shared between the two kinds of invocations. I'm not yet sure what the inputs would be.

At minimum, we need the thread-count hint nthreads, and total number of work elements nwork. For ND, nwork is the number of 1D FFTs along a single axis, right? For 1D, it's nvtrans (not sure what this is - I'm not very familiar with FFT implementations in general).

Ideally TF would also have access to something that gives a hint of compute cycles per work element. For ND, this is the cost of a 1D FFT of length axis_length. I'm not sure what it would be for 1D.

mreineck commented 1 year ago

At minimum, we need the thread-count hint nthreads, and total number of work elements nwork. For ND, nwork is the number of 1D FFTs along a single axis, right?

Almost: it's the total size of the array divided by the length of the axis being transformed (potentially divided by the length of a SIMD vector); for 2D arrays, that's the same.

For 1D, it's nvtrans (not sure what this is - I'm not very familiar with FFT implementations in general).

Basically the length of the 1D transform is factored into two numbers n1, n2 that are as similar as possible, and the 1D transform is carried out as a 2D transform of dimensions (n1, n2) with a special twiddling step in between. So we basically make a detour to 2D to be able to parallelize. nvtrans is n1 or n2 divided by the SIMD length.

mreineck commented 1 year ago

I may have an idea how to address all the issues elegantly, without having to add the extra parameter to the interface of every function that has a parallel region or indirectly calls such a function.

In the crrent state of ducc, we have a (somewhat hidden) global variable holding the thread pool. For your application this is insufficient since in the TF context several different thread pools may be in use at the same time.

How about the following: instead of a single global thread pool, we instead use a global (and probably also thread-local) stack of "threading environments", where a threading environment is derived from an abstract base class that offers virtual methods like get_max_nthreads, execParallel, execDynamic etc. There can be implementations of this class for strict single-threading, ducc-like multithreading, TF-like multithreading, etc., and you can (for example) have several TF-like environment objects with different thread pools.

Whenever a parallel code block is encountered, the currently active thread simply asks the top of the threading environment stack to take care of it, and it will just do the right thing.

This avoids the need to introduce additional parameters, which feels like a big advantage to me, but there may be issues I have completely overlooked. Also, introducing this would be a fairly big change and take some time, especially to validate.

If this turns out to be possible, this would be my longer-term goal. But I have absolutely no problem of merging something else in the meantime that works sufficiently well for TF integration.

cantonios commented 1 year ago

How about the following: instead of a single global thread pool, we instead use a global (and probably also thread-local) stack of "threading environments", where a threading environment is derived from an abstract base class that offers virtual methods like get_max_nthreads, execParallel, execDynamic etc.

My first attempt before putting up Version 1 was actually to make Distribution an abstract interface and put it in threading.h. I passed a custom Distribution in as an argument instead of the std::function version I'm using now. It required messing around with the threading module though, so I thought this would be more palatable. I'd be okay with going back to something like that.

Whenever a parallel code block is encountered, the currently active thread simply asks the top of the threading environment stack to take care of it, and it will just do the right thing.

If we have a global stack, we would need to lock the stack while performing the FFT to ensure that nobody else modifies the top of the stack. This would prevent parallel calls into the FFT functions. TF has a model serving mode that runs multiple models in parallel (or instances of a single model on different inputs), so we would potentially have collisions here. In this case, I don't see the advantage to a stack over a single global instance - unless you were to also pass an identifier into the FFT functions like c2c telling it which member of the stack to use.

A thread_local version might work. The caller would set the thread-local Distribution, call the FFT function, then reset the Distribution. The issue here would be that the distribution would be reset in new threads spawned internally within FFT execution. As long as the default was DUCC0_NO_THREADING this would be alright - but I'm not sure about your existing usages. You might be okay because of the in_parallel_region variable. Other potential users would need to keep this in mind.

One final alternative I thought of last night would be to take a more object-oriented approach, and put the FFT functions into a class that has a virtual execParallel function that can be overridden. You could restore your original public API by using a default instance that forwards calls into ducc0::execParallel.

mreineck commented 1 year ago

Superseded by #9, closing...