google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

Question regarding `qtransform_by_parent_and_siblings` in `muzero_policy` #84

Closed sotetsuk closed 8 months ago

sotetsuk commented 8 months ago

Hello,

I have a question about the qtransform_by_parent_and_siblings function used in muzero_policy as default QTranform. It appears to implement normalization differently from the method described in the original MuZero paper and its pseudocode. While the original MuZero normalizes using the min/max values of the entire tree, qtransform_by_parent_and_siblings seems to use the min/max values of sibling nodes for normalization. Could you please clarify the origin of this function? Is there any specific reference or source for this approach?

Description in the original MuZero paper image
Pseudo-code in the original MuZero paper ```python class MinMaxStats(object): """A class that holds the min-max values of the tree.""" def __init__(self, known_bounds: Optional[KnownBounds]): self.maximum = known_bounds.max if known_bounds else -MAXIMUM_FLOAT_VALUE self.minimum = known_bounds.min if known_bounds else MAXIMUM_FLOAT_VALUE def update(self, value: float): self.maximum = max(self.maximum, value) self.minimum = min(self.minimum, value) def normalize(self, value: float) -> float: if self.maximum > self.minimum: # We normalize only when we have set the maximum and minimum values. return (value - self.minimum) / (self.maximum - self.minimum) return value ```

As an additional note, in my preliminary experiments with a AlphaZero-style training in 9x9 Go, I observed that qtransform_by_parent_and_siblings seems to perform better than the original tree-wide normalization (and qtransform_by_min_max).

Thank you for your time and assistance.

fidlej commented 8 months ago

Thanks for asking. This normalization is now recommended instead of the original normalization. As you noticed, it works better.

The qtransform_completed_by_mix_value used by Gumbel MuZero uses a similar normalization, when using rescale_values=True.