ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.35k stars 5.65k forks source link

[rllib] Enabling sample_async on remote workers when using PyTorch framework #10628

Closed roireshef closed 3 years ago

roireshef commented 4 years ago

Describe your feature request

I'm thinking, given these two references below, that there's a way to enable sample_async for PyTorch as well. Since this is a major runtime booster, this feature is going to have a relatively wide impact on PyTorch/RLlib users that run long training sessions.

Thanks!

roireshef commented 4 years ago
ericl commented 4 years ago

Thanks for reporting this. Can you provide a reproduction case with CartPole that doesn't work properly (rllib train --run=A3C --env=CartPole-v0 --torch --config=...)?

roireshef commented 3 years ago

Hi @ericl

I was able to reproduce and track it to a race condition for reading and writing from/to the class variable I'm using to store the value function estimate in my TorchModelV2, let's call it self.last_value. The problem is that by the time a thread runs:

  1. TorchModelV2.forward()
  2. TorchModelV2.value_function()

another thread could override the self.last_value by calling TorchModelV2.forward() again.

To reproduce, run Tune with Ray on local_mode=True (it reproduces without it, but it helps debugging this case) with A3C. Configuration should include: "framework": "torch", "sample_async": True

and the validator in https://github.com/ray-project/ray/blob/master/rllib/agents/a3c/a3c.py#L60 should be commented out, since it overrides the sample_async, and I was already able to remove it and run A3C in Ray v0.7.3 with great success on sample_async mode with pytorch. I had a discussion with @sven1977 about removing this validator since this limitation is not really necessary.

Now, when you do that, end your TorchModelV2 model's forward method with something like:

        ...
        last_value = self.critic_net(shared_embeddings).squeeze(-1)
        self.last_value = last_value

        return self.actor_net(shared_embeddings), state

If you set a conditional breakpoint on the return clause that only breaks whenever the number of samples in the input is greater than 1, it's going to break only on the master thread, when the TorchModelV2.foward() method is called from the A3C loss method. This is where, if you compared last_value (which is only defined in the scope of that specific forward() call) is actually already different from self.last_value (class scope, can be overiden from outside of that specific forward() call) although one line above it we assigned last_value to self.last_value.

This is an issue since the code migration to TorchModelV2, which forces the developer to use a class variable outside the scope of the forward() call. With TorchModel (original version) the forward was "atomic" in that sense that it returned the value estimates with the rest of the outputs and this was avoided.

I believe the threads that are racing are running the AsyncSampler.run() and the RolloutWorker.compute_gradients() where the former always runs TorchModelV2.forward() with a single sample, and the latter accumulates episodes at the RolloutWorker and sends them as a batch to TorchModelV2.forward(), which is one way this logic could fail on different tensor shapes at the loss function calculation.

roireshef commented 3 years ago

@ericl , @sven1977 - I tried to look for a use case where value_function() is called independent of the forward() call, but it seems like they are always called together, one after the other. I assume you were trying to save computation by decoupling them...? But the fact is that same design that decouples between the two really limits the performance of PyTorch + A3C.

Is it worth recoupling them together into forward()?

I'm not sure how clean is it, but you could make value computation lazy by having forward returning a Callable (lambda function) that an upstream function can choose to call or not. This way foward's outputs still agree, even if the Callable is materialized in a later stage, it references the inputs from that same foward call.

sven1977 commented 3 years ago

@roireshef Here is a PR that may fix this issue: https://github.com/ray-project/ray/pull/11935 Waiting for tests to pass.

roireshef commented 3 years ago

Hi @ericl @sven1977 - I have implemented a solution based on (1) from https://github.com/ray-project/ray/pull/11935#issuecomment-730644787

I don't think it uncovers the full potential of the async sampler runtime improvements, but I do see a 20%-50% speedup when switching from sample_async=False to sample_async=True with this feature enabled (depending on how many envs_per_worker I use).

Could you please review? https://github.com/ray-project/ray/pull/12922

stale[bot] commented 3 years ago

Hi, I'm a bot from the Ray team :)

To help human contributors to focus on more relevant issues, I will automatically add the stale label to issues that have had no activity for more than 4 months.

If there is no further activity in the 14 days, the issue will be closed!

You can always ask for help on our discussion forum or Ray's public slack channel.

stale[bot] commented 3 years ago

Hi again! The issue will be closed because there has been no more activity in the 14 days since the last message.

Please feel free to reopen or open a new issue if you'd still like it to be addressed.

Again, you can always ask for help on our discussion forum or Ray's public slack channel.

Thanks again for opening the issue!