google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.33k stars 189 forks source link

Question about proof that Gumbel guarantees a policy improvement #40

Closed p3achyjr closed 1 year ago

p3achyjr commented 1 year ago

Hey all! Thanks for writing the paper and this library :)

I feel like I'm missing something in the Appendix B proof for "Policy Improvement Proof for Planning with Gumbel". Specifically, the paper mentions that we can replace $argmaxa(g(a) + logits(a))$ with $argmax{a \in A{topn}}(g(a) + logits(a))$. While the action with the highest $g(a) + logits(a)$ is guaranteed to be in $A{topn}$, how can we guarantee that $\mathbb{E} [q(argmax{a \in A{topn}}(g(a) + logits(a))] \ge \mathbb{E} [q(argmax_a(g(a) + logits(a)))]$? We don't take into account the $q$ function while sampling, so how can we be sure the inequality holds when taking $q$ into account?

There's a short proof that the "select n most probable actions" heuristic does not guarantee a policy improvement, formulated as $q = (0, 0, 1)$ and $\pi = (.5, .3, .2)$. AIUI, in order for Gumbel to guarantee a policy improvement on this example, the selection process would have to pick $\pi[2]$ to include in its final set. How is this guaranteed?

fidlej commented 1 year ago

Thanks for asking. First, instead of $\mathbb{E} [q(argmax{a \in A{topn}}(g(a) + logits(a))] \ge \mathbb{E} [q(argmaxa(g(a) + logits(a)))]$ we want to guarantee $\mathbb{E} [q(argmax{a \in A_{topn}}(g(a) + logits(a) + \sigma(q(a)))] \ge \mathbb{E} [q(argmax_a(g(a) + logits(a)))]$. The left-hand-side is mentioned in Equation (17). Hopefully, the rest of the proof makes sense.

About the example: The $argmax{a \in A{topn}}(g(a) + logits(a) + \sigma(q(a))$ will select $\pi[2]$ more often than unbiased sampling from the policy (i.e., than $argmax_a(g(a) + logits(a))$ ).

p3achyjr commented 1 year ago

Thank you! Really appreciate you getting back to me :)