davidbrandfonbrener / onestep-rl

40 stars 1 forks source link

Potential bug in density calculation #4

Open ezhang7423 opened 2 years ago

ezhang7423 commented 2 years ago

I believe I have found a bug in your density estimation for the normal distribution. This is currently your GaussMLP:

class GaussMLP(nn.Module):
    def __init__(self, state_dim, action_dim, width, depth, dist_type):
        super().__init__()
        self.net = utils.MLP(input_shape=(state_dim), output_dim=2*action_dim,
                        width=width, depth=depth)
        self.log_std_bounds = (-5., 0.)
        self.mu_bounds = (-1., 1.)
        self.dist_type = dist_type

    def forward(self, s):
        s = torch.flatten(s, start_dim=1)
        mu, log_std = self.net(s).chunk(2, dim=-1)

        mu = soft_clamp(mu, *self.mu_bounds)
        log_std = soft_clamp(log_std, *self.log_std_bounds)

        std = log_std.exp()
        if self.dist_type == 'normal':
            dist = D.Normal(mu, std)
        elif self.dist_type == 'trunc':
            dist = utils.TruncatedNormal(mu, std)
        elif self.dist_type == 'squash':
            dist = utils.SquashedNormal(mu, std)
        else:
            raise TypeError("Expected dist_type to be 'normal', 'trunc', or 'squash'")
        return dist

In the normal case, since you're performing a tanh transformation before returning a normal distribution, the density function should include an atanh and a normalization term: ebdb4ec31bc424a637a1704691feff1a85f0d901

Do you agree?

davidbrandfonbrener commented 2 years ago

No, this is not a bug. The mean and std are clamped before the distribution is returned, so we can stack any distribution on top of that mean and variance. For the case of the "normal" distribution, this just means it is a normal distribution with mean and variance falling in some range.

The truncated and squashed distributions put the distribution itself through a transformation, not the mean and variance (i.e. they only draw samples within fixed bounds).