Open tadashiK opened 5 years ago
How about changing base.py to the following for readability. Main modifications are on comments and some if statements that makes the code unnecessarily deep.
import copy
import gym
import numpy as np
import torch.nn as nn
from machina.utils import get_device
class BasePol(nn.Module):
"""
A base class of the policy. This class works as a "head" of a given neural network.
The head can be RNN, appropriately normalize the output range of the neural network,
and make the computation parallel.
Parameters
----------
observation_space : gym.Space
Observation space
action_space : gym.Space
Action space
net : torch.nn.Module
rnn : bool
normalize_ac : bool
If True, the output of net is scaled such that it covers the entire action_space.
It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
data_parallel : bool or str
If True, network computation is executed in parallel.
If data_parallel is ddp, network computation is executed in distributed parallel.
parallel_dim : int
Split dimension in data parallel.
"""
def __init__(self, observation_space, action_space, net, rnn=False, normalize_ac=True, data_parallel=False, parallel_dim=0):
nn.Module.__init__(self)
self.observation_space = observation_space
self.action_space = action_space
self.net = net
self.rnn = rnn
self.hs = None # A hidden state vector of the RNN.
self.normalize_ac = normalize_ac
self.data_parallel = data_parallel
if data_parallel is True:
self.dp_net = nn.DataParallel(self.net, dim=parallel_dim)
elif data_parallel == 'ddp':
self.net.to(get_device())
self.dp_net = nn.parallel.DistributedDataParallel(
self.net, device_ids=[get_device()], dim=parallel_dim)
elif data_parallel is not False:
raise ValueError(
'data_parallel must be either Boolean value or str(ddp).')
self.dp_run = False
self.multi = isinstance(action_space, gym.spaces.MultiDiscrete)
self.discrete =\
self.multi or isinstance(action_space, gym.spaces.Discrete)
if self.discrete is False:
self.a_i_shape = action_space.shape
else:
if self.multi:
nvec = action_space.nvec
assert any([nvec[0] == nv for nv in nvec])
self.a_i_shape = (len(nvec), nvec[0])
else:
self.a_i_shape = (action_space.n, )
def __getstate__(self):
state = self.__dict__.copy()
if 'dp_net' in state['_modules']:
_modules = copy.deepcopy(state['_modules'])
del _modules['dp_net']
state['_modules'] = _modules
return state
def __setstate__(self, state):
if 'dp_net' in state:
state.pop('dp_net')
self.__dict__.update(state)
def convert_ac_for_real(self, x):
"""
Scales an action (x), which is the output of self.net, such
that the action is appropriate for a real world task.
"""
if not self.discrete:
lb, ub = self.action_space.low, self.action_space.high
if self.normalize_ac:
x = lb + (x + 1.) * 0.5 * (ub - lb)
x = np.clip(x, lb, ub)
else:
x = np.clip(x, lb, ub)
return x
def reset(self):
"""
Resets a hidden state vector of the RNN.
"""
self.hs = None if self.rnn else None
def _check_obs_shape(self, obs):
"""
Reshapes input (obs) appropriately.
"""
additional_shape = 2 if self.rnn else 1
if len(obs.shape) < additional_shape + len(self.observation_space.shape):
for _ in range(additional_shape + len(self.observation_space.shape) - len(obs.shape)):
obs = obs.unsqueeze(0)
return obs
The codes of categorical_pol.py and multi_categorical_pol.py seem to be almost identical. Rather than having two separated files, how about combining them into a single file called discrete.py with the following code?
import torch
from machina.pols import BasePol
from machina.pds.categorical_pd import CategoricalPd
from machina.pds.multi_categorical_pd import MultiCategoricalPd
from machina.utils import get_device
class CategoricalPol(BasePol):
r"""
A policy for a discrete action space with one dimension.
For example, such action space is given as
:math:`\{ 0, 1, \\dots, n-1 \}`.
Parameters
----------
observation_space : gym.Space
Observation space
action_space : gym.Space
Action space
This must be an instance of gym.spaces.Discrete
net : torch.nn.Module
rnn : bool
normalize_ac : bool
If True, the output of net is scaled such that it covers the entire action_space.
It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
data_parallel : bool or str
If True, network computation is executed in parallel.
If data_parallel is ddp, network computation is executed in distributed parallel.
parallel_dim : int
Split dimension in data parallel.
"""
def __init__(self, observation_space, action_space, net, rnn=False,
normalize_ac=True, data_parallel=False, parallel_dim=0):
BasePol.__init__(self, observation_space, action_space, net, rnn,
normalize_ac, data_parallel, parallel_dim)
self.pd = CategoricalPd()
self.to(get_device())
def forward(self, obs, hs=None, h_masks=None):
obs = self._check_obs_shape(obs)
if self.rnn:
pi = self.forward_with_rnn(obs, hs, h_masks, self.dp_run)
else:
pi = self.dp_net(obs) if self.dp_run else self.net(obs)
ac = self.pd.sample(dict(pi=pi))
ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
return ac_real, ac, dict(pi=pi, hs=hs)
def forward_with_rnn(self, obs, hs=None, h_masks=None, dp_run=False):
time_seq, batch_size, *_ = obs.shape
# If hs is None while self.hs is not, hs = self.hs.
# If self.hs is again None, hs = self.net.init_hs(batch_size)
hs = hs or self.hs or self.net.init_hs(batch_size)
hs = tuple(h.unsqueeze(0) for h in hs[0:2]) if dp_run else hs
h_masks = h_masks or hs[0].new(time_seq, batch_size, 1).zero_()
h_masks = h_masks.reshape(time_seq, batch_size, 1)
pi, hs = self.dp_net(obs, hs, h_masks) if dp_run else self.net(obs, hs, h_masks)
self.hs = hs
return pi
def deterministic_ac_real(self, obs, hs=None, h_masks=None):
"""
action for deployment
"""
obs = self._check_obs_shape(obs)
pi = self.forward_with_rnn(obs, hs, h_masks, dp_run=False)
_, ac = torch.max(pi, dim=-1)
ac_real = self.convert_ac_for_real(ac.detach().cpu().numpy())
return ac_real, ac, dict(pi=pi, hs=hs)
class MultiCategoricalPol(CategoricalPol):
r"""
A policy for a discrete action space with multiple dimensions.
For example, such action space is given as
:math:`\{ 0, 1, \\dots, n-1 \} \times \{ 0, 1, \\dots, m-1 \}`.
Parameters
----------
observation_space : gym.Space
Observation space
action_space : gym.Space
Action space
This must be an instance of gym.spaces.Discrete
net : torch.nn.Module
rnn : bool
normalize_ac : bool
If True, the output of net is scaled such that it covers the entire action_space.
It is assumed that the output is continuous, and each of its dimensions ranges between -1 and 1.
data_parallel : bool or str
If True, network computation is executed in parallel.
If data_parallel is ddp, network computation is executed in distributed parallel.
parallel_dim : int
Split dimension in data parallel.
"""
def __init__(self, observation_space, action_space, net, rnn=False,
normalize_ac=True, data_parallel=False, parallel_dim=0):
BasePol.__init__(self, observation_space, action_space, net, rnn,
normalize_ac, data_parallel, parallel_dim)
self.pd = MultiCategoricalPd()
self.to(get_device())
Thanks for suggesting an improvement of documents and a sharing components between Categorical
and MultiCategorical
. I agree with the improvement of documents. I also agree with the sharing, but we need to check carefully. Could you send two PRs?
Thank you for the reply.
Yes, I will. However, before sending PRs, I would like to ask you two questions.
First, ArgmaxQfPol
requires an instance of SAVfunc
as a variable qfunc
. However, in forward
method, ArgmaxQfPol
uses SAVfunc.max
method, which SAVfunc
does not have in general. Should I comment on this in the code, or should I change SAVfunc
to CEMDeterministicSAVfunc
, which seems to be the only one subclass of SAVfunc
having max
method? I can implement max
method in other subclasses of SAVfunc
if you want me to.
The second question is related to the first question. When an MDP has only a finite number of actions, max
method can be drastically simpler (no need of optimization). Furthermore, in such situation, a Q-value function is frequently represented by a neural network accepting a state and outputting Q-values of all actions. Is there any possibility of dividing SAVfunc
to two classes like ContSDiscAVFunc
and ContSContAVFunc
? Again, if you want me to, I can give it a try.
As of now, some files in pols directory seem to need some modifications. For example, comments of BasePol class is difficult to understand. I suggest some improvements below.