zhen8838 / Circle-Loss

Tensorflow2 implementation of CircleLoss. Support class-level, sparse class-level, pair-wise labels
MIT License
108 stars 40 forks source link

potential bug in batch #1

Open AddASecond opened 4 years ago

AddASecond commented 4 years ago

when i use keras with tensorflow's ImageDataGenerator, during model.fit, following shape mismatch happens. my batchsize is 128

tensorflow.python.framework.errors_impl.InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  ConcatOp : Dimensions of inputs should match: shape[0] = [128,1] vs. shape[1] = [117,1]
     [[node loss/dense_40_loss/concat (defined at media/src/models/circle_loss.py:129) ]]
     [[Shape_12/_148]]
  (1) Invalid argument:  ConcatOp : Dimensions of inputs should match: shape[0] = [128,1] vs. shape[1] = [117,1]
     [[node loss/dense_40_loss/concat (defined at mediar/src/models/circle_loss.py:129) ]]
zhen8838 commented 4 years ago

In order to speed up the loss calculation, I set batch_idxs to a constant value.

    if batch_size:
      self.batch_size = batch_size
      self.batch_idxs = tf.expand_dims(
          tf.range(0, batch_size, dtype=tf.int32), 1)  # shape [batch,1]

If you need support dynamic batch_size, you can use code as follow:

class SparseCircleLoss(CircleLoss):

  def call(self, y_true: tf.Tensor, y_pred: tf.Tensor) -> tf.Tensor:
    batch_size=tf.shape(y_true)[0]
    batch_idxs = tf.expand_dims(
          tf.range(0, batch_size, dtype=tf.int32), 1)
    idxs = tf.concat([batch_idxs, tf.cast(y_true, tf.int32)], 1)
    sp = tf.expand_dims(tf.gather_nd(y_pred, idxs), 1)
    mask = tf.logical_not(
        tf.scatter_nd(idxs, tf.ones(tf.shape(idxs)[0], tf.bool),
                      tf.shape(y_pred)))

    sn = tf.reshape(tf.boolean_mask(y_pred, mask), (batch_size, -1))

    alpha_p = tf.nn.relu(self.O_p - tf.stop_gradient(sp))
    alpha_n = tf.nn.relu(tf.stop_gradient(sn) - self.O_n)

    r_sp_m = alpha_p * (sp - self.Delta_p)
    r_sn_m = alpha_n * (sn - self.Delta_n)
    _Z = tf.concat([r_sn_m, r_sp_m], 1)
    _Z = _Z * self.gamma
    # sum all similarity
    logZ = tf.math.reduce_logsumexp(_Z, 1, keepdims=True)
    # remove sn_p from all sum similarity
    return -r_sp_m * self.gamma + logZ