MushroomRL / mushroom-rl

Python library for Reinforcement Learning.
MIT License
805 stars 146 forks source link

Adapting A2C and deep policy gradient methods to Discrete envs #20

Closed lionely closed 4 years ago

lionely commented 4 years ago

Describe the bug Hi, I am getting the following error. I’m trying to use the A2C algorithm, which samples actions from a Gaussian distribution when given a state. The code seems expects torch.float32. I have checked and my inputs are indeed torch.float32 not Long. I’m not sure what to do. I know the algorithm runs as I have run it before, but I get this error when I try to use it in an environment with discrete actions.

Here is the library’s draw_action function:

def draw_action_t(self, state):
        print('draw_action_t',state.dtype)
        return self.distribution_t(state).sample().detach()

and

def distribution_t(self, state):
        mu, sigma = self.get_mean_and_covariance(state)
        return torch.distributions.MultivariateNormal(loc=mu, covariance_matrix=sigma)

Final error:

RuntimeError: _th_normal_ not supported on CPUType for Long

Help is much appreciated!

System information (please complete the following information):

Additional context I'm working with the Gym 'CartPole-v0' environment. I need these deep policy gradient methods to work for discrete environments as I am testing something for my research. Any advice on this? Help is greatly appreciated!

boris-il-forte commented 4 years ago

Hi, Currently, we don't support actor-critic methods with finite actions. However, I think it would be quite easy to make it work for A2C, given the simplicity of the algorithm. Indeed, you are trying to use a gaussian policy in a discrete action environment. What you should do, instead, is implementing a discrete policy (e.g. boltzmann policy) extending the torch policy interface. it should be quite easy btw.

the only problematic part is to remember to cast back to long the action that is converted to float tensor inside the _loss method (or change the line in the a2c code)

lionely commented 4 years ago

Hi, thanks for the response! Do you know if there is a deep RL method in your library that supports finite actions? Apart from DQN? I will make the necessary changes for A2C if there aren't any available.

boris-il-forte commented 4 years ago

Mostly is missing the policy. After that you should be almost done.

All variants of dqn (double, averaged, categorical) support finite actions.

no actor-critic method currently supports finite actions.

You could also try fitted q iteration with deep networks, even if this algorithm works better with extra trees.

boris-il-forte commented 4 years ago

We just pushed the BoltzmannTorchPolicy in the dev branch. see commit 4d20e68 Also, there is an example of usage of this policy with a2c here

This should fix your issue. If not, feel free to open another issue/bug report.