Closed borninfreedom closed 2 years ago
I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement)
and for LSTM with SB3, you can take a look at that comment: https://github.com/DLR-RM/stable-baselines3/issues/18#issuecomment-979338510
and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)
Does the ppo-lstm work properly now?
Does the ppo-lstm work properly now?
It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).
Does the ppo-lstm work properly now?
It has been tested yes, only the dict obs support is missing. But I would sill recommend to try frame-stacking first (and use it together with the lstm).
OK, that's cool. But why the feat/ppo-lstm branch doesn't merge to the master branch now?
and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)
How to resolve the error?
Traceback (most recent call last):
File "test_lstm.py", line 132, in <module>
test_cnn()
File "test_lstm.py", line 50, in test_cnn
model.learn(total_timesteps=32)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn
continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts
actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward
latent_pi = self.mlp_extractor.forward_actor(latent_pi)
File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__
type(self).__name__, name))
AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'
and for LSTM with SB3, you can take a look at that comment: DLR-RM/stable-baselines3#18 (comment)
How to resolve the error?
Traceback (most recent call last): File "test_lstm.py", line 132, in <module> test_cnn() File "test_lstm.py", line 50, in test_cnn model.learn(total_timesteps=32) File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 496, in learn continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/ppo_lstm/ppo_lstm.py", line 276, in collect_rollouts actions, values, log_probs, lstm_states = self.policy.forward(obs_tensor, lstm_states, episode_starts) File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/sb3_contrib-1.3.1a3-py3.7.egg/sb3_contrib/common/recurrent/policies.py", line 189, in forward latent_pi = self.mlp_extractor.forward_actor(latent_pi) File "/home/yan/miniconda3/envs/sb3_env/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1131, in __getattr__ type(self).__name__, name)) AttributeError: 'MlpExtractor' object has no attribute 'forward_actor'
this issue is not the place for that. you need master version of SB3.
It is not merged with master because we need to polish it (doc, tests, comments,...)
I'd recommend sticking with stable-baselines3 as it is more updated and TF1 is really outdated by this point. The (boring) answer is that the code structure is different and that TF models work in different ways than PyTorch ones (IIRC, to "print out the model" in TF1 was more difficult than just a simple print statement) @Miffyli Can you, by any chance, point to resources to "print out the model" with stable-baselines2? it's been pretty hard to find anything in the docs. Thanks
When I run
The output is
Why can't I get the model architecture? If I use stable baselines3, It works writing like this.