YeWR / EfficientZero

Open-source codebase for EfficientZero, from "Mastering Atari Games with Limited Data" at NeurIPS 2021.
GNU General Public License v3.0
847 stars 131 forks source link

Slight discrepancy with implementation of value scaling #13

Closed henrycharlesworth closed 2 years ago

henrycharlesworth commented 2 years ago

Hey, firstly just wanted to say thank you because this is an amazing repo for understanding how MuZero/EfficientZero work in detail!

I've been trying to dig into exactly how the value prediction is done as it seems like a pretty significant detail that is hidden away in an appendix and I think there seems to be a slight discrepancy (that probably doesn't make much difference but is maybe still worth highlighting).

In the original paper (https://arxiv.org/pdf/1805.11593.pdf) they define the scaling function as:

with the inverse function given by proposition A.2 (iii).

but in the MuZero appendix they have:

(with the final term inside the bracket).

Unless I'm mistaken, in the code you've used the MuZero version of h(x), but for the inverse formula you've used the formula given in proposition A.2 (iii) of the first paper - which won't quite be correct anymore, right?

Just to show the discrepancy - if I look at the following code:

import torch

def scalar_transform(x, epsilon=0.001):
    sign = torch.ones(x.shape).float().to(x.device)
    sign[x < 0] = -1.0
    output = sign * (torch.sqrt(torch.abs(x) + 1) - 1 + epsilon * x)
    return output

def inverse_scalar_transform(value, epsilon=0.001):
    sign = torch.ones(value.shape).float().to(value.device)
    sign[value < 0] = -1.0
    output = (((torch.sqrt(1 + 4 * epsilon * (torch.abs(value) + 1 + epsilon)) - 1) / (2 * epsilon)) ** 2 - 1)
    output = sign * output
    return output

a = torch.randn(1000)
b = scalar_transform(a)
c = inverse_scalar_transform(b)

print(torch.sum(torch.abs(a-c)))

which is how the functions are implemented in this code base I get a value of ~2.4 printed, whilst if I change the scalar transform to be the same as in the first paper I get a value of ~0.04.

Hwhitetooth commented 2 years ago

Hi @henrycharlesworth,

Good catch! To add to your comment, I think it is a typo in the MuZero preprint. They actually fixed it in the Nature paper(page 8, right column).

Best

YeWR commented 2 years ago

Hi, @henrycharlesworth @Hwhitetooth Really thank you for correcting the formula! This is a mistake, and it should be h1, rather than h2. image.

Current implementation will lead to larger error. And we will correct this formula later.

More importantly, we can find that when x >= 0, h1(x) is equal to h2(x). When x < 0, we have image

Since eps=0.0001, the error of h1(x) and h2(x) is small. Therefore, it probably doesn't make much difference.

And thank you again for your detailed discussion!