google-deepmind / acme

A library of reinforcement learning components and agents
Apache License 2.0
3.52k stars 426 forks source link

D4PG jax agent computes n-step learning target in reverse? #292

Closed joeryjoery closed 1 year ago

joeryjoery commented 1 year ago

Hi, can I ask why in the reference agents: acme.agents.jax.d4pg.builder.py in this function on line 52, there is a reverse=True in tf.scan? Also why isn't this implemented in d4pg.learner using a call to rlax? At the moment it is called as part of an awkward post-processing call within the make_dataset_iterator function of D4PGBuilder.

For the discount this doesn't really matter since scalar multiplication is commutative, but as it simultaneously compounds the rewards shouldn't this (code at the bottom) compute:

G_t:t+n = (gamma)^n r_t+1 + (gamma)^n-1 r_t+2 + ... + gamma * r_t+n + r_t+n+1

Instead of

G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1

I've added the code to the function below. Maybe I'm overlooking something, perhaps the flat_trajectory stores transitions in reverse time-order by default for some reason?

def _as_n_step_transition(flat_trajectory: reverb.ReplaySample,
                          agent_discount: float) -> reverb.ReplaySample:
  """Compute discounted return and total discount for N-step transitions... """
  trajectory = flat_trajectory.data

  def compute_discount_and_reward(
      state: types.NestedTensor,
      discount_and_reward: types.NestedTensor) -> types.NestedTensor:
    compounded_discount, discounted_reward = state
    return (agent_discount * discount_and_reward[0] * compounded_discount,
            discounted_reward + discount_and_reward[1] * compounded_discount)  # Problem on this line 1/2

  initializer = (tf.constant(1, dtype=tf.float32),
                 tf.constant(0, dtype=tf.float32))
  elems = tf.stack((trajectory.discount, trajectory.reward), axis=-1)
  total_discount, n_step_return = tf.scan(
      compute_discount_and_reward, elems, initializer, reverse=True)  # Problem on this line 2/2
  return reverb.ReplaySample(
      info=flat_trajectory.info,
      data=types.Transition(
          observation=tree.map_structure(lambda x: x[0],
                                         trajectory.observation),
          action=tree.map_structure(lambda x: x[0], trajectory.action),
          reward=n_step_return[0],
          discount=total_discount[0],
          next_observation=tree.map_structure(lambda x: x[-1],
                                              trajectory.observation),
          extras=tree.map_structure(lambda x: x[0], trajectory.extras)))
joeryjoery commented 1 year ago

Perhaps to add to this, could the D4PG agent be refactored to use the NStepTransitionAdder instead of the StructuredWriter? This would prevent silent bugs like this.

Jogima-cyber commented 1 year ago

I don't know if we're talking of the same thing but from Sutton "Reinforcement learning, An introduction":

Capture d’écran 2023-05-27 à 17 49 06

It seems to me it should be: G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1 So code is correct.

joeryjoery commented 1 year ago

I don't know if we're talking of the same thing but from Sutton "Reinforcement learning, An introduction": Capture d’écran 2023-05-27 à 17 49 06 It seems to me it should be: G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1 So code is correct.

Yes, this is what it should compute, but the code computes this incorrectly.

ethanluoyc commented 1 year ago

+1 on this issue.

I tried plugging in the function above into the test cases of the adders and it's not computing the correct quantity.

Regarding not using rlax. I think the issue is that the transition adder (and the new structured writer configured to insert transitions) will produce sequences of varying length (at the start of the episode and end) so you can't actually batch the computation.

bshahr commented 1 year ago

Hi everyone!

Thanks joeryjoery for bringing this to my attention and apologies it took so long to get to this in my stack.

You're right, for nearly a year now, the JAX version of D4PG has been incorrectly computing n-step returns. I've just pushed a fix to this, which uses the forward scanning of rewards/discounts (I agree with you that this is more intuitive). This code is now being tested using the existing test cases for the NStepTransitionAdder so, barring uncovered corner cases, the n-step computations should now be correct and verified. Please do confirm this on your use-case and re-open the issue if the current code does not cover it.

Thanks again for the catch and engagement, and sincere apologies for any lost time/effort.

Bobak