google-deepmind / rlax

https://rlax.readthedocs.io
Apache License 2.0
1.23k stars 85 forks source link

Questions re: RLax Value Learning ? #9

Open RylanSchaeffer opened 3 years ago

RylanSchaeffer commented 3 years ago

Hi! I have several questions/requests regarding value learning https://github.com/deepmind/rlax/blob/master/rlax/_src/value_learning.py

  1. If I want to use the _quantile_regression_loss without the Huber aspect, does setting huber_param equal to 0 accomplish this? That's my understanding, but I'd like to check :)

  2. I'm interested in exploring expectile regression-naive DQN and expectile regression DQN, but code for these two related algorithms don't seem to exist. Is that correct? If code does exist, could you point me in the right direction?

  3. If functions for expectile regression indeed do not exist, what would be the most straightforward way to implement them? If I just want expectile regression-naive, I'm thinking I would need to do the following:

a. Copy _quantile_regression_loss() to create _expectile_regression_loss(), replacing the quantile loss with expectile loss b. Copy quantile_q_learning() to create expectile_q_learning, replacing the _quantile_regression_loss() call with a _expectile_regression_loss() call

Is this correct? If so, would you be open to PRs?

  1. Expectile regression is a little trickier, due to its imputation strategy. Are you planning on implementing & releasing that? If not, how would you recommend implementing that?
RylanSchaeffer commented 3 years ago

To be clear about 3a, I specifically mean

def _expectile_naive_regression_loss(
    dist_src: ArrayLike,
    tau_src: ArrayLike,
    dist_target: ArrayLike,
    huber_param: float = 0.
) -> ArrayOrScalar:
  """Compute ER-naive loss between two discrete expectile-valued distributions.

  See "Statistics and Samples in Distributional Reinforcement Learning" by
  Rowland et al. (http://proceedings.mlr.press/v97/rowland19a).

  Args:
    dist_src: source probability distribution.
    tau_src: source distribution probability thresholds.
    dist_target: target probability distribution.
    huber_param: ignored

  Returns:
    Expectile (naive) regression loss.
  """
  base.rank_assert([dist_src, tau_src, dist_target], 1)
  base.type_assert([dist_src, tau_src, dist_target], float)

  # Calculate expectile error.
  delta = dist_target[None, :] - dist_src[:, None]
  delta_neg = (delta < 0.).astype(jnp.float32)
  delta_neg = jax.lax.stop_gradient(delta_neg)
  weight = jnp.abs(tau_src[:, None] - delta_neg)

  loss = jnp.abs(jnp.square(delta))
  loss *= weight

  # Average over target-samples dimension, sum over src-samples dimension.
  return jnp.sum(jnp.mean(loss, axis=-1))
RylanSchaeffer commented 3 years ago

And 3b, I mean

def expectile_naive_q_learning(
    dist_q_tm1: ArrayLike,
    tau_q_tm1: ArrayLike,
    a_tm1: ArrayOrScalar,
    r_t: ArrayOrScalar,
    discount_t: ArrayOrScalar,
    dist_q_t_selector: ArrayLike,
    dist_q_t: ArrayLike,
    huber_param: float = 0.
) -> ArrayOrScalar:
  """Implements Q-learning for expectile-valued Q distributions.

  See "Statistics and Samples in Distributional Reinforcement Learning" by
  Rowland et al. (http://proceedings.mlr.press/v97/rowland19a).

  Args:
    dist_q_tm1: Q distribution at time t-1.
    tau_q_tm1: Q distribution probability thresholds.
    a_tm1: action index at time t-1.
    r_t: reward at time t.
    discount_t: discount at time t.
    dist_q_t_selector: Q distribution at time t for selecting greedy action in
      target policy. This is separate from dist_q_t as in Double Q-Learning, but
      can be computed with the target network and a separate set of samples.
    dist_q_t: target Q distribution at time t.
    huber_param: Huber loss parameter, defaults to 0 (no Huber loss).

  Returns:
    Quantile regression Q learning loss.
  """
  base.rank_assert([
      dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t
  ], [2, 1, 0, 0, 0, 2, 2])
  base.type_assert([
      dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t
  ], [float, float, int, float, float, float, float])

  # Only update the taken actions.
  dist_qa_tm1 = dist_q_tm1[:, a_tm1]

  # Select target action according to greedy policy w.r.t. dist_q_t_selector.
  q_t_selector = jnp.mean(dist_q_t_selector, axis=0)
  a_t = jnp.argmax(q_t_selector)
  dist_qa_t = dist_q_t[:, a_t]

  # Compute target, do not backpropagate into it.
  dist_target = r_t + discount_t * dist_qa_t
  dist_target = jax.lax.stop_gradient(dist_target)

  return _expectile_naive_regression_loss(
      dist_qa_tm1, tau_q_tm1, dist_target, huber_param)
mtthss commented 3 years ago

If I want to use the _quantile_regression_loss without the Huber aspect, does setting huber_param equal to 0 accomplish this?

Yes that is correct.

code for these two related algorithms don't seem to exist. Is that correct? If code does exist, could you point me in the right direction?

Correct, rlax doesn't currently implement them.

If so, would you be open to PRs?

Yes I'd welcome a PR adding these!

RylanSchaeffer commented 3 years ago

@mtthss regarding the PR, what do you think of my two comments above for implementing expectile regression-naive?

mtthss commented 3 years ago

I asked one of the original authors of that paper to provide comments on the proposed implementation.

RylanSchaeffer commented 3 years ago

I kept the Huber loss parameter for consistency with the quantile regression API, but the parameter is never used, so maybe it should be excluded.

Mark-Rowland commented 3 years ago

Hi Rylan and Matteo,

Will Dabney and I have taken a look through expectile_naive_regression_loss and expectile_naive_q_learning, and the code looks correct to us.

One comment is that expectile_naive_regression_loss could be reused as part of an ER-DQN (non-naive) implementation (in which case dist_target would be a vector of samples, rather than expectiles), so expectile_naive_regression_loss could be renamed to expectile_regression_loss, with the understanding that the elements of dist_target may be either expectiles or samples. If going down these lines, the docstring could be updated to reflect what semantics of inputs are allowed for both dist_target and dist_src.

A couple of other minor comments: the jnp.abs at the end of expectile_naive_regression_loss can be removed, and the end of the docstring for expectile_naive_q_learning should update "Quantile" to "Expectile". Also as Rylan mentions, the huber_param input to both functions could also safely be removed.

Thanks!