Closed hartikainen closed 4 years ago
Okay, after spending a bit more time looking into the implementation, I believe to have an answer to both of these questions.
1) The discounts should be included in the pcontinues
values passed as inputs to retrace. I initially thought these to be used only for accounting the terminal conditions, but it makes a ton of sense to include the discounts in them as well.
2) This actually seems to be just a typo in one of the comments. The first comment about the recurrent form seems to correctly state that the timesteps are equal for the recursion, and it's just the second comment (which I referred to in my question above) that has an incorrect time index. I can send a PR to fix this.
Hey,
I was looking at the retrace ops provided by trfl and there are a couple of implementation details that seem a bit confusing to me.
It seems like trfl retrace drops the discount terms from the
š¼_Ļ Q(x_t, .)
term. This is in line with the retrace formulation in Equation 13 in MPO paper [1], but is different from Equation 4 in the original retrace paper [2]. I have included a small test case below that shows this. Is this a bug or a conscious choice? Edit: actually, it seems like at least one of the terms is included in the continuation probs.In
retrace_ops._general_off_policy_corrected_multistep_target
comments, it's mentioned thatexp_q_t = š¼_Ļ Q(x_{t+1},.)
andqa_t = Q(x_t, a_t)
, indicating thatexp_q_t
should be one timestep ahead ofqa_t
: https://github.com/deepmind/trfl/blob/e633edbd9d326b8bebc7c7c7d53f37118b48a440/trfl/retrace_ops.py#L252-L253 However, If I understand this correctly, when those values are actually assigned, they come from the same time indices: https://github.com/deepmind/trfl/blob/e633edbd9d326b8bebc7c7c7d53f37118b48a440/trfl/retrace_ops.py#L263-L264 It's possible that thetarget_policy_t
values that are used to index forexp_q_t
somehow account this, but I can't wrap my head around how that would do it. Am I misunderstanding something here or is it possible that these indices are actually off?[1] Abdolmaleki, A., Springenberg, J.T., Tassa, Y., Munos, R., Heess, N. and Riedmiller, M., 2018. Maximum a posteriori policy optimisation. arXiv preprint arXiv:1806.06920. [2] Munos, R., Stepleton, T., Harutyunyan, A. and Bellemare, M., 2016. Safe and efficient off-policy reinforcement learning. In Advances in Neural Information Processing Systems (pp. 1054-1062).
Code related to question 1 (click to expand):
The test case is simplified (e.g. just one action) and I have used a slightly modified version of trfl to make it compatible with tf2, but all the logic should be the correct. ```python import numpy as np import tensorflow as tf from trfl import retrace_ops lambda_ = 0.99 discount = 0.9 Q_values = np.array([ [[2.2], [5.2]], [[7.2], [4.2]], [[3.2], [4.2]], [[2.2], [9.2]]], dtype=np.float32) target_Q_values = np.array([ [[2.], [5.]], [[7.], [4.]], [[3.], [4.]], [[2.], [9.]]], dtype=np.float32) actions = np.array([ [0, 0], [0, 0], [0, 0], [0, 0]]) rewards = np.array([ [1.9, 2.9], [3.9, 4.9], [5.9, 6.9], [np.nan, np.nan], # nan marks entries we should never use. ], dtype=np.float32) pcontinues = np.array([ [0.8, 0.9], [0.7, 0.8], [0.6, 0.5], [np.nan, np.nan]], dtype=np.float32) target_policy_probs = np.array([ [[np.nan] * 1, [np.nan] * 1], [[1.0], [1.0]], [[1.0], [1.0]], [[1.0], [1.0]]], dtype=np.float32) behavior_policy_probs = np.array([ [np.nan, np.nan], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], dtype=np.float32) def retrace_original_v1( lambda_, discount, target_Q_values, actions, rewards, target_policy_probs, behavior_policy_probs): actions = actions[1:, ...] rewards = rewards[:-1, ...] target_policy_probs = target_policy_probs[1:, ...] behavior_policy_probs = behavior_policy_probs[1:, ...] traces = lambda_ * np.minimum( 1.0, target_policy_probs / behavior_policy_probs[..., None]) deltas = ( rewards[..., None] + discount * target_Q_values[1:] - target_Q_values[:-1]) retraces = [] for i in range(tf.shape(traces)[0]): sum_terms = [] for t in range(i, tf.shape(traces)[0]): trace = tf.reduce_prod([ traces[k] for k in range(i + 1, t + 1) ], axis=0) sum_term = discount ** (t - i) * trace * deltas[t] sum_terms.append(sum_term) result = tf.reduce_sum(sum_terms, axis=0) retraces.append(result) retraces = tf.stack(retraces) + target_Q_values[:-1] return retraces output_original_v1 = retrace_original_v1( lambda_, 1.0, target_Q_values, actions, rewards, target_policy_probs, behavior_policy_probs) print(f"output_original_v1:\n{output_original_v1.numpy().round(3)}\n") output_original_discounted_v1 = retrace_original_v1( lambda_, discount, target_Q_values, actions, rewards, target_policy_probs, behavior_policy_probs) print(f"output_original_discounted_v1:\n{output_original_discounted_v1.numpy().round(3)}\n") output_trfl_v1 = retrace_ops.retrace( lambda_, Q_values, target_Q_values, actions, rewards, tf.ones_like(rewards), target_policy_probs, behavior_policy_probs, ).extra.target[..., None] tf.debugging.assert_near(output_original_v1, output_trfl_v1) # succeeds tf.debugging.assert_near(output_original_discounted_v1, output_trfl_v1) # fails ```