datawhalechina / easy-rl

强化学习中文教程(蘑菇书🍄),在线阅读地址:https://datawhalechina.github.io/easy-rl/
Other
9.32k stars 1.85k forks source link

/chapter9/chapter9_questions&keywords #59

Open qiwang067 opened 3 years ago

qiwang067 commented 3 years ago

https://datawhalechina.github.io/easy-rl/#/chapter9/chapter9_questions&keywords

Description

Strawberry47 commented 2 years ago

Thanks♪(・ω・)ノ

Strawberry47 commented 2 years ago

请问actor-critic是off-policy吗

qiwang067 commented 2 years ago

请问actor-critic是off-policy吗

您好,A2C 和 A3C 都是 on-policy(同策略) 的

15138922051 commented 2 years ago

A3C的code有吗,谢谢楼主

qiwang067 commented 2 years ago

A3C的code有吗,谢谢楼主

有A2C的 code: https://github.com/datawhalechina/easy-rl/tree/master/codes/A2C

chenjiaqiang-a commented 1 year ago

代码库里实现的a2c算法和理论公式有些出入,我按照理论公式实现了一版,训练出来效果还不错,不知道这样的实现是否有问题,可以帮我看一下吗?

def update(self):
    state_pool, action_pool, reward_pool, next_state_pool, done_pool = self.memory.sample(len(self.memory), True)
    self.memory.clear()

    states = torch.tensor(state_pool, dtype=torch.float32, device=self.device)
    actions = torch.tensor(action_pool, dtype=torch.float32, device=self.device)
    next_states = torch.tensor(next_state_pool, dtype=torch.float32, device=self.device)
    rewards = torch.tensor(reward_pool, dtype=torch.float32, device=self.device)
    masks = torch.tensor(1.0 - np.float32(done_pool), device=self.device)

    probs, values = self.model(states)
    _, next_values = self.model(next_states)

    dist = Categorical(probs)
    log_probs = dist.log_prob(actions)
    advantages = rewards + self.gamma * next_values.squeeze().detach() * masks - values.squeeze()
    actor_loss = -(log_probs * advantages.detach()).mean()
    critic_loss = advantages.pow(2).mean()
    loss = actor_loss + self.critic_factor * critic_loss - self.entropy_coef * dist.entropy().mean()

    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()