hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

[question] How to get the model architecture when using recurrent policy? #1145

Closed borninfreedom closed 2 years ago

borninfreedom commented 2 years ago

When I run

from stable_baselines import PPO2
from stable_baselines.common.policies import LstmPolicy

class CustomLSTMPolicy(LstmPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=64, reuse=False, **_kwargs):
        super().__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm, reuse,
                         net_arch=[8, 'lstm', dict(vf=[5, 10], pi=[10])],
                         layer_norm=True, feature_extraction="mlp", **_kwargs)

model = PPO2(CustomLSTMPolicy, 'CartPole-v1',nminibatches=1, verbose=1)
print(model.policy)
print(model.policy_kwargs)

The output is

__main__.CustomLSTMPolicy
{}

Why can't I get the model architecture? If I use stable baselines3, It works writing like this.

Miffyli commented 2 years ago

I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement)

araffin commented 2 years ago

and for LSTM with SB3, you can take a look at that comment: https://github.com/DLR-RM/stable-baselines3/issues/18#issuecomment-979338510

borninfreedom commented 2 years ago

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

Does the ppo-lstm work properly now?

araffin commented 2 years ago

Does the ppo-lstm work properly now?

It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).

borninfreedom commented 2 years ago

Does the ppo-lstm work properly now?

It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).

OK, that's cool. But why the feat/ppo-lstm branch doesn't merge to the master branch now?

borninfreedom commented 2 years ago

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

How to resolve the error?

Traceback (most recent call last):
  File "test_lstm.py", line 132, in <module>
    test_cnn()
  File "test_lstm.py", line 50, in test_cnn
    model.learn(total_timesteps=32)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
    actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
    latent_pi = self.mlp_extractor.forward_actor(latent_pi)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'
araffin commented 2 years ago

and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)

How to resolve the error?

Traceback (most recent call last):
  File "test_lstm.py", line 132, in <module>
    test_cnn()
  File "test_lstm.py", line 50, in test_cnn
    model.learn(total_timesteps=32)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
    continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
    actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
    latent_pi = self.mlp_extractor.forward_actor(latent_pi)
  File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
    type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'

this issue is not the place for that. you need master version of SB3.

It is not merged with master because we need to polish it (doc, tests, comments,...)

charlo1998 commented 1 year ago

I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement) @Miffyli Can you, by any chance, point to resources to "print out the model" with stable-baselines2? it's been pretty hard to find anything in the docs. Thanks