hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.14k stars 723 forks source link

How to create an actor-critic network with two separate LSTMs #1177

Open ashleychung830 opened 1 year ago

ashleychung830 commented 1 year ago

Hi, I was wondering if it's possible to have an actor-critic network with two separate LSTMs, where one LSTM outputs value (critic) and one LSTM outputs actions (actor)? Similar to #1002 but the two LSTMs would be receiving the same input from CNN layers.

Based on the LSTMPolicy source code, the net_arch parameter can only take one 'LSTM' occurrence and LSTM's are only supported in the shared part of the policy network.

# Build the non-shared part of the policy-network
                latent_policy = latent
                for idx, pi_layer_size in enumerate(policy_only_layers):
                    if pi_layer_size == "lstm":
                        raise NotImplementedError("LSTMs are only supported in the shared part of the policy network.")
                    assert isinstance(pi_layer_size, int), "Error: net_arch[-1]['pi'] must only contain integers."
                    latent_policy = act_fun(
                        linear(latent_policy, "pi_fc{}".format(idx), pi_layer_size, init_scale=np.sqrt(2)))

                # Build the non-shared part of the value-network
                latent_value = latent
                for idx, vf_layer_size in enumerate(value_only_layers):
                    if vf_layer_size == "lstm":
                        raise NotImplementedError("LSTMs are only supported in the shared part of the value function "
                                                  "network.")
                    assert isinstance(vf_layer_size, int), "Error: net_arch[-1]['vf'] must only contain integers."
                    latent_value = act_fun(
                        linear(latent_value, "vf_fc{}".format(idx), vf_layer_size, init_scale=np.sqrt(2)))

I was wondering if there is a way around this, or if the implementation of this policy class would simply not support having two LSTMs? I tried writing code to get around this, but the tensorboard model graph for the resulting network has looked funky that I might just switch to stable baselines 3.