ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.29k stars 5.63k forks source link

[RLlib] New Env Runners cannot do the inference on the GPU #47874

Open brieyla1 opened 2 days ago

brieyla1 commented 2 days ago

What happened + What you expected to happen

When using the new enable_env_runner_and_connector_v2 feature in RLlib, the env_runners do not have access to the GPU for inference on the env_runner ray actor. When disabled, the old workers function correctly and can utilize the GPU for inference (at least the GPU is used at 97%). However, the old runners are not fully compatible with the new RLModule on my side, specifically when using LSTM (seqlen in the rnn-sequencing) models. This shouldn't be an issue since everything has already been migrated to the new stack. The CPU is fine for smaller models, but larger models would benefit from GPU computing.

Thoughts

Could this lack of GPU access be disabled by design? If so, what is the reasoning behind this choice in an On-Policy Algorithm? The GPU remains idle during most of the experience-gathering process, which could be better utilized for performance improvement.

Versions / Dependencies

ray: 2.37.0 OS: WSL2 Ubuntu 22 GPU: RTX 4090 CPU: I9 14900K

Reproduction script

from ray.rllib.algorithms.ppo import PPOConfig

config = (
    PPOConfig()
    .framework(framework='torch')
    .environment(env='CartPole-v1')
    .api_stack(
        enable_rl_module_and_learner=True,
        # When this is enabled, the env_runners can't have access to the GPU for inference.
        enable_env_runner_and_connector_v2=True,
    )
)

Issue Severity

High: It blocks me from completing my task.

brieyla1 commented 1 day ago

More tests have been done using ray3.0.0dev, with the same behaviour.

I've also tried compiling the model in the TorchRLModule, Unfortunately It didn't seem to help. I've noticed that the Algorithm doesn't doesn't seem to pickup the AlgorithmConfig's torch_compile_worker flags that the RLModule needs to be compiled with the new API, so it required for me to compile it manually:

# ray/rllib/core/rl_module/torch/torch_rl_module.py
class TorchRLModule(nn.Module, RLModule):
    ...
    def __init__(self, *args, **kwargs) -> None:
        ... 
        if torch.cuda.is_available():
            self.compile(TorchCompileConfig(torch_dynamo_mode="default", torch_dynamo_backend="inductor"))

Adding a custom callback for on_algorithm_init that switches the nn.module to the GPU device for each worker in the env_runner_group doesn't seem to work either as some components are still on the CPU, making Torch complain and throw.

I've found examples like ray/rllib/benchmarks/torch_compile/run_ppo_with_inference_bm.py which seem to be outdated.

Any Guess on what could be the culprit ?

brieyla1 commented 1 day ago

I got it to work with terrible performance when doing the following (3x decrease in throughput)

# ray/rllib/core/rl_module/torch/torch_rl_module.py
class TorchRLModule(nn.Module, RLModule):
    ...
    def __init__(self, *args, **kwargs) -> None:
        ...
        self.use_gpu_if_available()

    def use_gpu_if_available(self):
        if torch.cuda.is_available():
            self.to(device="cuda")
            torch.set_float32_matmul_precision('high') # enables TF32 rather than standard FP32
            self.compile(TorchCompileConfig(torch_dynamo_mode="default", torch_dynamo_backend="inductor"))
            for module in self.modules():
                module.to(device="cuda")

and

# ray/rllib/algorithms/ppo/torch/ppo_torch_rl_module.py
class PPOTorchRLModule(TorchRLModule, PPORLModule):
    ...
    def _forward_inference(self, batch: Dict[str, Any]) -> Dict[str, Any]:
        output = {}

        device = next(self.parameters()).device
        batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
        ...

No wonder the performance is crap, My fix is dirty.

It also doesn't work with LSTMs because of the hidden weights not being sent to the GPU. I'm guessing this is a bug, or a work in progress as there is no real implementation of the compilation by the Algorithm class, and that the GPU is picked up nowhere for now. I'll stick with CPUs for now, but I would love to get some more information about the possibility of using GPUs for inference.

having some kind of middle-man worker that batches up inference tasks would be extremely efficient, possibly decoupling the performance of PPO's sampling. I'm especially talking about a project I've helped called SEEDRL paper or a new project called SRL paper that uses another worker that batches and helps with the inference (a bit like a remote RLModule-ish)

SRL Architecture from the paper SRL Architecture