renmengye / few-shot-ssl-public

Meta Learning for Semi-Supervised Few-Shot Classification
Other
552 stars 100 forks source link

What does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) mean? #16

Open jinghanSunn opened 4 years ago

jinghanSunn commented 4 years ago

In clustering, I don't understand this code:

# Run clustering.
for tt in range(num_cluster_steps):
      protos_1 = tf.expand_dims(protos, 2)
      protos_2 = tf.expand_dims(h_unlabel, 1)
      pair_dist = tf.reduce_sum((protos_1 - protos_2)**2, [3])  # [B, K, N]
      m_dist = tf.reduce_mean(pair_dist, [2])  # [B, K]
      m_dist_1 = tf.expand_dims(m_dist, 1)  # [B, 1, K]
      m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0))

Does m_dist_1 += tf.to_float(tf.equal(m_dist_1, 0.0)) mean that if the distance from the center of the cluster is 1 then add 1. But why add 1?

renmengye commented 4 years ago

This is to prevent it to be zero. So it will be changed to 1.0 when it's 0.0