hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

[question]A problem of how to use MlpLstmPolicy in GAIL training? #1148

Closed LongchaoDa closed 2 years ago

LongchaoDa commented 2 years ago

I was training a GAIL model with MlpLstmPolicy in Stable_Baselines2, however, I could not successfully run the training process even though: I made the

assert issubclass(self.policy, LstmPolicy)

in the TRPO part, Is there any changes I should make? Or If there is no other possible solution, how can i customize a LSTM policy to be compatible with GAIL for myself?

Looking forward to your reply!

The error happened is here:

Traceback (most recent call last): File "/home/.../train-recurrentGail.py", line 18, in <module>
    model.learn(total_timesteps=100000)
  File "/home/.../model/gail/model.py", line 57, in learn
    return super().learn(total_timesteps, callback, log_interval, tb_log_name, reset_num_timesteps)
  File "/home/.../model/gail/myTrpo.py", line 364, in learn
    seg = seg_gen.__next__()
  File "/home/.../common/runners.py", line 118, in traj_segment_generator
    action, vpred, states, _ = policy.step(observation.reshape(-1, *observation.shape), states, done)
  File "/home/.../stable_baselines/common/policies.py", line 508, in step
    {self.obs_ph: obs, self.states_ph: state, self.dones_ph: mask})
  File "/home/.../python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/home/.../python/client/session.py", line 1156, in _run
    (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape () for Tensor 'input_1/dones_ph:0', which has shape '(1,)'
Miffyli commented 2 years ago

LSTM is not supported in the pretraining. There was a PR adding this but it has then died out: https://github.com/hill-a/stable-baselines/pull/315 . For better imitation learning algos see imitation library, for example. Note that SB2 is not mantained anymore.

You may close this issue if you have no further questions.

LongchaoDa commented 2 years ago

Thank you, so you mean the structure of Stable Baseline 2 is difficult to support custom policy LSTM into GAIL traning? Maybe i will turn to the link you mentioned to explore more info.

Miffyli commented 2 years ago

Thank you, so you mean the structure of Stable Baseline 2 is difficult to support custom policy LSTM into GAIL traning? Yes it requires a bit of work (you may check the PR #315).

Closing issue as resolved :).