Hello, my issue concerns the usage of net_acrh parameter inside LstmPolicy. This will helps to implement custom CnnLstmPolicy.
Now, LstmPolicy from stable_baselines.common.policies has following code with NotImplementedError() when net_arch is not None:
class LstmPolicy(RecurrentActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=256, reuse=False, layers=None,
net_arch=None, act_fun=tf.tanh, cnn_extractor=nature_cnn, layer_norm=False, feature_extraction="cnn",
**kwargs):
# state_shape = [n_lstm * 2] dim because of the cell and hidden states of the LSTM
super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch,
state_shape=(2 * n_lstm, ), reuse=reuse,
scale=(feature_extraction == "cnn"))
self._kwargs_check(feature_extraction, kwargs)
if net_arch is None: # Legacy mode
if layers is None:
layers = [64, 64]
else:
warnings.warn("The layers parameter is deprecated. Use the net_arch parameter instead.")
with tf.variable_scope("model", reuse=reuse):
if feature_extraction == "cnn":
extracted_features = cnn_extractor(self.processed_obs, **kwargs)
else:
extracted_features = tf.layers.flatten(self.processed_obs)
for i, layer_size in enumerate(layers):
extracted_features = act_fun(linear(extracted_features, 'pi_fc' + str(i), n_hidden=layer_size,
init_scale=np.sqrt(2)))
input_sequence = batch_to_seq(extracted_features, self.n_env, n_steps)
masks = batch_to_seq(self.dones_ph, self.n_env, n_steps)
rnn_output, self.snew = lstm(input_sequence, masks, self.states_ph, 'lstm1', n_hidden=n_lstm,
layer_norm=layer_norm)
rnn_output = seq_to_batch(rnn_output)
value_fn = linear(rnn_output, 'vf', 1)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(rnn_output, rnn_output)
self._value_fn = value_fn
else: # Use the new net_arch parameter
if layers is not None:
warnings.warn("The new net_arch parameter overrides the deprecated layers parameter.")
if feature_extraction == "cnn":
raise NotImplementedError()
....
So, the solution is simple, do it in analogy to net_arch==None case. just use provided cnn_extractor in a case of "cnn" feature_extraction to preprocess input images to flatten layer that would go next through net_arch layers.
Hello, my issue concerns the usage of
net_acrh
parameter insideLstmPolicy
. This will helps to implement customCnnLstmPolicy
.Now,
LstmPolicy
fromstable_baselines.common.policies
has following code withNotImplementedError()
whennet_arch
is not None:So, the solution is simple, do it in analogy to
net_arch==None
case. just use providedcnn_extractor
in a case of "cnn"feature_extraction
to preprocess input images to flatten layer that would go next throughnet_arch
layers.