chainer / chainerrl

ChainerRL is a deep reinforcement learning library built on top of Chainer.
MIT License
1.17k stars 226 forks source link

Define a softmax LSTM architecture #512

Closed oribarel closed 5 years ago

oribarel commented 5 years ago

Hi,

I'm trying to combine the A3CLSTMGaussian and A3CFFSoftmax examples to an A3CLSTMSoftmax architecture. Is the following the right way to go? Would you change something?

BTW, If I managed to use A3CFFSoftmax successfully, should I change something in the observations? Namely, should the observation contain history of previous observations or everything is handled for me by Chainer / ChainerRL? One more question, what is the argument t-max used for?

class A3CLSTMSoftmax(chainer.ChainList, a3c.A3CModel):
    def __init__(self, obs_size, action_size, hidden_size=200, lstm_size=128):
        self.pi_head = L.Linear(obs_size, hidden_size)
        self.v_head = L.Linear(obs_size, hidden_size)
        self.pi_lstm = L.LSTM(hidden_size, lstm_size)
        self.v_lstm = L.LSTM(hidden_size, lstm_size)
        self.pi = policies.SoftmaxPolicy(lstm_size, action_size, hidden_sizes=(hidden_size, )
        self.v = v_function.FCVFunction(lstm_size)
        super().__init__(self.pi_head, self.v_head,
                         self.pi_lstm, self.v_lstm, self.pi, self.v)

    def pi_and_v(self, state):

        def forward(head, lstm, tail):
            h = F.relu(head(state))
            h = lstm(h)
            return tail(h)

        pout = forward(self.pi_head, self.pi_lstm, self.pi)
        vout = forward(self.v_head, self.v_lstm, self.v)

        return pout, vout
muupan commented 5 years ago

Hi, your implementation of A3CLSTMSoftmax looks correct.

Namely, should the observation contain history of previous observations or everything is handled for me by Chainer / ChainerRL?

This completely depends on what environment you use. It is the environment, not Chainer or ChainerRL, that determines what information is contained in an observation.

what is the argument t-max used for?

t_max is the length of rollouts used for A3C's updates, defined in Algorithm S3 in https://arxiv.org/abs/1602.01783.