facebookresearch / torchbeast

A PyTorch Platform for Distributed RL
Apache License 2.0
734 stars 113 forks source link

Continuous Action Apace #36

Open MXD6 opened 2 years ago

MXD6 commented 2 years ago

Hello Author: How can I apply vtrace to continuous action space? I take the policy_logits as the normal distribution.

import torch.distributions as tdist

def __init__(self, observation_shape, num_actions, use_lstm=False):
     ...
     self.policy = nn.Linear(core_output_size, 2)
     ...
def forward(self, inputs, core_state=()):
     ...
     policy_logits = self.policy(core_output) 
     mu = policy_logits[0]
     sigma = policy_logits[1]
     action = tdist.Normal(mu, sigma).sample(1)
     ...