toshikwa / gail-airl-ppo.pytorch

PyTorch implementation of GAIL and AIRL based on PPO.
MIT License
189 stars 30 forks source link

about reparametrize #7

Closed ZhenZhuHuang closed 2 years ago

ZhenZhuHuang commented 2 years ago
def calculate_log_pi(log_stds, noises, actions):
    gaussian_log_probs = (-0.5 * noises.pow(2) - log_stds).sum(
        dim=-1, keepdim=True) - 0.5 * math.log(2 * math.pi) * log_stds.size(-1)

    return gaussian_log_probs - torch.log(
        1 - actions.pow(2) + 1e-6).sum(dim=-1, keepdim=True)
def reparameterize(means, log_stds):
    noises = torch.randn_like(means)
    us = means + noises * log_stds.exp()
    actions = torch.tanh(us)
    return actions, calculate_log_pi(log_stds, noises, actions)`

Hello, I would like to ask why the calculate_log_pi function calculates the logpi this way. I can't find the algorithm, okay?

toshikwa commented 2 years ago

The reason why I do us = means + noises * log_stds.exp() is called "reparametrization trick". Try calculating the log probability of gaussian distribution using noises[i] = N(\mu=0, \sigma=I), you'll see.

I avoided using torch.distributions.Normal().log_prob() in order to reduce calculation.

Does it make sense?

ZhenZhuHuang commented 2 years ago

What I don't understand is that the calculation formula of normal. Log_prob seems to be different from your code, so it is confused. Thank you for your answer, Sir. ------------------ 原始邮件 ------------------ 发件人: "ku2482/gail-airl-ppo.pytorch" @.>; 发送时间: 2022年8月15日(星期一) 下午3:23 @.>; @.**@.>; 主题: Re: [ku2482/gail-airl-ppo.pytorch] about reparametrize (Issue #7)

The reason why I do us = means + noises * log_stds.exp() is called "reparametrization trick". Try calculating the log probability of gaussian distribution using noises[i] = N(\mu=0, \sigma=I), you'll see.

I avoided using torch.distributions.Normal().log_prob() in order to reduce calculation.

Does it make sense?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

toshikwa commented 2 years ago

Normal(means, stds).log_prob(means + noises * stds) equals to Normal(0, stds).log_prob(noises * stds). So, I can reduce unnecessary calculation.

ZhenZhuHuang commented 2 years ago

Okay, I think I get it. One last question, does GAIL based on PPO and SAC have any relevant papers or blogs? Thank you for your help, Sir.

------------------ 原始邮件 ------------------ 发件人: "ku2482/gail-airl-ppo.pytorch" @.>; 发送时间: 2022年8月15日(星期一) 下午4:25 @.>; @.**@.>; 主题: Re: [ku2482/gail-airl-ppo.pytorch] about reparametrize (Issue #7)

Normal(means, stds).log_prob(means + noises stds) equals to Normal(0, stds).log_prob(noises stds). So, I can reduce unnecessary calculation.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

toshikwa commented 2 years ago

I'm sorry, but I don't remember...

ZhenZhuHuang commented 2 years ago

Well, thank you for your patience!

------------------ 原始邮件 ------------------ 发件人: "ku2482/gail-airl-ppo.pytorch" @.>; 发送时间: 2022年8月15日(星期一) 下午5:06 @.>; @.**@.>; 主题: Re: [ku2482/gail-airl-ppo.pytorch] about reparametrize (Issue #7)

I'm sorry, but I don't remember...

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>