ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
32.95k stars 5.58k forks source link

DQN Minibatch Option #8870

Open houcharlie opened 4 years ago

houcharlie commented 4 years ago

Describe your feature request

Would it be possible to allow gradient accumulation for DQN? Or is there an algorithmic reason why huge batches for gradient calculation aren't useful for DQN?

ericl commented 4 years ago

Sure. I don't know of any such reason, though I would guess it would only be helpful if the model activations are using an huge amount of memory (i.e., much bigger than your typical RL model).

I haven't tried this out but the patch would be something like the following:

diff --git a/rllib/agents/dqn/dqn.py b/rllib/agents/dqn/dqn.py
index 76bc21817..d613ecef7 100644
--- a/rllib/agents/dqn/dqn.py
+++ b/rllib/agents/dqn/dqn.py
@@ -276,7 +276,10 @@ def execution_plan(workers, config):
     post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
     replay_op = Replay(local_buffer=local_replay_buffer) \
         .for_each(lambda x: post_fn(x, workers, config)) \
-        .for_each(TrainOneStep(workers)) \
+        .for_each(ComputeGradients(workers))  \
+        .batch(num_microbatches)  \
+        .for_each(AverageGradients())  \
+        .for_each(ApplyGradients(workers))) \
         .for_each(update_prio) \
         .for_each(UpdateTargetNetwork(
             workers, config["target_network_update_freq"]))