harukaki / brl

reinforcement learning for bridge
Apache License 2.0
8 stars 1 forks source link

Validation of the entropy regularization term for illegal action #4

Closed harukaki closed 1 year ago

harukaki commented 1 year ago

entropy正則化項の実装を検証する

harukaki commented 1 year ago

workspace/test_entropy.py にて検証 二つの方法でmaskしたpolicyに対するentropyを実装

def entropy_from_dif(logits, mask):
    """非合法手をmaskしたエントロピーを定義から計算"""
    logits = logits + jnp.finfo(np.float64).min * (~mask)
    log_probs = jax.nn.log_softmax(logits)
    probs = jax.nn.softmax(logits)
    entropy = jnp.array(0, dtype=jnp.float32)
    for i in range(38):
        entropy = jax.lax.cond(
            mask[i], lambda: entropy + log_probs[i] * probs[i], lambda: entropy
        )
    return (
        -entropy,
        logits,
        probs,
        log_probs,
    )
def entropy_from_distrax(logits, mask):
    """非合法手をmaskしたエントロピーをdistraxを用いて計算"""
    illegal_action_masked_logits = logits + jnp.finfo(np.float64).min * (~mask)
    illegal_action_masked_pi = distrax.Categorical(logits=illegal_action_masked_logits)
    return (
        illegal_action_masked_pi.entropy(),
        illegal_action_masked_logits,
        illegal_action_masked_pi.probs,
    )
harukaki commented 1 year ago

実行結果

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Calc from difinition
entropy: 3.860790729522705
logits: [  1.   1. -inf  -2.]
probs: [0.48785555 0.48785555 0.         0.0242889 ]
log_probs: [-0.7177359 -0.7177359       -inf -3.7177358]

entropy: -3.860790729522705
grad: [ 0.03554842  0.03554842  0.         -0.07109684]

Calc from distrax
entropy: 0.7906026244163513
logits: [  1.   1. -inf  -2.]
probs: [0.48785555 0.48785555 0.         0.0242889 ]

entropy: -0.7906026244163513
grad: [ 0.03554843  0.03554843  0.         -0.07109685]

Loading dds results from dds_results/train_000.npy ...
Calc from difinition
entropy: 3.569485902786255
logits: [-0.08427836        -inf        -inf  0.15760802  0.28321218 -0.05905859
  0.02925954  0.13305202 -0.13916473 -0.18824172  0.08766825 -0.00051884
  0.1281999   0.09974058  0.00066989  0.01761878 -0.2606694   0.00373229
  0.06415141 -0.12595929 -0.10677278 -0.39286864 -0.02238724 -0.02819624
  0.38166225  0.13034666 -0.16226953  0.21659629 -0.09879474  0.07691625
 -0.08494915 -0.21141459 -0.13202426 -0.07100749 -0.22356787  0.11439289
  0.30908436 -0.17015982]
probs: [0.02541129 0.         0.         0.03236508 0.0366966  0.0260603
 0.0284666  0.03157999 0.02405414 0.02290214 0.03017881 0.02763141
 0.03142714 0.03054535 0.02766427 0.02813715 0.02130203 0.02774912
 0.02947738 0.02437389 0.02484606 0.01866413 0.02703371 0.02687713
 0.0404932  0.03149467 0.02350475 0.03433167 0.02504507 0.02985607
 0.02539425 0.02237753 0.02422651 0.02575077 0.02210722 0.03099621
 0.0376584  0.02332002]
log_probs: [-3.6725616       -inf       -inf -3.4306755 -3.3050714 -3.647342
 -3.5590239 -3.4552314 -3.7274482 -3.776525  -3.5006151 -3.5888023
 -3.4600835 -3.4885428 -3.5876136 -3.5706646 -3.8489528 -3.584551
 -3.524132  -3.7142427 -3.6950562 -3.981152  -3.6106706 -3.6164796
 -3.2066212 -3.4579368 -3.750553  -3.3716872 -3.6870782 -3.511367
 -3.6732326 -3.7996979 -3.7203078 -3.6592908 -3.8118513 -3.4738905
 -3.2791991 -3.7584434]

entropy: -3.569485902786255
grad: [-2.6192833e-03  0.0000000e+00  0.0000000e+00  4.4926140e-03
  9.7031277e-03 -2.0289514e-03  2.9782928e-04  3.6081651e-03
 -3.7996452e-03 -4.7416282e-03  2.0784489e-03 -5.3373410e-04
  3.4382178e-03  2.4724468e-03 -5.0148804e-04 -3.3163444e-05
 -5.9532044e-03 -4.1803677e-04  1.3369240e-03 -3.5282797e-03
 -3.1199241e-03 -7.6833852e-03 -1.1133708e-03 -1.2630540e-03
  1.4693558e-02  3.5132086e-03 -4.2559239e-03  6.7907651e-03
 -2.9451023e-03  1.7352093e-03 -2.6345621e-03 -5.1515726e-03
 -3.6538844e-03 -2.3125391e-03 -5.3580189e-03  2.9630959e-03
  1.0931742e-02 -4.4064834e-03]

Calc from distrax
entropy: 3.569486141204834
logits: [-0.08427836        -inf        -inf  0.15760802  0.28321218 -0.05905859
  0.02925954  0.13305202 -0.13916473 -0.18824172  0.08766825 -0.00051884
  0.1281999   0.09974058  0.00066989  0.01761878 -0.2606694   0.00373229
  0.06415141 -0.12595929 -0.10677278 -0.39286864 -0.02238724 -0.02819624
  0.38166225  0.13034666 -0.16226953  0.21659629 -0.09879474  0.07691625
 -0.08494915 -0.21141459 -0.13202426 -0.07100749 -0.22356787  0.11439289
  0.30908436 -0.17015982]
probs: [0.02541129 0.         0.         0.03236507 0.03669659 0.02606031
 0.0284666  0.03157999 0.02405414 0.02290214 0.03017881 0.0276314
 0.03142714 0.03054535 0.02766427 0.02813715 0.02130203 0.02774912
 0.02947738 0.02437389 0.02484606 0.01866413 0.02703371 0.02687713
 0.0404932  0.03149467 0.02350475 0.03433166 0.02504507 0.02985607
 0.02539425 0.02237753 0.02422651 0.02575077 0.02210722 0.03099621
 0.0376584  0.02332001]

entropy: -3.569486141204834
grad: [-2.6192830e-03  0.0000000e+00  0.0000000e+00  4.4926135e-03
  9.7031137e-03 -2.0289503e-03  2.9782232e-04  3.6081586e-03
 -3.7996417e-03 -4.7416347e-03  2.0784410e-03 -5.3373526e-04
  3.4382117e-03  2.4724395e-03 -5.0148595e-04 -3.3163637e-05
 -5.9532160e-03 -4.1803997e-04  1.3369160e-03 -3.5282865e-03
 -3.1199232e-03 -7.6833903e-03 -1.1133720e-03 -1.2630549e-03
  1.4693556e-02  3.5131981e-03 -4.2559309e-03  6.7907646e-03
 -2.9451011e-03  1.7351999e-03 -2.6345670e-03 -5.1515764e-03
 -3.6538814e-03 -2.3125438e-03 -5.3580184e-03  2.9631017e-03
  1.0931744e-02 -4.4064852e-03]
harukaki commented 1 year ago

Results

どちらの実装でも非合法手に対する出力に対しての勾配は0となり、非合法手の確率を上げることは起きない。

Consideration

distraxのcategorical分布内のentropyの計算実装 https://github.com/google-deepmind/distrax/blob/079849a0e778e3f9073a076b3589ab434a2ea33c/distrax/_src/distributions/categorical.py#L117

  def entropy(self) -> Array:
    """See `Distribution.entropy`."""
    if self._logits is None:
      log_probs = jnp.log(self._probs)
    else:
      log_probs = jax.nn.log_softmax(self._logits)
    return -jnp.sum(math.mul_exp(log_probs, log_probs), axis=-1)

entropy計算における p * log pの計算実装 https://github.com/google-deepmind/distrax/blob/079849a0e778e3f9073a076b3589ab434a2ea33c/distrax/_src/utils/math.py#L91

def mul_exp(x: Array, logp: Array) -> Array:
  """Returns `x * exp(logp)` with zero output if `exp(logp)==0`.

  Args:
    x: An array.
    logp: An array.

  Returns:
    `x * exp(logp)` with zero output and zero gradient if `exp(logp)==0`,
    even if `x` is NaN or infinite.
  """
  p = jnp.exp(logp)
  # If p==0, the gradient with respect to logp is zero,
  # so we can replace the possibly non-finite `x` with zero.
  x = jnp.where(p == 0, 0.0, x)
  return x * p

log pが計算される際にはlog pが負の無限大になる時、つまり確率pが0になるときには、勾配が0になるように計算される。 illegal action maskをした際に、非合法手に対しての確率は0になるように計算されるので勾配が計算されない。

harukaki commented 1 year ago

entropyを二つの方法で実装し、実際に強化学習を回した結果

https://wandb.ai/hrkkt1213/ppo-bridge/reports/230905--Vmlldzo1MzExMTYz?accessToken=6qecsrr3e0gy99nmurveugllm6e1qhvd1y0qwtrzla83e25iivorqokxatvel8hm

二つで結果は変わらなかった。