miyosuda / async_deep_reinforce

Asynchronous Methods for Deep Reinforcement Learning
Apache License 2.0
592 stars 192 forks source link

Would like to convert this to Keras but having issues #26

Open rnunziata opened 7 years ago

rnunziata commented 7 years ago

Would like to convert this to Keras but having issues ...has anyone tried this?


 with tf.device(self._device): 

        #'uniform', 'glorot_normal', 'glorot_uniform','lecun_uniform',lambda shape, name: normal(shape, scale=0.01, name=name)
        init = 'lecun_uniform'

        K.set_image_dim_ordering('tf')
        self.s = Input(shape=(NETWORK_IMAGE_HIEGHT, NETWORK_IMAGE_WIDTH, NETWORK_IMAGE_CHANNELS))

        K.set_learning_phase(learning_phase)

        # Convolution2D  args:  number of filters, filter size row col, stride dim, padding type
        shared = Convolution2D(16, 8, 8, name="conv1",  subsample=(4,4), activation='relu', border_mode='same', init=init)(self.s) 

        shared = Convolution2D(32, 4, 4, name="conv2", subsample=(2,2), activation='relu', border_mode='same', init=init)(shared)

        shared = Flatten()(shared)
        shared = Dense(name="h1", output_dim=256, activation='relu', init=init)(shared)

        self.pi =Dense(name="p", output_dim=action_size, activation='softmax', init=init)(shared)

        self.v = Dense(name="v", output_dim=1, activation='linear', init=init)(shared)

        self.policy_network = Model(input=self.s, output=self.pi)
        self.value_network  = Model(input=self.s, output=self.v)

        self.p_params = self.policy_network.trainable_weights
        self.v_params = self.value_network.trainable_weights

        self.p_out = self.policy_network(self.s)
        self.v_    = self.value_network(self.s)
        self.v_out = tf.reshape( self.v_, [-1] )

  def run_policy_and_value(self, sess, s_t): 
    pi_out, v_out = sess.run( [self.p_out, self.v_out], feed_dict = {self.s : [s_t]} )
    return (pi_out[0], v_out[0])

  def run_policy(self, sess, s_t):
    probs = self.p_out.eval(session = sess, feed_dict = {self.s : [s_t]})[0]
    return probs

  def run_value(self, sess, s_t):
    values = self.v_out.eval(session = sess, feed_dict = {self.s : [s_t]})[0]
    return values

  def get_vars(self):
    #return self.p_params
    return self.v_params