rapidsai / cuml

cuML - RAPIDS Machine Learning Library
https://docs.rapids.ai/api/cuml/stable/
Apache License 2.0
4.03k stars 519 forks source link

[DOC] More guidance about parallel usage for many small problems on single GPU #3588

Open frankier opened 3 years ago

frankier commented 3 years ago

Report needed documentation

Firstly, thanks for the project! I'm really looking forward to being able to combine these classical ML approaches with deep learning all on the GPU.

There should be some kind of example of solving multiple small problems e.g. with LogisticRegression concurrently from Python. Currently it's not clear how best to do this, or to what extent it is even supported. Say as an example, I have 1000 such problems, and want to have 8 solvers running concurrently at any one time. I think this would be a reasonable application for CuML since running them in serial is likely to not use that full capacity of a large GPU.

A seemly straightforward approach to this problem is to enqueue multiple CUDA streams, as is suggested is possible internally to CuML https://github.com/rapidsai/cuml/blob/branch-0.19/wiki/python/DEVELOPER_GUIDE.md#asynchronous-operations-and-stream-ordering -- however it seems that LogisticRegression.fit(...) will block until the CUDA process is finished, so this won't work. For this to work we would need an API like fit_queue(), fit_sync which lets us queue up our work and sync everything at the end. This would be nice if possible because then we only need to worry about CUDA concurrency rather than CUDA concurrency + Python concurrency.

The other approach, which already seems like it should be possible, is to use multithreading or multiprocessing and queue things up on different CUDA cores.

Here's what I've got so far.

import cupy
import torch
from cuml.linear_model import LogisticRegression
from cuml.raft.common import Handle, Stream
from joblib import Parallel, delayed

def logit(x, y):
    stream = Stream()
    handle = Handle()
    handle.setStream(stream)
    model = LogisticRegression(handle=handle)
    model = model.fit(cupy.asarray(x), cupy.asarray(y))
    del handle
    del stream
    return model.coef_, model.intercept_

all_delayed = []
for _ in range(100):
    x = torch.rand((4, 2), dtype=torch.float).cuda()
    y = torch.tensor((4,), dtype=torch.float).cuda()

    all_delayed.append(delayed(logit)(x, y))

with Parallel(n_jobs=16, backend="threading") as parallel:
    for result in parallel(all_delayed):
        print("result", result)

My questions are: 1) Is it ok/best to use threading? Would there be a problem using multprocessing (loky backend from joblib) instead? Is it supported even if it's not necessary? 2) Is it necessary to create a new handle manually as in this example, or do we automatically get a new stream per thread? 3) Is it neccesary to manually clean up handles to avoid them overlapping with later calls in the worker threads as is done here?

The documentation could give the best possible/most minimal example code, and explain why it was done that way and outline any pitfalls/unsupported ways of doing parallelism. Some guidance how to pick the number of workers/streams per GPU could also be useful.

frankier commented 3 years ago

Looks like someone was trying to do this before here: https://github.com/rapidsai/cuml/issues/1320

frankier commented 3 years ago

I've tried a few things now that https://github.com/rapidsai/cuml/issues/3587 has been fixed. The first thing I tried was as above: threading using joblib:

import torch
from cuml.linear_model import LogisticRegression
from cuml.raft.common import Handle, Stream
from joblib import Parallel, delayed

def logit(x, y):
    with open("loggy", "a") as loggy:
        print("start", x, y, file=loggy, flush=True)
    stream = Stream()
    handle = Handle()
    handle.setStream(stream)
    model = LogisticRegression(handle=handle)
    model = model.fit(x, y)
    with open("loggy", "a") as loggy:
        print("pre del", file=loggy, flush=True)
    del stream
    del handle
    with open("loggy", "a") as loggy:
        print("done", model.coef_, model.intercept_, file=loggy, flush=True)
    return model.coef_, model.intercept_

cuda_dev = torch.device("cuda")
all_delayed = []
for idx in range(1024):
    print("idx", idx, flush=True)
    x = torch.rand(4, 2, device=cuda_dev, dtype=torch.float)
    y = torch.rand(4, device=cuda_dev, dtype=torch.float)

    all_delayed.append(delayed(logit)(x, y))

print("starting", flush=True)
with Parallel(n_jobs=16, backend="threading") as parallel:
    print("pool started", flush=True)
    for result in parallel(all_delayed):
        print("result", result, flush=True)

In this case after waiting several minutes I get the following on stdout:

idx 0
...
idx 1023
starting
pool started

And in loggy it might be typical to have e.g.:

$ grep -c done loggy
8
$ grep -c start loggy
68

So it appears this approach gets livelocked/deadlocked somehow (high cpu usage in one processor/zero gpu usage)

Here is the view from py-spy:

  0.00% 1800.00%   0.000s     3531s   _bootstrap (threading.py:890)
  0.00% 1800.00%   0.000s     3531s   _bootstrap_inner (threading.py:932)
  0.00% 1800.00%   0.000s     3531s   run (threading.py:870)
  0.00% 1600.00%   0.000s     3139s   <listcomp> (joblib/parallel.py:262)
  0.00% 1600.00%   0.000s     3139s   worker (multiprocessing/pool.py:125)
  0.00% 1600.00%   0.000s     3139s   __call__ (joblib/_parallel_backends.py:595)
  0.00% 1600.00%   0.000s     3139s   __call__ (joblib/parallel.py:262)
200.00% 1300.00%   392.3s     2550s   inner_with_setters (cuml/internals/api_decorators.py:409)
  0.00% 1300.00%   0.000s     2550s   logit (testy.py:14)
  0.00% 300.00%   0.000s    588.5s   full (cuml/common/array.py:328)
300.00% 300.00%   588.5s    588.5s   full (cupy/_creation/basic.py:272)
  0.00% 300.00%   0.000s    588.5s   ones (cuml/common/array.py:360)
  0.00% 300.00%   0.000s    588.5s   inner (cuml/internals/api_decorators.py:360)
  0.00% 300.00%   0.000s    588.5s   cupy_rmm_wrapper (cuml/common/memory_utils.py:93)
200.00% 200.00%   392.3s    392.3s   unique (cupy/_manipulation/add_remove.py:106)
100.00% 200.00%   196.2s    392.3s   array (cupy/_creation/from_data.py:41)
200.00% 200.00%   392.3s    392.3s   rmm_cupy_allocator (rmm/rmm.py:198)
200.00% 200.00%   392.3s    392.3s   logit (testy.py:15)
  0.00% 200.00%   0.000s    392.3s   input_to_cuml_array (cuml/common/input_utils.py:341)
100.00% 100.00%   196.2s    196.2s   unique (cupy/_manipulation/add_remove.py:116)
100.00% 100.00%   196.2s    196.2s   logit (testy.py:8)
  0.00% 100.00%   0.000s    196.2s   unique (cupy/_manipulation/add_remove.py:112)
  0.00% 100.00%   0.000s    196.2s   synchronize (numba/cuda/cudadrv/driver.py:1858)
  0.00% 100.00%   0.000s    196.2s   as_cuda_array (numba/cuda/api.py:75)
100.00% 100.00%   196.2s    196.2s   _handle_results (multiprocessing/pool.py:576)
  0.00% 100.00%   0.000s    196.2s   from_cuda_array_interface (numba/cuda/api.py:53)
100.00% 100.00%   196.2s    196.2s   unique (cupy/_manipulation/add_remove.py:118)
  0.00% 100.00%   0.000s    196.2s   _require_cuda_context (numba/cuda/cudadrv/devices.py:224)
100.00% 100.00%   196.2s    196.2s   safe_cuda_api_call (numba/cuda/cudadrv/driver.py:299)
  0.00% 100.00%   0.000s    196.2s   convert_dtype (cuml/common/input_utils.py:550)
100.00% 100.00%   196.2s    196.2s   _handle_tasks (multiprocessing/pool.py:528)
  0.00% 100.00%   0.000s    196.2s   input_to_cuml_array (cuml/common/input_utils.py:296)

It is stuck in LogisticRegression.fit(...), doing dtype conversion?

A second approach I tried was using torch.multiprocessing, and following the guidance in the PyTorch docs about safely sharing CUDA tensors. In this case I have this example:

import torch
from cuml.linear_model import LogisticRegression
from torch import multiprocessing
from torch.cuda import empty_cache

def usage():
    import resource

    r = resource.getrusage(resource.RUSAGE_SELF)
    return f"ru_maxrss: {r.ru_maxrss}\tru_ixrss: {r.ru_ixrss}\tru_idrss: {r.ru_idrss}"

def logit(xy):
    x, y = xy
    with open("loggy3", "a") as loggy:
        print("start", x, y, usage(), file=loggy, flush=True)
    model = LogisticRegression()
    with open("loggy3", "a") as loggy:
        print("constructed logit", usage(), file=loggy, flush=True)
    model = model.fit(x, y)
    with open("loggy3", "a") as loggy:
        print("pre del", usage(), file=loggy, flush=True)
    del xy
    del x
    del y
    empty_cache()
    with open("loggy3", "a") as loggy:
        print("done", usage(), model.coef_, model.intercept_, file=loggy, flush=True)
    return model.coef_, model.intercept_

def iter_examples():
    cuda_dev = torch.device("cuda")
    for idx in range(1024):
        print("idx", idx, flush=True)
        x = torch.rand(4, 2, device=cuda_dev, dtype=torch.float)
        y = torch.rand(4, device=cuda_dev, dtype=torch.float)
        yield x, y

def main():
    multiprocessing.set_start_method("spawn", force=True)
    print("start", usage())
    with multiprocessing.Pool(processes=4) as pool:
        result = pool.imap_unordered(logit, iter_examples())
        print("after imap", usage())
        for res in result:
            print("result", res, usage(), flush=True)

if __name__ == "__main__":
    main()

This gets some results printed, but eventually halts (with 100% cpu for each worker and 100% gpu usage).

$ grep -c start loggy3
28
$ grep -c done loggy3
24

Attached py-spy to one of the worker process gives us:

  %Own   %Total  OwnTime  TotalTime  Function (filename:line)                                                                                                                                                      
100.00% 100.00%   51.98s    52.00s   inner_with_setters (cuml/internals/api_decorators.py:409)
  0.00%   0.00%   0.010s    0.010s   array (cupy/_creation/from_data.py:41)
  0.00%   0.00%   0.010s    0.010s   pop_all (cuml/internals/api_context_managers.py:132)
  0.00%   0.00%   0.000s    0.010s   __enter__ (cuml/internals/api_context_managers.py:239)
  0.00% 100.00%   0.000s    52.00s   spawn_main (multiprocessing/spawn.py:116)
  0.00%   0.00%   0.000s    0.010s   inner (cuml/internals/api_decorators.py:360)
  0.00% 100.00%   0.000s    52.00s   <module> (<string>:1)
  0.00% 100.00%   0.000s    52.00s   run (multiprocessing/process.py:108)
  0.00% 100.00%   0.000s    52.00s   _bootstrap (multiprocessing/process.py:315)
  0.00% 100.00%   0.000s    52.00s   worker (multiprocessing/pool.py:125)
  0.00%   0.00%   0.000s    0.010s   inner (cuml/internals/api_decorators.py:359)
  0.00%   0.00%   0.000s    0.010s   input_to_cuml_array (cuml/common/input_utils.py:341)
  0.00% 100.00%   0.000s    52.00s   _main (multiprocessing/spawn.py:129)
  0.00% 100.00%   0.000s    52.00s   logit (testy3.py:21)

Which I think mainly tells us we're inside non-Python code.

I'm not sure if either of these are bugs, because it's not clear what's supported, hence I'm adding them here.

frankier commented 3 years ago

Relevant previous issue: https://github.com/rapidsai/cuml/issues/3237

BTW the above are both on 0.19 nightly

JohnZed commented 3 years ago

Thanks for the detailed issue, @frankier! A couple of things here:

import cupy as cp
from cuml.linear_model import LogisticRegression
from cuml.raft.common import Handle, Stream
from joblib import Parallel, delayed
import cuml.datasets
import time

def logit(x, y):
    with open("loggy", "a") as loggy:
        print("start", x, y, file=loggy, flush=True)
    stream = Stream()
    handle = Handle()
    # handle.setStream(stream)

    model = LogisticRegression(handle=handle)
    model = model.fit(x, y)
    with open("loggy", "a") as loggy:
        print("pre del", file=loggy, flush=True)
    del stream
    del handle
    with open("loggy", "a") as loggy:
        print("done", model.coef_, model.intercept_, file=loggy, flush=True)
    return model.coef_, model.intercept_

all_delayed = []
n_est = 1024
n_rows = 100
n_cols = 4

for idx in range(n_est):
    # print("idx", idx, flush=True)
    # x = cp.random.rand(4,2, dtype=cp.float32)
    # y = cp.random.rand(4, dtype=cp.float32)
    x, y = cuml.datasets.make_classification(n_rows, n_cols)
    y = y.astype(cp.float32)
    all_delayed.append(delayed(logit)(x, y))

print("starting", flush=True)
t0 = time.time()
with Parallel(n_jobs=16, backend="threading") as parallel:
    print("pool started", flush=True)
    for result in parallel(all_delayed):
        print("result", result, flush=True)

elapsed = time.time() - t0
print(f"Build {n_est} estimators with {n_rows} in {elapsed} s")

So please stay tuned for future updates on using streams! Unfortunately, we don't have a timeline planned in detail yet, but it's definitely a topic of interest.

frankier commented 3 years ago

Thanks for the help. I managed to get both the threading and multiprocessing approaches working and have made a small benchmark based on the type of workload I am trying to run here: https://github.com/frankier/batchlogit . I think the benchmark might have some problems e.g. it's also benchmarking make_classification at the moment, but I think it still gives some kind of an indication of relative performance. What it appears to show is this:

github-actions[bot] commented 3 years ago

This issue has been labeled inactive-30d due to no recent activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed. This issue will be labeled inactive-90d if there is no activity in the next 60 days.

github-actions[bot] commented 2 years ago

This issue has been labeled inactive-90d due to no recent activity in the past 90 days. Please close this issue if no further response or action is needed. Otherwise, please respond with a comment indicating any updates or changes to the original issue and/or confirm this issue still needs to be addressed.