tanaylab / metacells

Metacells - Single-cell RNA Sequencing Analysis
MIT License
92 stars 8 forks source link

Thread explosion #50

Open sophie-xhonneux opened 1 year ago

sophie-xhonneux commented 1 year ago

On powerful machines it easily happens to get an error like (see also issue #24):

...
OpenBLAS blas_thread_init: pthread_create failed for thread 20 of 56: Resource temporarily unavailable
OpenBLAS blas_thread_init: RLIMIT_NPROC 4096 current, 2062711 max
OpenBLAS blas_thread_init: pthread_create failed for thread 21 of 56: Resource temporarily unavailable
OpenBLAS blas_thread_init: RLIMIT_NPROC 4096 current, 2062711 max
OpenBLAS blas_thread_init: pthread_create failed for thread 22 of 56: Resource temporarily unavailable
OpenBLAS blas_thread_init: RLIMIT_NPROC 4096 current, 2062711 max
...

My fix was to rewrite metacells/utilities/parallel.py as below. Importantly, this was actually faster on a beefy HPC cluster node (50 parralel piles, 32 CPUs, 220GB of RAM). Thus, I wanted to ask if we could at least get a flag such that metacells runs code like below?

The changes below are the lines commented out in set_processors_count() and _invocation():

import ctypes
import os
import sys
from math import ceil
from multiprocessing import Value
from multiprocessing import get_context
from threading import current_thread
from typing import Any
from typing import Callable
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar

import psutil  # type: ignore
from threadpoolctl import threadpool_limits  # type: ignore

import metacells.utilities.documentation as utd
import metacells.utilities.logging as utl
import metacells.utilities.progress as utp
import metacells.utilities.timing as utm

if "sphinx" not in sys.argv[0]:
    import metacells.extensions as xt  # type: ignore

__all__ = [
    "is_main_process",
    "set_processors_count",
    "get_processors_count",
    "parallel_map",
]

PROCESSORS_COUNT = 0

MAIN_PROCESS_PID = os.getpid()

IS_MAIN_PROCESS: Optional[bool] = True

MAP_INDEX = 0
PROCESS_INDEX = 0

PROCESSES_COUNT = 0
NEXT_PROCESS_INDEX = Value(ctypes.c_int32, lock=True)
PARALLEL_FUNCTION: Optional[Callable[[int], Any]] = None

def is_main_process() -> bool:
    """
    Return whether this is the main process, as opposed to a sub-process spawned by
    :py:func:`parallel_map`.
    """
    return bool(IS_MAIN_PROCESS)

def set_processors_count(processors: int) -> None:
    """
    Set the (maximal) number of processors to use in parallel.

    The default value of ``0`` means using all the available physical processors. Note that if
    hyper-threading is enabled, this would be less than (typically half of) the number of logical
    processors in the system. This is intentional, as there's no value - actually, negative
    value - in running multiple heavy computations on hyper-threads of the same physical processor.

    Otherwise, the value is the actual (positive) number of processors to use. Override this by
    setting the ``METACELLS_PROCESSORS_COUNT`` environment variable or by invoking this function
    from the main thread.
    """
    assert IS_MAIN_PROCESS

    if processors == 0:
        processors = psutil.cpu_count(logical=False)

    assert processors > 0

    global PROCESSORS_COUNT
    PROCESSORS_COUNT = processors

    #threadpool_limits(limits=PROCESSORS_COUNT)
    #xt.set_threads_count(PROCESSORS_COUNT)
    #os.environ["OMP_NUM_THREADS"] = str(PROCESSORS_COUNT)
    #os.environ["MKL_NUM_THREADS"] = str(PROCESSORS_COUNT)

if "sphinx" not in sys.argv[0]:
    set_processors_count(int(os.environ.get("METACELLS_PROCESSORS_COUNT", "0")))

def get_processors_count() -> int:
    """
    Return the number of PROCESSORs we are allowed to use.
    """
    assert PROCESSORS_COUNT > 0
    return PROCESSORS_COUNT

T = TypeVar("T")

@utd.expand_doc()
def parallel_map(
    function: Callable[[int], T],
    invocations: int,
    *,
    max_processors: int = 0,
    hide_from_progress_bar: bool = False,
) -> List[T]:
    """
    Execute ``function``, in parallel, ``invocations`` times. Each invocation is given the invocation's index as its
    single argument.

    For our simple pipelines, only the main process is allowed to execute functions in parallel processes, that is, we
    do not support nested ``parallel_map`` calls.

    This uses :py:func:`get_processors_count` processes. If ``max_processors`` (default: {max_processors}) is zero, use
    all available processors. Otherwise, further reduces the number of processes used to at most the specified value.

    If this ends up using a single process, runs the function serially. Otherwise, fork new processes to execute the
    function invocations (using ``multiprocessing.get_context('fork').Pool.map``).

    The downside is that this is slow, and you need to set up **mutable** shared memory (e.g. for large results) in
    advance. The upside is that each of these processes starts with a shared memory copy(-on-write) of the full Python
    state, that is, all the inputs for the function are available "for free".

    If a progress bar is active at the time of invoking ``parallel_map``, and ``hide_from_progress_bar`` is not set,
    then it is assumed the parallel map will cover all the current (slice of) the progress bar, and it is reported into
    it in increments of ``1/invocations``.

    .. todo::

        It is currently only possible to invoke :py:func:`parallel_map` from the main application thread (that is, it
        does not nest).
    """
    if invocations == 0:
        return []

    assert function.__is_timed__  # type: ignore

    global IS_MAIN_PROCESS
    assert IS_MAIN_PROCESS

    global PROCESSES_COUNT
    PROCESSES_COUNT = min(PROCESSORS_COUNT, invocations)
    if max_processors != 0:
        assert max_processors > 0
        PROCESSES_COUNT = min(PROCESSES_COUNT, max_processors)

    if PROCESSES_COUNT == 1:
        return [function(index) for index in range(invocations)]

    NEXT_PROCESS_INDEX.value = 0  # type: ignore

    global PARALLEL_FUNCTION
    assert PARALLEL_FUNCTION is None

    global MAP_INDEX
    MAP_INDEX += 1

    num_threads = str(ceil(PROCESSES_COUNT / invocations))
    os.environ["OMP_NUM_THREADS"] = num_threads
    os.environ["MKL_NUM_THREADS"] = num_threads

    PARALLEL_FUNCTION = function
    IS_MAIN_PROCESS = None   
 try:
        results: List[Optional[T]] = [None] * invocations
        utm.flush_timing()
        with utm.timed_step("parallel_map"):
            utm.timed_parameters(index=MAP_INDEX, processes=PROCESSES_COUNT)
            with get_context("fork").Pool(PROCESSES_COUNT) as pool:
                for index, result in pool.imap_unordered(_invocation, range(invocations)):
                    if utp.has_progress_bar() and not hide_from_progress_bar:
                        utp.did_progress(1 / invocations)
                    results[index] = result
        return results  # type: ignore
    finally:
        IS_MAIN_PROCESS = True
        PARALLEL_FUNCTION = None
        os.environ["OMP_NUM_THREADS"] = str(PROCESSES_COUNT)
        os.environ["MKL_NUM_THREADS"] = str(PROCESSES_COUNT)

def _invocation(index: int) -> Tuple[int, Any]:
    global IS_MAIN_PROCESS
    # if IS_MAIN_PROCESS is None:
    #     IS_MAIN_PROCESS = os.getpid() == MAIN_PROCESS_PID
    #     assert not IS_MAIN_PROCESS

    #     global PROCESS_INDEX
    #     with NEXT_PROCESS_INDEX:
    #         PROCESS_INDEX = NEXT_PROCESS_INDEX.value  # type: ignore
    #         NEXT_PROCESS_INDEX.value += 1  # type: ignore

    #     current_thread().name = f"#{MAP_INDEX}.{PROCESS_INDEX}"
    #     utm.in_parallel_map(MAP_INDEX, PROCESS_INDEX)

    #     global PROCESSORS_COUNT
    #     start_processor_index = int(round(PROCESSORS_COUNT * PROCESS_INDEX / PROCESSES_COUNT))
    #     stop_processor_index = int(round(PROCESSORS_COUNT * (PROCESS_INDEX + 1) / PROCESSES_COUNT))
    #     PROCESSORS_COUNT = stop_processor_index - start_processor_index

    #     assert PROCESSORS_COUNT > 0
    #     utl.logger().debug("PROCESSORS: %s", PROCESSORS_COUNT)
    #     threadpool_limits(limits=PROCESSORS_COUNT)
    #     xt.set_threads_count(PROCESSORS_COUNT)
    #     os.environ["OMP_NUM_THREADS"] = str(PROCESSORS_COUNT)
    #     os.environ["MKL_NUM_THREADS"] = str(PROCESSORS_COUNT)

    assert PARALLEL_FUNCTION is not None
    result = PARALLEL_FUNCTION(index)
    return index, result
orenbenkiki commented 1 year ago

TL;DR: the last thing I'd expect is that that commenting out the code you did will prevent an explosion in the number of used threads. That is, I can understand that my code failed to prevent a threads explosion or failures, given the vagaries of the OpenBLAS / OpenMP implementations. But I can't see how this code, which only reduces the number of used threads, causes an increase in the number of used threads.

Gory details:

The whole point of the commented-out lines is to reduce the amount of parallelism used in nested Python processes. That is, we start with the top-level Python process, and when we call a parallel loop, we fork worker sub-processes. Because we are forking, all the (typically very large) arrays are available "for free" (well, except for the cost of a fork, of course). However, anything we compute in the worker processes has to be copied back to the main process (the multi-processing shared memory arrays in Python were very flaky when I tried to use them to avoid this - a saga all of its own).

This is clunky, but ended up being the least-bad approach I could find to get performance given the code is written in Python.

Naively, since each forked worker is a brand-new standalone process, if you use parallelism inside it (e.g. OpenMP or whatever), it will try to take over all the parallel cores in the machine. So on a machine with N cores you will end with N^2 active threads, also known as "having a very bad day". This seems to be what you are seeing.

(As a side note, N should be the number of physical rather than logical cores - you get no performance boost from the hyper-threads for dense computation code. In fact, you typically lose performance, because you are increasing cache pressure and memory footprint.)

At any rate, the intent of the code you commented out is to say "if I am in a worker process, restrict the amount of parallelism I am using to only my fair share of the machine". It also gives the worker processes nice names for logging etc. but that's a side benefit.

This is somewhat fragile due to the fact that OpenMP really hates the idea of (1) having a process use parallelism in OpenMP; (2) calling fork (not using OpenMP); (3) using OpenMP again in the child process. The problem is that OpenMP creates its worker threads pool and when forking, the child worker doesn't have these threads, which makes sense. However, because we forked, the in-memory data structures OpenMP uses think the threads exists, and hilarity ensues. A simple workaround would have been a way to reset OpenMP in the worker process, but OpenMP, in its infinite wisdom, does not provide such a function.

Therefore in my own C++ extension codes I do not use OpenMP. I manually use C++ threads, where I work around the problem by spinning up my threads for each parallel loop (yes, this is less efficient than it could have been, but seems to work OK in practice).

So much for intent - I'm trying to figure out what is happening in practice in your case. Looking at https://groups.google.com/g/openblas-users/c/W6ehBvPsKTw I see that whether OpenBLAS uses OpenMP or pthreads "depends". Seeing issues like https://github.com/xianyi/OpenBLAS/issues/294 and https://github.com/xianyi/OpenBLAS/issues/240 it seems people are aware of the problem and that they have some workarounds that work sometimes? Hard to say.

To figure things out, one would need to dynamically track the tree of processes and threads that is created, and understand who is creating all these threads (OpenMP? OpenBLAS? Someone else?), and what mysterious reason causes asking for less threads (in the commented out code) actually triggers the creation of more threads. I'd start with strace tracking the creation of processes and threads, but that would probably also require looking at the source code of at least OpenBLAS and possibly OpenMP as well.

Bottom line... I've no idea what is happening in your case. My takeaway from this is that using parallelism in Python is a losing proposition. Perhaps the "just write it in Julia and call it from a sequential Python or R wrapper" approach is the least insane option after all. Sigh.