nslyubaykin / rnns_for_pomdp

Recurrent Policies for Handling Partially Observable Environments
2 stars 0 forks source link

lstm+sac #1

Closed 1900360 closed 1 year ago

1900360 commented 1 year ago

I use the code of LSTM+SAC, but it seems to report an error, do you have any solution?

`# Actor & critic: actor = SAC( policy_net=NormalLSTM(obs_dim, acs_dim, nlayers_lstm=2, seq_len=1+n_lags, nunits_lstm=32, nunits_dense=8, out_activation=torch.nn.Identity(), init_log_std=-1.0), device=torch.device('cpu'), auto_tune_alpha=True, target_entropy=-acs_dim, learning_rate=policy_lr, batch_size=512, n_random_steps=1000, min_acs=-2, max_acs=2, obs_nlags=n_lags, obs_expand_axis=0, obs_concat_axis=0, obs_padding='zeros' ) critic = CDQN(

critic_net=ContQMLP(obs_dim, acs_dim,hidden1=512,hidden2=512),

    critic_net=VLSTM(obs_dim, nlayers_lstm=2,
                     seq_len=1+n_lags,
                     nunits_lstm=32, nunits_dense=8),
    critic_net2=VLSTM(obs_dim, nlayers_lstm=2,
                     seq_len=1+n_lags,
                     nunits_lstm=32, nunits_dense=8),
    device=torch.device('cpu'),
    learning_rate=critic_lr,
    gamma=0.99,
    obs_nlags=n_lags,
    obs_expand_axis=0,
    obs_concat_axis=0,
    obs_padding='zeros',
    weight_decay=0.0
)`

lstm_parallel_sac.txt

1900360 commented 1 year ago

Hi @nslyubaykin Hope you can help me with this, after all I've been stuck with this code these days :(

nslyubaykin commented 1 year ago

Hi @1900360!

Could you please provide an error message itself, and the line which yields it?

1900360 commented 1 year ago

Hi @nslyubaykin! Sure, this error is here:

Traceback (most recent call last): File "D:/desktop/lunwen_dabao/xinsuanfa0912/lstm-rnns_ppo_0929-master/lags_for_pomdp_parallel_sac.py", line 296, in actor_logs = actor.update(replay_buffer) File "D:\desktop\lunwen_dabao\xinsuanfa0912\lstm-rnns_ppo_0929-master\relax\rl\actors.py", line 1593, in update sample=sample) File "D:\desktop\lunwen_dabao\xinsuanfa0912\lstm-rnns_ppo_0929-master\relax\rl\critics.py", line 1491, in _sac_update target=True).squeeze() File "D:\desktop\lunwen_dabao\xinsuanfa0912\lstm-rnns_ppo_0929-master\relax\rl\critics.py", line 1230, in forward return self.target_net(obs=obs,acs=acs) File "C:\Users\1900.conda\envs\tf\lib\site-packages\torch\nn\modules\module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) TypeError: forward() got an unexpected keyword argument 'obs'

And this line is here:

actor_logs = actor.update(replay_buffer)

I don't know whether my settings is correct. The actor network layer is NormalLSTM, and the critic network layer is VLSTM

actor = SAC( policy_net=NormalLSTM(obs_dim, acs_dim, nlayers_lstm=2, seq_len=1+n_lags, nunits_lstm=32, nunits_dense=8, out_activation=torch.nn.Identity(), init_log_std=-1.0), device=torch.device('cpu'), auto_tune_alpha=True, target_entropy=-acs_dim, learning_rate=policy_lr, batch_size=512, n_random_steps=1000, min_acs=-2, max_acs=2, obs_nlags=n_lags, obs_expand_axis=0, obs_concat_axis=0, obs_padding='zeros' )

critic = CDQN( critic_net=VLSTM(obs_dim, nlayers_lstm=2, seq_len=1+n_lags, nunits_lstm=32, nunits_dense=8), critic_net2=VLSTM(obs_dim, nlayers_lstm=2, seq_len=1+n_lags, nunits_lstm=32, nunits_dense=8), device=torch.device('cpu'), learning_rate=critic_lr, gamma=0.99, obs_nlags=n_lags, obs_expand_axis=0, obs_concat_axis=0, obs_padding='zeros', weight_decay=0.0 )

nslyubaykin commented 1 year ago

OK, so my recommendations to debug your code:

1) Do not use ParallelSampler for that task. Here we sample only 1 transition each training step (q-iteration-like setting), that is why there is no need for parallel training sampling here. In fact it will make the program even slower. Refer to this README "Parallel Sampling Takeaways" section. Consider just using Sampler here.

2) Do not use relax.zoo.critics.VLSTM with CDQN critic. "V" here stands for "Value", which means that this network is a value function but not a Q-function. This implies, that it is only state dependent, while Q-function is both state and action dependent. That is why you receive "unexpected keyword argument" error. Instead, use relax.zoo.critics.ContQMLP. It should be noted that this network does not utilize LSTM architecture by default and for now just flattens lagged observation into 1 dimensional array and feds it into a dense layer. Right now, there are no built-in LSTM-based continuous Q-function critic nets in this package. Still, you can create your own custom critic that will utilize LSTM architecture for observation array processing just by following relax.zoo.critics.ContQMLP interface while modifying network's architecture and its forward() method and it should work with CDQN critic.

3) Also, do not use NormalLSTM with SAC. It utilizes normal distribution with state independent action variance. While It should technically work, however, in original paper authors say that state independent action variance does not practically work with SAC. Consider using relax.zoo.policies.TanhNormalMLP which uses state dependent action variance in its architecture. This policy also uses obs flattening + dense layers only but, once again, you can create your own custom policy that will utilize LSTM architecture for observation array processing just by following relax.zoo.policies.TanhNormalMLP interface while modifying network's architecture and its forward() method and it should work with SAC.

!Note: SAC example may be a useful clarification for this.

So this settings should for now do the job without the usage of LSTM layers:

 # Actor & critic:
 actor = SAC(
        policy_net=TanhNormalMLP(obs_dim*(1+n_lags), #account for lags
                                 acs_dim),
        device=torch.device('cpu'),
        auto_tune_alpha=True,
        target_entropy=-acs_dim,
        learning_rate=policy_lr,
        batch_size=512,
        n_random_steps=1000,
        min_acs=-2,
        max_acs=2,
        obs_nlags=n_lags,
        obs_expand_axis=0,
        obs_concat_axis=0,
        obs_padding='zeros'
    )

 critic = CDQN(
        critic_net=ContQMLP(obs_dim*(1+n_lags),  acs_dim), #account for lags,
        critic_net2=ContQMLP(obs_dim*(1+n_lags),  acs_dim), #account for lags,
        device=torch.device('cpu'),
        learning_rate=critic_lr,
        gamma=0.99,
        obs_nlags=n_lags,
        obs_expand_axis=0,
        obs_concat_axis=0,
        obs_padding='zeros',
        weight_decay=0.0
    )

P.S. I had to update the package to remove annoying warning with distributions arg constraints, so, please, consider installing an updated version.

1900360 commented 1 year ago

Hi @nslyubaykin! Thank you for your help! After I will try to create your own custom critic that will utilize LSTM architecture. By the way, could you check the issue2 ? It seems like lstm+ppo not work well in Pendulum-v0.