datawhalechina / easy-rl

强化学习中文教程(蘑菇书🍄),在线阅读地址:https://datawhalechina.github.io/easy-rl/
Other
9.04k stars 1.81k forks source link

SAC代码问题 #134

Closed zichunxx closed 1 year ago

zichunxx commented 1 year ago

源代码:

        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()

        ## 计算reparameterization参数损失
        mean_loss = mean_lambda * mean.pow(2).mean()
        std_loss  = std_lambda  * log_std.pow(2).mean()
        z_loss    = z_lambda    * z.pow(2).sum(1).mean()  

1) 请问policy_loss的这种计算方式有无参考文献呢?

源代码中没有涉及对熵的调整,所以我理解源代码复现的是原始的SAC

但是,按照我的理解,原始SAC对policy_loss的计算不应该是:policy_loss = (log_prob - expected_new_q_value).mean() ?

2) 此外,关于源代码中reparameterization参数损失的计算是否有相关资料可以参考?

谢谢!

johnjim0816 commented 1 year ago

源代码:

        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss = (log_prob * (log_prob - log_prob_target).detach()).mean()

        ## 计算reparameterization参数损失
        mean_loss = mean_lambda * mean.pow(2).mean()
        std_loss  = std_lambda  * log_std.pow(2).mean()
        z_loss    = z_lambda    * z.pow(2).sum(1).mean()  

1) 请问policy_loss的这种计算方式有无参考文献呢?

源代码中没有涉及对熵的调整,所以我理解源代码复现的是原始的SAC

但是,按照我的理解,原始SAC对policy_loss的计算不应该是:policy_loss = (log_prob - expected_new_q_value).mean() ?

2) 此外,关于源代码中reparameterization参数损失的计算是否有相关资料可以参考?

谢谢!

熵的计算确实有点问题,先看这里的PPO吧,有熵的计算。

第二个问题搜reparameterization trick

zichunxx commented 1 year ago

谢谢