Closed carlosgmartin closed 7 months ago
I asked this here, because decision and chance fns alternate the unused portion is discarded (the invalid ones you are seeing wouldn’t be used).
@evanatyourservice Hmm, doesn't that mean computation is being wasted on the invalid actions and outcomes?
Yeah, here’s another issue I opened about that. There’s not a super simple workaround right now because jaxlib hasn’t implemented batched cond. Here’s an issue related to that in jax (there are a few). Doesn’t matter the most for mctx if using tpu or gpu because they can be computed for the most part in parallel but the issue is there nonetheless. Becomes a bigger problem when the two branches differ in computation amounts.
Maybe the mcts could be rewritten in a way that only one branch is computed, I didn’t really think into it. If you see a solution please share, I might look into it as sometimes my nets are large and there would be quite the difference in speed/compute.
Thanks @evanatyourservice for explaining the issue. If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.
@fidlej According to the JAX docs on out-of-bounds indexing:
Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of undefined behavior.
If out-of-bounds indices being passed into functions that expect valid actions or outcomes could potentially cause undefined behavior somewhere downstream, would it perhaps be a good idea to clip the indices passed to decision_recurrent_fn
and chance_recurrent_fn
to their valid ranges?
Thanks @evanatyourservice for explaining the issue. If you rewrite the MCTS to better support Stochastic MuZero, maybe keep that in your fork. I do not plan big changes to mctx.
Will you close enhancement issue about stochastic muzero?
Thanks for noticing. Done.
Example:
Output:
The actions range from 0 to 9 (7+3=10 in total), even though there are only 7 actions. The outcomes range from -7 (a negative integer!) to 2 (7+3=10 in total), even though there are only 3 outcomes.
This may have something to do with the math inside
stochastic_recurrent_fn
.Version information: