google-research / simclr

SimCLRv2 - Big Self-Supervised Models are Strong Semi-Supervised Learners
https://arxiv.org/abs/2006.10029
Apache License 2.0
4k stars 621 forks source link

Error in Generalized contrastive loss #162

Open aakash-saboo opened 2 years ago

aakash-saboo commented 2 years ago

Hi. Thank you so much for your work and its implementation. I have been trying to evaluate generalized contrastive loss between random normal vectors. The code works for small batches but gives error for large batches (say 100), the last line of the "sort" function (gather function) gives an error saying the index used by the "gather" function is out of bounds x = tf.gather(x, tf.cast(rank_inv, tf.int32), axis=-1, batch_dims=-1) the command I use is the following

for i in range(0,100):
    np_arr=np.random.randn(512,256)
    tf_x = tf.cast(tf.convert_to_tensor(np_arr), dtype=tf.float32)
    print(generalized_contrastive_loss(tf_x[0:256,:],tf_x[256:,:],hidden_norm=False,dist='normal'))
    # print(sort(tf_x))

traceback:


InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-22-0f3921532cb1> in <module>()
      4     tf_x = tf.cast(tf.convert_to_tensor(np_arr), dtype=tf.float32)
      5     # print(generalized_contrastive_loss(torch_x,hidden_norm=True,dist='logsumexp'))
----> 6     print(generalized_contrastive_loss(tf_x[0:50,:],tf_x[50:,:],hidden_norm=False,dist='normal'))
      7     # print(sort(tf_x))

10 frames
<ipython-input-2-777f2f5008ad> in generalized_contrastive_loss(hidden1, hidden2, lambda_weight, temperature, dist, hidden_norm, loss_scaling)
     31     loss_dist_match = get_swd_loss(hiddens, rand_w,
     32                             prior=dist,
---> 33                             hidden_norm=hidden_norm)
     34   return loss_scaling * (loss_align + lambda_weight * loss_dist_match)
     35 

<ipython-input-2-777f2f5008ad> in get_swd_loss(states, rand_w, prior, stddev, hidden_norm)
     59   states_shape = tf.shape(states)
     60   states = tf.matmul(states, rand_w)
---> 61   states_t = sort(tf.transpose(states))  # (dim, bsz)
     62 
     63   if prior == 'normal':

<ipython-input-2-777f2f5008ad> in sort(x)
     52       tf.transpose(tf.cast(tf.one_hot(rank, xshape[1]), tf.float32), [0, 2, 1]),
     53       tf.range(xshape[1], dtype='float32'))  # (dim, bsz)
---> 54   x = tf.gather(x, tf.cast(rank_inv, tf.int32), axis=-1, batch_dims=-1)
     55   return x
     56 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py in gather_v2(params, indices, validate_indices, axis, batch_dims, name)
   5073       name=name,
   5074       axis=axis,
-> 5075       batch_dims=batch_dims)
   5076 
   5077 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    547                 'in a future version' if date is None else ('after %s' % date),
    548                 instructions)
--> 549       return func(*args, **kwargs)
    550 
    551     doc = _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/dispatch.py in wrapper(*args, **kwargs)
    204     """Call target, and fall back on dispatchers if there is a TypeError."""
    205     try:
--> 206       return target(*args, **kwargs)
    207     except (TypeError, ValueError):
    208       # Note: convert_to_eager_tensor currently raises a ValueError, not a

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/array_ops.py in gather(***failed resolving arguments***)
   5050   if tensor_util.constant_value(axis) != 0:
   5051     return gen_array_ops.gather_v2(
-> 5052         params, indices, axis, batch_dims=batch_dims, name=name)
   5053   try:
   5054     # TODO(apassos) find a less bad way of detecting resource variables

/usr/local/lib/python3.7/dist-packages/tensorflow/python/ops/gen_array_ops.py in gather_v2(params, indices, axis, batch_dims, name)
   3806       return _result
   3807     except _core._NotOkStatusException as e:
-> 3808       _ops.raise_from_not_ok_status(e, name)
   3809     except _core._FallbackException:
   3810       pass

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6939   message = e.message + (" name: " + name if name is not None else "")
   6940   # pylint: disable=protected-access
-> 6941   six.raise_from(core._status_to_exception(e.code, message), None)
   6942   # pylint: enable=protected-access
   6943 

/usr/local/lib/python3.7/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: indices[20,2] = 130 is not in [0, 100) [Op:GatherV2]
aakash-saboo commented 2 years ago

Hi

This error will come when "sort" function is used. "sort" function (in its current definition), will only work when there are NO two elements which are greater than SAME number of elements i.e all the elements in any row of the input matrix is different.

def sort(x):
  """Returns the matrix x where each row is sorted (ascending)."""
  xshape = tf.shape(x)
  rank = tf.reduce_sum(
      tf.cast(tf.expand_dims(x, 2) > tf.expand_dims(x, 1), tf.int32), axis=2) 
chentingpc commented 2 years ago

The sort function was written this way so it runs on TPU. If you just use GPU, you should be able to do with it tf.sort, which probably doesn't have this problem.

aakash-saboo commented 2 years ago

@chentingpc Thank you so much for your response. Yes, tf.sort is working just fine.

But I liked the implementation of the sort function. Is there any way you can tell me whether this exact implementation was used for TPU or not? If yes, then I am wondering how come this problem did not come on TPU. Also the function works fine when we work in float64 instead of float32.

Thanks Aakash

chentingpc commented 2 years ago

This is the code that would run on TPU, and we didn't experience this issue. I wonder if there's any collapsing in the input features that lead to the same numbers in output. If that's the case, you can try to add a tiny bit of random noises in the input that breaks the tie.

aakash-saboo commented 2 years ago

@chentingpc Thanks a lot! will try that out.