DeepX-inc / machina

Control section: Deep Reinforcement Learning framework
MIT License
279 stars 45 forks source link

Some refactoring to files in pols directory #252

Open tadashiK opened 5 years ago

tadashiK commented 5 years ago

As of now, some files in pols directory seem to need some modifications. For example, comments of BasePol class is difficult to understand. I suggest some improvements below.

tadashiK commented 5 years ago

How about changing base.py to the following for readability. Main modifications are on comments and some if statements that makes the code unnecessarily deep.

import copy

import gym
import numpy as np
import torch.nn as nn

from machina.utils import get_device

class BasePol(nn.Module):
    """
    A base class of the policy. This class works as a "head" of a given neural network.
    The head can be RNN, appropriately normalize the output range of the neural network,
    and make the computation parallel.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0):
        nn.Module.__init__(self)
        self.observation_space = observation_space
        self.action_space = action_space
        self.net = net

        self.rnn = rnn
        self.hs = None  # A hidden state vector of the RNN.

        self.normalize_ac = normalize_ac
        self.data_parallel = data_parallel
        if data_parallel is True:
            self.dp_net = nn.DataParallel(self.net, dim=parallel_dim)
        elif data_parallel == 'ddp':
            self.net.to(get_device())
            self.dp_net = nn.parallel.DistributedDataParallel(
                self.net, device_ids=[get_device()], dim=parallel_dim)
        elif data_parallel is not False:
            raise ValueError(
                'data_parallel must be either Boolean value or str(ddp).')
        self.dp_run = False

        self.multi = isinstance(action_space, gym.spaces.MultiDiscrete)
        self.discrete =\
            self.multi or isinstance(action_space, gym.spaces.Discrete)

        if self.discrete is False:
            self.a_i_shape = action_space.shape
        else:
            if self.multi:
                nvec = action_space.nvec
                assert any([nvec[0] == nv for nv in nvec])
                self.a_i_shape = (len(nvec), nvec[0])
            else:
                self.a_i_shape = (action_space.n, )

    def __getstate__(self):
        state = self.__dict__.copy()
        if 'dp_net' in state['_modules']:
            _modules = copy.deepcopy(state['_modules'])
            del _modules['dp_net']
            state['_modules'] = _modules
        return state

    def __setstate__(self, state):
        if 'dp_net' in state:
            state.pop('dp_net')
        self.__dict__.update(state)

    def convert_ac_for_real(self, x):
        """
        Scales an action (x), which is the output of self.net, such
        that the action is appropriate for a real world task.
        """
        if not self.discrete:
            lb, ub = self.action_space.low, self.action_space.high
            if self.normalize_ac:
                x = lb + (x + 1.) * 0.5 * (ub - lb)
                x = np.clip(x, lb, ub)
            else:
                x = np.clip(x, lb, ub)
        return x

    def reset(self):
        """
        Resets a hidden state vector of the RNN.
        """
        self.hs = None if self.rnn else None

    def _check_obs_shape(self, obs):
        """
        Reshapes input (obs) appropriately.
        """
        additional_shape = 2 if self.rnn else 1
        if len(obs.shape) < additional_shape + len(self.observation_space.shape):
            for _ in range(additional_shape + len(self.observation_space.shape) - len(obs.shape)):
                obs = obs.unsqueeze(0)
        return obs
tadashiK commented 5 years ago

The codes of categorical_pol.py and multi_categorical_pol.py seem to be almost identical. Rather than having two separated files, how about combining them into a single file called discrete.py with the following code?

import torch
from machina.pols import BasePol
from machina.pds.categorical_pd import CategoricalPd
from machina.pds.multi_categorical_pd import MultiCategoricalPd
from machina.utils import get_device

class CategoricalPol(BasePol):
    r"""
    A policy for a discrete action space with one dimension.
    For example, such action space is given as
        :math:`\{ 0, 1, \\dots, n-1 \}`.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
        This must be an instance of gym.spaces.Discrete
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False,
                 normalize_ac=True, data_parallel=False, parallel_dim=0):
        BasePol.__init__(self, observation_space, action_space, net, rnn,
                         normalize_ac, data_parallel, parallel_dim)
        self.pd = CategoricalPd()
        self.to(get_device())

    def forward(self, obs, hs=None, h_masks=None):
        obs = self._check_obs_shape(obs)

        if self.rnn:
            pi = self.forward_with_rnn(obs, hs, h_masks, self.dp_run)
        else:
            pi = self.dp_net(obs) if self.dp_run else self.net(obs)

        ac = self.pd.sample(dict(pi=pi))
        ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
        return ac_real, ac, dict(pi=pi, hs=hs)

    def forward_with_rnn(self, obs, hs=None, h_masks=None, dp_run=False):
        time_seq, batch_size, *_ = obs.shape

        # If hs is None while self.hs is not, hs = self.hs.
        # If self.hs is again None, hs = self.net.init_hs(batch_size)
        hs = hs or self.hs or self.net.init_hs(batch_size)
        hs = tuple(h.unsqueeze(0) for h in hs[0:2]) if dp_run else hs

        h_masks = h_masks or hs[0].new(time_seq, batch_size, 1).zero_()
        h_masks = h_masks.reshape(time_seq, batch_size, 1)

        pi, hs = self.dp_net(obs, hs, h_masks) if dp_run else self.net(obs, hs, h_masks)
        self.hs = hs

        return pi

    def deterministic_ac_real(self, obs, hs=None, h_masks=None):
        """
        action for deployment
        """
        obs = self._check_obs_shape(obs)
        pi = self.forward_with_rnn(obs, hs, h_masks, dp_run=False)
        _, ac = torch.max(pi, dim=-1)
        ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
        return ac_real, ac, dict(pi=pi, hs=hs)

class MultiCategoricalPol(CategoricalPol):
    r"""
    A policy for a discrete action space with multiple dimensions.
    For example, such action space is given as
        :math:`\{ 0, 1, \\dots, n-1 \} \times \{ 0, 1, \\dots, m-1 \}`.

    Parameters
    ----------
    observation_space : gym.Space
        Observation space
    action_space : gym.Space
        Action space
        This must be an instance of gym.spaces.Discrete
    net : torch.nn.Module
    rnn : bool
    normalize_ac : bool
        If True, the output of net is scaled such that it covers the entire action_space.
        It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
    data_parallel : bool or str
        If True, network computation is executed in parallel.
        If data_parallel is ddp, network computation is executed in distributed parallel.
    parallel_dim : int
        Split dimension in data parallel.
    """

    def __init__(self, observation_space, action_space, net, rnn=False,
                 normalize_ac=True, data_parallel=False, parallel_dim=0):
        BasePol.__init__(self, observation_space, action_space, net, rnn,
                         normalize_ac, data_parallel, parallel_dim)
        self.pd = MultiCategoricalPd()
        self.to(get_device())
rarilurelo commented 5 years ago

Thanks for suggesting an improvement of documents and a sharing components between Categorical and MultiCategorical. I agree with the improvement of documents. I also agree with the sharing, but we need to check carefully. Could you send two PRs?

tadashiK commented 5 years ago

Thank you for the reply.

Yes, I will. However, before sending PRs, I would like to ask you two questions.

First, ArgmaxQfPol requires an instance of SAVfunc as a variable qfunc. However, in forward method, ArgmaxQfPol uses SAVfunc.max method, which SAVfunc does not have in general. Should I comment on this in the code, or should I change SAVfunc to CEMDeterministicSAVfunc, which seems to be the only one subclass of SAVfunc having max method? I can implement max method in other subclasses of SAVfunc if you want me to.

The second question is related to the first question. When an MDP has only a finite number of actions, max method can be drastically simpler (no need of optimization). Furthermore, in such situation, a Q-value function is frequently represented by a neural network accepting a state and outputting Q-values of all actions. Is there any possibility of dividing SAVfunc to two classes like ContSDiscAVFunc and ContSContAVFunc? Again, if you want me to, I can give it a try.