Open Burstaholic opened 5 years ago
I had a similar problem with DeepQ and I made the some changes. A mask parameter is passed to the function. Then re-define the q_values. This is not a perfect method, but it works for my problems. Hope OpenAI can give a better solution later. This is the code:
observations_ph = make_obs_ph("observation")
stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic")
update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps")
mask_ph = tf.placeholder(shape=(1, num_actions), dtype=tf.int64, name="mask")
eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0))
q_values = q_func(observations_ph.get(), num_actions, scope="q_func")
small_constant = tf.constant( -999, dtype=tf.float32, shape=(1, num_actions) )
zeros = tf.zeros(shape=tf.shape(q_values), dtype=tf.int64)
masked_q_values = tf.where(tf.math.equal(zeros, mask_ph), q_values, small_constant)
deterministic_actions = tf.argmax(masked_q_values, axis=1)
batch_size = tf.shape(observations_ph.get())[0]
random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64)
chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps
stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions)
output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions)
update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps))
_act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph, mask_ph],
outputs=output_actions,
givens={update_eps_ph: -1.0, stochastic_ph: True, mask_ph:[[0,0,0]]},
updates=[update_eps_expr])
def act(ob, stochastic=True, update_eps=-1, mask=[[0,0,0]]):
return _act(ob, stochastic, update_eps, mask)
return act
I had a similar problem with DeepQ and I made the some changes. A mask parameter is passed to the function. Then re-define the q_values. This is not a perfect method, but it works for my problems. Hope OpenAI can give a better solution later. This is the code:
observations_ph = make_obs_ph("observation") stochastic_ph = tf.placeholder(tf.bool, (), name="stochastic") update_eps_ph = tf.placeholder(tf.float32, (), name="update_eps") mask_ph = tf.placeholder(shape=(1, num_actions), dtype=tf.int64, name="mask") eps = tf.get_variable("eps", (), initializer=tf.constant_initializer(0)) q_values = q_func(observations_ph.get(), num_actions, scope="q_func") small_constant = tf.constant( -999, dtype=tf.float32, shape=(1, num_actions) ) zeros = tf.zeros(shape=tf.shape(q_values), dtype=tf.int64) masked_q_values = tf.where(tf.math.equal(zeros, mask_ph), q_values, small_constant) deterministic_actions = tf.argmax(masked_q_values, axis=1) batch_size = tf.shape(observations_ph.get())[0] random_actions = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=num_actions, dtype=tf.int64) chose_random = tf.random_uniform(tf.stack([batch_size]), minval=0, maxval=1, dtype=tf.float32) < eps stochastic_actions = tf.where(chose_random, random_actions, deterministic_actions) output_actions = tf.cond(stochastic_ph, lambda: stochastic_actions, lambda: deterministic_actions) update_eps_expr = eps.assign(tf.cond(update_eps_ph >= 0, lambda: update_eps_ph, lambda: eps)) _act = U.function(inputs=[observations_ph, stochastic_ph, update_eps_ph, mask_ph], outputs=output_actions, givens={update_eps_ph: -1.0, stochastic_ph: True, mask_ph:[[0,0,0]]}, updates=[update_eps_expr]) def act(ob, stochastic=True, update_eps=-1, mask=[[0,0,0]]): return _act(ob, stochastic, update_eps, mask) return act
Could you show all detail changes to implement that action masking in deepq, including how to pass the mask vector from env to the agent? I have been stuck on this issue for a while... Your response will be greatly appreciated!
I'd like to try out some of these algorithms in board game environments that require mask+re-normalization of the action space to handle legal vs. illegal moves.
Is there a good way to do this in the provided implementations? If not, I'd like to make this a feature request. I think it shouldn't be too difficult to implement, and would make these easily usable in a large additional class of environments.
For example, I tried writing a simple Tic-Tac-Toe player with DeepQ. I looked for a way to use the provided
callback
to handle action legality, but couldn't figure out a good way to do it. Any suggestions are of course welcome. Thanks!