ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.04k stars 5.59k forks source link

[Feature] Avoid recompilation using JIT #23664

Open gao462 opened 2 years ago

gao462 commented 2 years ago

Search before asking

Description

Ray remote becomes extremely slow when I have just-in-time functions (by Numba) inside. The time costs of my testing code (attached later) are:

----------- ------------ -
  Implement      Seconds  
----------- ------------ -
    Default 0.1877801418 ✓
      Numba 0.0030972958 ✓
        Ray 0.0730540752 ✓
Numba + Ray 1.3743853569 ✓
----------- ------------ -

Number of cpus are ray.init(num_cpus=4). Default is raw Python implementation, Numba is using only JIT (compile by pseudo input before counting time), Ray is paralleling the function, and Numba + Ray is paralleling the JIT function. We can see that Numba is working fine (extremely fast after compilation), Ray is also working fine (obviously scale down to 1/4 with some overhead). However, Numba + Ray is not working as expected since it is extremely slow, but it should be better than Numba only. I think the reason is that Numba compiles the JIT function for every Ray process. I wonder if it is possible to compile JIT function only once before the Ray processes starts. (I have tried to put the JIT function as a ref objectm but it has no influence.) I am aware of ramba (a library combining Ray and Numba), but it seems that it is only a NumPy replacement with parallelism.

Here is the testing code I use. It is implementation for multiple sampling from given HMM model.

R"""
"""
#
import numpy as onp
import numpy.typing as onpt
import numba
import ray
import time
from typing import Union, Sequence, Tuple

#
NPY_FLOATS = onpt.NDArray[onp.float64]
NPY_INTS = onpt.NDArray[onp.int64]
NPY_BOOLS = onpt.NDArray[onp.bool8]
NPY_GENERICS = Union[NPY_FLOATS, NPY_INTS, NPY_BOOLS]
NPY_GROUP = Sequence[NPY_GENERICS]
NPY_MEMORY = Sequence[NPY_GROUP]

def multi_sample_hmm(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, trace_num_times: Sequence[int], seed: int,
    /,
) -> Tuple[Sequence[NPY_INTS], Sequence[NPY_INTS]]:
    R"""
    Sample multiple random traces of HMM-like.
    """
    #
    trace_hiddens = []
    trace_outputs = []
    for (i, num_times) in enumerate(trace_num_times):
        #
        (hiddens, outputs) = (
            a_sample_hmm_from_cdf(
                cdf_initials, cdf_transitions, cdf_emissions, num_times,
                seed + i,
            )
        )
        trace_hiddens.append(hiddens)
        trace_outputs.append(outputs)
    return (trace_hiddens, trace_outputs)

def a_sample_int_from_cdf(cdf: NPY_FLOATS, /) -> NPY_INTS:
    R"""
    Sample an integer from cumulative distribution.
    """
    #
    unif = onp.random.uniform(0.0, 1.0)
    (hits,) = onp.where(unif <= cdf)
    return onp.min(hits) # type: ignore[no-any-return]

def a_sample_hmm_from_cdf(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, num_times: int, seed: int,
    /,
) -> Tuple[NPY_INTS, NPY_INTS]:
    R"""
    Sample a random trace of HMM-like from HMM cumulative distributions.
    """
    # Control randomness.
    # Avoid using `RandomState` for numba support.
    onp.random.seed(seed)

    #
    hiddens = onp.zeros((num_times,), dtype=onp.int64)
    outputs = onp.zeros((num_times,), dtype=onp.int64)

    #
    state = a_sample_int_from_cdf(cdf_initials)
    hiddens[0] = state
    outputs[0] = a_sample_int_from_cdf(cdf_emissions[state])

    #
    for t in range(1, num_times):
        #
        state = a_sample_int_from_cdf(cdf_transitions[state])
        hiddens[t] = state
        outputs[t] = a_sample_int_from_cdf(cdf_emissions[state])
    return (hiddens, outputs)

def multi_sample_hmm_njit(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, trace_num_times: Sequence[int], seed: int,
    /,
) -> Tuple[Sequence[NPY_INTS], Sequence[NPY_INTS]]:
    R"""
    Sample multiple random traces of HMM-like.
    """
    #
    trace_hiddens = []
    trace_outputs = []
    for (i, num_times) in enumerate(trace_num_times):
        #
        (hiddens, outputs) = (
            a_sample_hmm_from_cdf_njit(
                cdf_initials, cdf_transitions, cdf_emissions, num_times,
                seed + i,
            )
        )
        trace_hiddens.append(hiddens)
        trace_outputs.append(outputs)
    return (trace_hiddens, trace_outputs)

@numba.njit # type: ignore[misc]
def a_sample_int_from_cdf_njit(cdf: NPY_FLOATS, /) -> NPY_INTS:
    R"""
    Sample an integer from cumulative distribution.
    """
    #
    unif = onp.random.uniform(0.0, 1.0)
    (hits,) = onp.where(unif <= cdf)
    return onp.min(hits) # type: ignore[no-any-return]

@numba.njit # type: ignore[misc]
def a_sample_hmm_from_cdf_njit(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, num_times: int, seed: int,
    /,
) -> Tuple[NPY_INTS, NPY_INTS]:
    R"""
    Sample a random trace of HMM-like from HMM cumulative distributions.
    """
    # Control randomness.
    # Avoid using `RandomState` for numba support.
    onp.random.seed(seed)

    #
    hiddens = onp.zeros((num_times,), dtype=onp.int64)
    outputs = onp.zeros((num_times,), dtype=onp.int64)

    #
    state = a_sample_int_from_cdf_njit(cdf_initials)
    hiddens[0] = state
    outputs[0] = a_sample_int_from_cdf_njit(cdf_emissions[state])

    #
    for t in range(1, num_times):
        #
        state = a_sample_int_from_cdf_njit(cdf_transitions[state])
        hiddens[t] = state
        outputs[t] = a_sample_int_from_cdf_njit(cdf_emissions[state])
    return (hiddens, outputs)

@ray.remote
def multi_sample_hmm_ray_proc(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, num_times: int, seed: int,
    /,
) -> Tuple[NPY_INTS, NPY_INTS]:
    R"""
    Single process of sampling multiple random traces of HMM-like.
    """
    #
    (hiddens, outputs) = (
        a_sample_hmm_from_cdf(
            cdf_initials, cdf_transitions, cdf_emissions, num_times,
            seed,
        )
    )
    return (hiddens, outputs)

def multi_sample_hmm_ray(
    oref_cdf_initials: NPY_FLOATS, oref_cdf_transitions: NPY_FLOATS,
    oref_cdf_emissions: NPY_FLOATS, trace_num_times: Sequence[int], seed: int,
    /,
) -> Tuple[Sequence[NPY_INTS], Sequence[NPY_INTS]]:
    R"""
    Sample multiple random traces of HMM-like.
    """
    #
    ofut = (
        [
            multi_sample_hmm_ray_proc.remote(
                oref_cdf_initials, oref_cdf_transitions, oref_cdf_emissions,
                num_times, seed + i,
            )
            for (i, num_times) in enumerate(trace_num_times)
        ]
    )
    obuf = ray.get(ofut)

    #
    trace_hiddens = [hiddens for (hiddens, _) in obuf]
    trace_outputs = [outputs for (_, outputs) in obuf]
    return (trace_hiddens, trace_outputs)

@ray.remote
def multi_sample_hmm_njit_ray_proc(
    cdf_initials: NPY_FLOATS, cdf_transitions: NPY_FLOATS,
    cdf_emissions: NPY_FLOATS, num_times: int, seed: int,
    /,
) -> Tuple[NPY_INTS, NPY_INTS]:
    R"""
    Single process of sampling multiple random traces of HMM-like.
    """
    #
    (hiddens, outputs) = (
        a_sample_hmm_from_cdf_njit(
            cdf_initials, cdf_transitions, cdf_emissions, num_times,
            seed,
        )
    )
    return (hiddens, outputs)

def multi_sample_hmm_njit_ray(
    oref_cdf_initials: NPY_FLOATS, oref_cdf_transitions: NPY_FLOATS,
    oref_cdf_emissions: NPY_FLOATS, trace_num_times: Sequence[int], seed: int,
    /,
) -> Tuple[Sequence[NPY_INTS], Sequence[NPY_INTS]]:
    R"""
    Sample multiple random traces of HMM-like.
    """
    #
    ofut = (
        [
            multi_sample_hmm_njit_ray_proc.remote(
                oref_cdf_initials, oref_cdf_transitions, oref_cdf_emissions,
                num_times, seed + i,
            )
            for (i, num_times) in enumerate(trace_num_times)
        ]
    )
    obuf = ray.get(ofut)

    #
    trace_hiddens = [hiddens for (hiddens, _) in obuf]
    trace_outputs = [outputs for (_, outputs) in obuf]
    return (trace_hiddens, trace_outputs)

def test(
    num_hiddens: int, num_outputs: int, num_times_mean: int, num_samples: int,
    seed: int,
    /,
) -> None:
    R"""
    Test.
    """
    #
    nprng = onp.random.RandomState(seed)
    initials = nprng.uniform(0.0, 1.0, (num_hiddens,))
    transitions = nprng.uniform(0.0, 1.0, (num_hiddens, num_hiddens))
    emissions = nprng.uniform(0.0, 1.0, (num_hiddens, num_outputs))
    initials = initials / onp.sum(initials)
    transitions = transitions / onp.sum(transitions, axis=1, keepdims=True)
    emissions = emissions / onp.sum(emissions, axis=1, keepdims=True)
    trace_num_times = (
        [int(nprng.normal(num_times_mean, 3.0)) for _ in range(num_samples)]
    )

    #
    cdf_initials = onp.cumsum(initials)
    cdf_transitions = onp.cumsum(transitions, axis=1)
    cdf_emissions = onp.cumsum(emissions, axis=1)

    # Compile first.
    multi_sample_hmm_njit(
        cdf_initials, cdf_transitions, cdf_emissions, [num_times_mean], seed,
    )

    # Allocate ray resources.
    ray.init(num_cpus=4)

    #
    oref_cdf_initials = ray.put(cdf_initials)
    oref_cdf_transitions = ray.put(cdf_transitions)
    oref_cdf_emissions = ray.put(cdf_emissions)

    #
    time_start = time.time()
    (trace_hiddens_def, trace_outputs_def) = (
        multi_sample_hmm(
            cdf_initials, cdf_transitions, cdf_emissions, trace_num_times,
            seed,
        )
    )
    time_elapsed_def = time.time() - time_start

    #
    time_start = time.time()
    (trace_hiddens_njit, trace_outputs_njit) = (
        multi_sample_hmm_njit(
            cdf_initials, cdf_transitions, cdf_emissions, trace_num_times,
            seed,
        )
    )
    time_elapsed_njit = time.time() - time_start

    #
    time_start = time.time()
    (trace_hiddens_ray, trace_outputs_ray) = (
        multi_sample_hmm_ray(
            oref_cdf_initials, oref_cdf_transitions, oref_cdf_emissions,
            trace_num_times, seed,
        )
    )
    time_elapsed_ray = time.time() - time_start

    #
    time_start = time.time()
    (trace_hiddens_njit_ray, trace_outputs_njit_ray) = (
        multi_sample_hmm_njit_ray(
            oref_cdf_initials, oref_cdf_transitions, oref_cdf_emissions,
            trace_num_times, seed,
        )
    )
    time_elapsed_njit_ray = time.time() - time_start

    # Shutdown ray resources.
    ray.shutdown()

    #
    maxlen1 = 11
    maxlen2 = 12
    maxlen3 = 1
    print("-" * maxlen1, "-" * maxlen2, "-" * maxlen3)
    print(
        "{:>{:d}s} {:>{:d}s} {:>{:d}s}"
        .format("Implement", maxlen1, "Seconds", maxlen2, "", maxlen3),
    )
    print("-" * maxlen1, "-" * maxlen2, "-" * maxlen3)
    for (name, time_elapsed, trace_hiddens, trace_outputs) in (
        zip(
            ["Default", "Numba", "Ray", "Numba + Ray"],
            [
                time_elapsed_def, time_elapsed_njit, time_elapsed_ray,
                time_elapsed_njit_ray,
            ],
            [
                trace_hiddens_def, trace_hiddens_njit, trace_hiddens_ray,
                trace_hiddens_njit_ray,
            ],
            [
                trace_outputs_def, trace_outputs_njit, trace_outputs_ray,
                trace_outputs_njit_ray,
            ],
        )
    ):
        #
        accurate = True
        for (array, array_def) in (
            zip(
                list(trace_hiddens) + list(trace_outputs),
                list(trace_hiddens_def) + list(trace_outputs_def),
            )
        ):
            #
            if onp.any(array != array_def):
                #
                accurate = False
                break

        #
        print(
            "{:>{:d}s} {:>{:d}s} {:>{:d}s}".format(
                name, maxlen1, "{:.10f}".format(time_elapsed), maxlen2,
                "✓" if accurate else "✗", maxlen3,
            ),
        )
    print("-" * maxlen1, "-" * maxlen2, "-" * maxlen3)

def main() -> None:
    R"""
    Main execution.
    """
    #
    test(3, 4, 300, 25, 47)

if __name__ == "__main__":
    #
    main()

Use case

It can further improve the efficency per Ray process. Also, JIT is becoming popular in near recent as ML efficency improvement toolkit, and it will be nice for Ray to have a better support with JIT.

Related issues

No response

Are you willing to submit a PR?

jhallard commented 1 year ago

Hey OP did you ever figure out a mitigation to this? This is a huge issue for us, we're seeing 50x slowdowns using Ray despite the improvements in parallelism.