Closed joeryjoery closed 1 year ago
Perhaps to add to this, could the D4PG agent be refactored to use the NStepTransitionAdder
instead of the StructuredWriter
? This would prevent silent bugs like this.
I don't know if we're talking of the same thing but from Sutton "Reinforcement learning, An introduction":
It seems to me it should be:
G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1
So code is correct.
I don't know if we're talking of the same thing but from Sutton "Reinforcement learning, An introduction": It seems to me it should be:
G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1
So code is correct.
Yes, this is what it should compute, but the code computes this incorrectly.
+1 on this issue.
I tried plugging in the function above into the test cases of the adders and it's not computing the correct quantity.
Regarding not using rlax. I think the issue is that the transition adder (and the new structured writer configured to insert transitions) will produce sequences of varying length (at the start of the episode and end) so you can't actually batch the computation.
Hi everyone!
Thanks joeryjoery for bringing this to my attention and apologies it took so long to get to this in my stack.
You're right, for nearly a year now, the JAX version of D4PG has been incorrectly computing n-step returns. I've just pushed a fix to this, which uses the forward scanning of rewards/discounts (I agree with you that this is more intuitive). This code is now being tested using the existing test cases for the NStepTransitionAdder so, barring uncovered corner cases, the n-step computations should now be correct and verified. Please do confirm this on your use-case and re-open the issue if the current code does not cover it.
Thanks again for the catch and engagement, and sincere apologies for any lost time/effort.
Bobak
Hi, can I ask why in the reference agents:
acme.agents.jax.d4pg.builder.py
in this function on line 52, there is areverse=True
intf.scan
? Also why isn't this implemented ind4pg.learner
using a call torlax
? At the moment it is called as part of an awkward post-processing call within themake_dataset_iterator
function ofD4PGBuilder
.For the discount this doesn't really matter since scalar multiplication is commutative, but as it simultaneously compounds the rewards shouldn't this (code at the bottom) compute:
G_t:t+n = (gamma)^n r_t+1 + (gamma)^n-1 r_t+2 + ... + gamma * r_t+n + r_t+n+1
Instead of
G_t:t+n = r_t+1 + gamma r_t+2 + (gamma)^2 r_t+3 + ... + (gamma)^n * r_t+n+1
I've added the code to the function below. Maybe I'm overlooking something, perhaps the
flat_trajectory
stores transitions in reverse time-order by default for some reason?