ericjang / gumbel-softmax

categorical variational autoencoder using the Gumbel-Softmax estimator
MIT License
425 stars 101 forks source link

why hard sampling should use stop_gradient ? #12

Open kelvinleen opened 4 years ago

kelvinleen commented 4 years ago

def gumbel_softmax(logits, temperature, hard=False): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar hard: if True, take argmax, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes """ y = gumbel_softmax_sample(logits, temperature) if hard: k = tf.shape(logits)[-1]

y_hard = tf.cast(tf.one_hot(tf.argmax(y,1),k), y.dtype)

y_hard = tf.cast(tf.equal(y,tf.reduce_max(y,1,keep_dims=True)),y.dtype)
y = tf.stop_gradient(y_hard - y) + y

return y

i have a question here, why should there be a stop_gradient before y_hard-y. y_hard comes from equal,as i think,the equal function could be backpropagated just as max function did。 am i wrong?