Closed harukaki closed 1 year ago
workspace/test_entropy.py
にて検証
二つの方法でmaskしたpolicyに対するentropyを実装
def entropy_from_dif(logits, mask):
"""非合法手をmaskしたエントロピーを定義から計算"""
logits = logits + jnp.finfo(np.float64).min * (~mask)
log_probs = jax.nn.log_softmax(logits)
probs = jax.nn.softmax(logits)
entropy = jnp.array(0, dtype=jnp.float32)
for i in range(38):
entropy = jax.lax.cond(
mask[i], lambda: entropy + log_probs[i] * probs[i], lambda: entropy
)
return (
-entropy,
logits,
probs,
log_probs,
)
def entropy_from_distrax(logits, mask):
"""非合法手をmaskしたエントロピーをdistraxを用いて計算"""
illegal_action_masked_logits = logits + jnp.finfo(np.float64).min * (~mask)
illegal_action_masked_pi = distrax.Categorical(logits=illegal_action_masked_logits)
return (
illegal_action_masked_pi.entropy(),
illegal_action_masked_logits,
illegal_action_masked_pi.probs,
)
実行結果
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
Calc from difinition
entropy: 3.860790729522705
logits: [ 1. 1. -inf -2.]
probs: [0.48785555 0.48785555 0. 0.0242889 ]
log_probs: [-0.7177359 -0.7177359 -inf -3.7177358]
entropy: -3.860790729522705
grad: [ 0.03554842 0.03554842 0. -0.07109684]
Calc from distrax
entropy: 0.7906026244163513
logits: [ 1. 1. -inf -2.]
probs: [0.48785555 0.48785555 0. 0.0242889 ]
entropy: -0.7906026244163513
grad: [ 0.03554843 0.03554843 0. -0.07109685]
Loading dds results from dds_results/train_000.npy ...
Calc from difinition
entropy: 3.569485902786255
logits: [-0.08427836 -inf -inf 0.15760802 0.28321218 -0.05905859
0.02925954 0.13305202 -0.13916473 -0.18824172 0.08766825 -0.00051884
0.1281999 0.09974058 0.00066989 0.01761878 -0.2606694 0.00373229
0.06415141 -0.12595929 -0.10677278 -0.39286864 -0.02238724 -0.02819624
0.38166225 0.13034666 -0.16226953 0.21659629 -0.09879474 0.07691625
-0.08494915 -0.21141459 -0.13202426 -0.07100749 -0.22356787 0.11439289
0.30908436 -0.17015982]
probs: [0.02541129 0. 0. 0.03236508 0.0366966 0.0260603
0.0284666 0.03157999 0.02405414 0.02290214 0.03017881 0.02763141
0.03142714 0.03054535 0.02766427 0.02813715 0.02130203 0.02774912
0.02947738 0.02437389 0.02484606 0.01866413 0.02703371 0.02687713
0.0404932 0.03149467 0.02350475 0.03433167 0.02504507 0.02985607
0.02539425 0.02237753 0.02422651 0.02575077 0.02210722 0.03099621
0.0376584 0.02332002]
log_probs: [-3.6725616 -inf -inf -3.4306755 -3.3050714 -3.647342
-3.5590239 -3.4552314 -3.7274482 -3.776525 -3.5006151 -3.5888023
-3.4600835 -3.4885428 -3.5876136 -3.5706646 -3.8489528 -3.584551
-3.524132 -3.7142427 -3.6950562 -3.981152 -3.6106706 -3.6164796
-3.2066212 -3.4579368 -3.750553 -3.3716872 -3.6870782 -3.511367
-3.6732326 -3.7996979 -3.7203078 -3.6592908 -3.8118513 -3.4738905
-3.2791991 -3.7584434]
entropy: -3.569485902786255
grad: [-2.6192833e-03 0.0000000e+00 0.0000000e+00 4.4926140e-03
9.7031277e-03 -2.0289514e-03 2.9782928e-04 3.6081651e-03
-3.7996452e-03 -4.7416282e-03 2.0784489e-03 -5.3373410e-04
3.4382178e-03 2.4724468e-03 -5.0148804e-04 -3.3163444e-05
-5.9532044e-03 -4.1803677e-04 1.3369240e-03 -3.5282797e-03
-3.1199241e-03 -7.6833852e-03 -1.1133708e-03 -1.2630540e-03
1.4693558e-02 3.5132086e-03 -4.2559239e-03 6.7907651e-03
-2.9451023e-03 1.7352093e-03 -2.6345621e-03 -5.1515726e-03
-3.6538844e-03 -2.3125391e-03 -5.3580189e-03 2.9630959e-03
1.0931742e-02 -4.4064834e-03]
Calc from distrax
entropy: 3.569486141204834
logits: [-0.08427836 -inf -inf 0.15760802 0.28321218 -0.05905859
0.02925954 0.13305202 -0.13916473 -0.18824172 0.08766825 -0.00051884
0.1281999 0.09974058 0.00066989 0.01761878 -0.2606694 0.00373229
0.06415141 -0.12595929 -0.10677278 -0.39286864 -0.02238724 -0.02819624
0.38166225 0.13034666 -0.16226953 0.21659629 -0.09879474 0.07691625
-0.08494915 -0.21141459 -0.13202426 -0.07100749 -0.22356787 0.11439289
0.30908436 -0.17015982]
probs: [0.02541129 0. 0. 0.03236507 0.03669659 0.02606031
0.0284666 0.03157999 0.02405414 0.02290214 0.03017881 0.0276314
0.03142714 0.03054535 0.02766427 0.02813715 0.02130203 0.02774912
0.02947738 0.02437389 0.02484606 0.01866413 0.02703371 0.02687713
0.0404932 0.03149467 0.02350475 0.03433166 0.02504507 0.02985607
0.02539425 0.02237753 0.02422651 0.02575077 0.02210722 0.03099621
0.0376584 0.02332001]
entropy: -3.569486141204834
grad: [-2.6192830e-03 0.0000000e+00 0.0000000e+00 4.4926135e-03
9.7031137e-03 -2.0289503e-03 2.9782232e-04 3.6081586e-03
-3.7996417e-03 -4.7416347e-03 2.0784410e-03 -5.3373526e-04
3.4382117e-03 2.4724395e-03 -5.0148595e-04 -3.3163637e-05
-5.9532160e-03 -4.1803997e-04 1.3369160e-03 -3.5282865e-03
-3.1199232e-03 -7.6833903e-03 -1.1133720e-03 -1.2630549e-03
1.4693556e-02 3.5131981e-03 -4.2559309e-03 6.7907646e-03
-2.9451011e-03 1.7351999e-03 -2.6345670e-03 -5.1515764e-03
-3.6538814e-03 -2.3125438e-03 -5.3580184e-03 2.9631017e-03
1.0931744e-02 -4.4064852e-03]
どちらの実装でも非合法手に対する出力に対しての勾配は0となり、非合法手の確率を上げることは起きない。
distraxのcategorical分布内のentropyの計算実装 https://github.com/google-deepmind/distrax/blob/079849a0e778e3f9073a076b3589ab434a2ea33c/distrax/_src/distributions/categorical.py#L117
def entropy(self) -> Array:
"""See `Distribution.entropy`."""
if self._logits is None:
log_probs = jnp.log(self._probs)
else:
log_probs = jax.nn.log_softmax(self._logits)
return -jnp.sum(math.mul_exp(log_probs, log_probs), axis=-1)
entropy計算における p * log pの計算実装 https://github.com/google-deepmind/distrax/blob/079849a0e778e3f9073a076b3589ab434a2ea33c/distrax/_src/utils/math.py#L91
def mul_exp(x: Array, logp: Array) -> Array:
"""Returns `x * exp(logp)` with zero output if `exp(logp)==0`.
Args:
x: An array.
logp: An array.
Returns:
`x * exp(logp)` with zero output and zero gradient if `exp(logp)==0`,
even if `x` is NaN or infinite.
"""
p = jnp.exp(logp)
# If p==0, the gradient with respect to logp is zero,
# so we can replace the possibly non-finite `x` with zero.
x = jnp.where(p == 0, 0.0, x)
return x * p
log pが計算される際にはlog pが負の無限大になる時、つまり確率pが0になるときには、勾配が0になるように計算される。 illegal action maskをした際に、非合法手に対しての確率は0になるように計算されるので勾配が計算されない。
二つで結果は変わらなかった。
entropy正則化項の実装を検証する