thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.78k stars 1.12k forks source link

RNN for continuous CQL algorithm #513

Open BFAnas opened 2 years ago

BFAnas commented 2 years ago

This is a request for RNN support in continuous CQL algorithm. Thanks for this awesome lib!

Trinkle23897 commented 2 years ago

@thkkk

BFAnas commented 2 years ago

I think the problem is in RecurrentActorProb class. From this part of the code it seems that it expects an input of shape [bsz, len*dim]:

self.nn = nn.LSTM(
              input_size=int(np.prod(state_shape)),
              hidden_size=hidden_layer_size,
              num_layers=layer_num,
              batch_first=True,
              )

But this part of the code suggests that the obs passed to self.nn is of shape [bsz, len, dim]:

obs [bsz, len, dim] (training) or [bsz, dim] (evaluation)
    obs, (hidden, cell) = self.nn(obs)

Could you investigate this?

Trinkle23897 commented 2 years ago

I think we can still use this setting by setting buffer's stack_num to >1. In short, when training RNN+CQL, we use [bsz, len, dim] to train a Recurrent network with trajectory length len, and when in the test phase we use [bsz, dim] because at this time the neural network can maintain a state to inference action.

BFAnas commented 2 years ago

@Trinkle23897 Thank you for your answer. I understand that, and that's not the problem. To explain better: I see a part of the code where the input expected is of shape [bsz, len*dim], whereas the input passed is of the shape [bsz, len, dim]. In this part of the code self.nn expects an input of shape [bsz, len*dim] note that int(np.prod(state_shape)) = len*dim :

self.nn = nn.LSTM(
              input_size=int(np.prod(state_shape)),
              hidden_size=hidden_layer_size,
              num_layers=layer_num,
              batch_first=True,
              )

And later, obs that is passed to self.nn is of shape [bsz, len, dim], therefore different of the shape that self.nn expects. Do you agree?

Trinkle23897 commented 2 years ago

note that int(np.prod(state_shape)) = len*dim

I don't think so. state_shape should always be a single frame, i.e., int(np.prod(state_shape)) = dim. If it's not the case, you should modify it outside correspondingly.

BFAnas commented 2 years ago

You mean I should have dim instead of len*dim? Even when I'm working with stack_num!=1? But anyway self.nn is getting obs of shape [bsz, len, dim] when it is expecting [bsz, int(np.prod(state_shape))] whatever that is.

Trinkle23897 commented 2 years ago
In [16]: m = nn.LSTM(input_size=3, hidden_size=10, num_layers=1, batch_first=True)

In [17]: s = torch.zeros([64, 1, 3])

In [18]: ns, (h, c) = m(s)

In [19]: ns.shape, h.shape, c.shape
Out[19]: (torch.Size([64, 1, 10]), torch.Size([1, 64, 10]), torch.Size([1, 64, 10]))

In [20]: s = torch.zeros([64, 16, 3])

In [21]: ns, (h, c) = m(s)

In [22]: ns.shape, h.shape, c.shape
Out[22]: (torch.Size([64, 16, 10]), torch.Size([1, 64, 10]), torch.Size([1, 64, 10]))

The input of self.nn.forward is always 3-dim tensor, not 2-dim.

BFAnas commented 2 years ago

If I have an observation of shape [bsz, len, dim] what is the state_shape argument that I should pass to RecurrentActorProb?

Trinkle23897 commented 2 years ago

Should be dim. Let's take atari example: the observation space is (4, 84, 84) where 4 is len. However, when defining recurrent network, the state_shape should be 84*84 instead of 4*84*84, and the length of trajectory is defined in replay buffer's sampling method.

BFAnas commented 2 years ago

Okay, thanks for the support. It is a little bit confusing since state_shape for the normal ActorProb is equal to obs.shape, maybe you can consider making them (ActorProb and RecurrentActorProb) coherent in this regard. Also, more ambitiously, maybe you can make the way of constructing RecurrentActorProb and RecurrentCritic with RecurentNet like ActorProb and Critic are constructed with Net.

Trinkle23897 commented 2 years ago

But here comes the problem: there are two ways to perform this kind of stack-obs:

  1. gym.Env outputs single frame -- stack by buffer.sample();
  2. gym.Env outputs stacked frame by FrameStack env wrapper -- no stack at all, or de-stack -> save to buffer -> stack by buffer.sample();

I cannot make any assumption here so that's why the current code looks like.

BFAnas commented 2 years ago

For making CQL work with RNN, I changed tmp_obs and tmp_obs_next in cql.py>CQLPolicy>learn as follows:

tmp_obs = obs.unsqueeze(1) \
    .repeat(1, self.num_repeat_actions, 1, 1) \
    .view(batch_size * self.num_repeat_actions, obs.shape[-2], obs.shape[-1])
tmp_obs_next = obs_next.unsqueeze(1) \
    .repeat(1, self.num_repeat_actions, 1, 1) \
    .view(batch_size * self.num_repeat_actions, obs.shape[-2], obs.shape[-1])

Now the code executes without errors, but maybe I'm missing something else necessary for RNN to work correctly.

Trinkle23897 commented 2 years ago

Glad to hear that!

BFAnas commented 2 years ago

Which task would you recommend for testing this solution? Ideally it should be a task in d4rl datasets and where SAC has been tried with RNNs and worked correctly, since CQL inherits from SAC.

thkkk commented 2 years ago

Which task would you recommend for testing this solution? Ideally it should be a task in d4rl datasets and where SAC has been tried with RNNs and worked correctly, since CQL inherits from SAC.

I think that the task for testing CQL is the same as the task testing for SAC, e.g., Pendulum for unit test or halfcheetah-medium in d4rl. I don't know if the existence of RNN will affect the choice of tasks.