hill-a / stable-baselines

A fork of OpenAI Baselines, implementations of reinforcement learning algorithms
http://stable-baselines.readthedocs.io/
MIT License
4.13k stars 724 forks source link

Training with recurrent cells from keras #161

Open ernestum opened 5 years ago

ernestum commented 5 years ago

Just in case anybody is interested, here is some example code that implements a policy based on the recurrent cells in tensorflow.python.keras.layers. Pros:

Cons:

Now that I wrote down this list, it seems quite stupid to use it in any way. I put lots of work in it. Maybe it is a starting point for something else.

I am thinking about putting it in the custom models section of the documentation. What do you guys think?

import numpy as np

import tensorflow as tf
from stable_baselines.a2c.utils import linear
from stable_baselines.common.policies import register_policy, LstmPolicy
from tensorflow.python.keras.layers import SimpleRNNCell, LSTMCell, GRUCell

def batch_to_seq(tensor_batch, n_batch, n_steps):
    return tf.unstack(tf.transpose(tf.reshape(tensor_batch, [n_batch, n_steps, -1]), [1, 0, 2]))

def batch_to_ta(tensor_batch, n_batch, n_steps):
    return tf.TensorArray(tf.float32, size=n_steps, dynamic_size=False).unstack(tf.transpose(tf.reshape(tensor_batch, [n_batch, n_steps, -1]), [1, 0, 2]))

def seq_to_batch(tensor_sequence):
    shape = tensor_sequence[0].get_shape().as_list()
    assert len(shape) > 1
    n_hidden = tensor_sequence[0].get_shape()[-1].value
    return tf.reshape(tf.concat(axis=1, values=tensor_sequence), [-1, n_hidden])

def ta_to_batch(tensor_array):
    return tf.reshape(tensor_array.stack(), [-1, 128])

class RecurrentPolicy(LstmPolicy):
    rnn_cell = None

    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, n_lstm=128, reuse=False, cell_type=LSTMCell, unroll_network=True, **_kwargs):
        super(LstmPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse,
                                         scale=False)

        act_fun = tf.nn.relu

        if RecurrentPolicy.rnn_cell is None and not reuse:
            with tf.variable_scope("model", reuse=True):
                RecurrentPolicy.rnn_cell = cell_type(n_lstm, name="lstm1")
        cell_state_size = RecurrentPolicy.rnn_cell.state_size
        if isinstance(cell_state_size, int):
            cell_state_size = [cell_state_size]

        with tf.variable_scope("input", reuse=True):
            self.masks_ph = tf.placeholder(tf.float32, [n_batch], name="masks_ph")
            self.states_ph = tf.placeholder(tf.float32, [n_env, sum(cell_state_size)], name="states_ph")

        with tf.variable_scope("model", reuse=reuse):

            initial_state = tf.get_variable(RecurrentPolicy.rnn_cell.get_initial_state(batch_size=n_env, dtype=tf.float32))
            # build shared network part consisting of parallel lstm and linear layer
            latent = tf.layers.flatten(self.processed_obs)
            state = tf.split(self.states_ph, cell_state_size, axis=1)

            if unroll_network:
                rnn_outputs = []
                masks = batch_to_seq(self.masks_ph, self.n_env, n_steps)
                input_sequence = batch_to_seq(latent, self.n_env, n_steps)
                with tf.variable_scope("my_lstm_scope"):
                    for mask, input in zip(masks, input_sequence):
                        state = mask * initial_state + (1 - mask) * state
                        output, state = RecurrentPolicy.rnn_cell(input, state)
                        rnn_outputs.append(output)
                latent = seq_to_batch(rnn_outputs)
            else:
                input_ta = batch_to_ta(latent, self.n_env, n_steps)
                mask_ta = batch_to_ta(self.masks_ph, self.n_env, n_steps)
                output_ta = tf.TensorArray(tf.float32, size=n_steps, dynamic_size=False)

                def loop_cond(idx, state, input_ta):
                    return idx < input_ta.size()

                def loop_body(idx, state, output_ta):
                    mask = mask_ta.read(idx)
                    input = input_ta.read(idx)
                    state = mask * initial_state + (1 - mask) * state
                    output, state = RecurrentPolicy.rnn_cell(input, state)
                    output_ta = output_ta.write(idx, output)
                    return idx+1, state, output_ta

                idx, state, output_ta = tf.while_loop(loop_cond, loop_body, (0, state, output_ta))
                latent = ta_to_batch(output_ta)

            self.snew = tf.concat(state, axis=1)

            # build deeper value network part
            latent_value = act_fun(linear(latent, "vf_fc1", 128, init_scale=np.sqrt(2)))
            latent_value = act_fun(linear(latent_value, "vf_fc2", 256, init_scale=np.sqrt(2)))

            self.value_fn = linear(latent_value, 'vf', 1)
            self.proba_distribution, self.policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(latent, latent_value)
        self.initial_state = np.zeros(self.states_ph.shape.as_list(), dtype=np.float32)
        self._setup_init()

register_policy("RecurrentPolicy", RecurrentPolicy)
araffin commented 5 years ago

I am thinking about putting it in the custom models section of the documentation. What do you guys think?

It seems quite advanced to put it in the user guide. Or maybe in a new "advanced" section? However, a simple example of using custom Keras policy could be a could idea (for instance a MlpPolicy).

ernestum commented 5 years ago

Yes maybe there can be an advanced section. But I remember the above example to be somewhat buggy. So we can not just plain copy paste. We need to test it again first.