tensorflow / similarity

TensorFlow Similarity is a python package focused on making similarity learning quick and easy.
Apache License 2.0
1.01k stars 104 forks source link

TPU Error while trying to train SimCLR contrastive model on Colab #240

Open kechan opened 2 years ago

kechan commented 2 years ago

I followed closely the hello world notebook, and added in required setup for TPU, and tried to run on google colab.

Got some errors at contrastive_model.fit(...):

[/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py] in _numpy(self)
   1189       return self._numpy_internal()
   1190     except core._NotOkStatusException as e:  # pylint: disable=protected-access
-> 1191       raise core._status_to_exception(e) from None  # pylint: disable=protected-access
   1192 
   1193   @property

InvalidArgumentError: 9 root error(s) found.
  (0) INVALID_ARGUMENT: {{function_node __inference_train_function_128376}} Input 1 to node `StatefulPartitionedCall/simclr/PartitionedCall/one_hot` with op OneHot must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

     [[{{function_node __forward_call_76175}}{{node one_hot}}]]
     [[TPUReplicate/_compile/_15029931604347861884/_4]]
  (1) INVALID_ARGUMENT: {{function_node __inference_train_function_128376}} Input 1 to node `StatefulPartitionedCall/simclr/PartitionedCall/one_hot` with op OneHot must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

     [[{{function_node __forward_call_76175}}{{node one_hot}}]]
     [[TPUReplicate/_compile/_15029931604347861884/_4]]
     [[tpu_compile_succeeded_assert/_16637135070362013654/_5/_63]]
  (2) INVALID_ARGUMENT: {{function_node __inference_train_function_128376}} Input 1 to node `StatefulPartitionedCall/simclr/PartitionedCall/one_hot` with op OneHot must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

     [[{{function_node __forward_call_76175}}{{node one_hot}}]]
     [[TPUReplicate/_compile/_15029931604347861884/_4]]
     [[tpu_compile_succeeded_assert/_16637135070362013654/_5/_47]]
  (3) INVALID_ARGUMENT: {{function_node __inference_train_function_128376}} Input 1 to node `StatefulPartitionedCall/simclr/PartitionedCall/one_hot` with op OneHot must be a compile-time constant.

XLA compilation requires that operator arguments that represent shapes or dimensions be evaluated to concrete values at compile time. This error means that a shape or dimension argument could not be evaluated at compile time, usually because the value of the argument depends on a parameter to the computation, on a variable, or on a stateful operation such as a random number generator.

     [[{{function_node __forward_call_76175}}{{node one_hot}}]]
     [[TPUReplicate/_compile/_15029931604347861884/_4]]
     [[tpu_compile_succeeded_assert/_16637135070362013654/_5/_143]]
  (4) INVALID_ARGUMENT: {{function_node __inference_ ... [truncated]

Best solution so far here (if issue TL;DR):

Do this for your dataset will solve this:

.batch(BATCH_SIZE, drop_remainder=True)

If you are interested, see my notebook here.

kechan commented 2 years ago

I think this is pretty significant for SimCLR, since TPU can dramatically speed up training and allow the large batch size these methods required without exhausting accelerator memory too soon.

owenvallis commented 2 years ago

I think it doesn't like the use of tf.range in the one_hot call. I could see where that would be stateful if if is implemented as a generator. Can you try the following instead?

diag = tf.linalg.diag(tf.ones(batch_size))

....

labels = tf.linalg.diag(tf.ones(batch_size), num_cols=batch_size*2)
kechan commented 2 years ago

@owenvallis Yes, this seems to make sense. Let me make your suggested change and try.

kechan commented 2 years ago

Unfortunately, tf.linalg.diag still got error

StatefulPartitionedCall/simclr/PartitionedCall/diag_1 with op MatrixDiagV3 must be a compile-time constant.

Actually, batch_size is also not known during compile time. Maybe next thing i can try is just hardcode it.

I am now a bit surprised at the broader limitation on TPU/TF this seemed to imply. I read about a more recent framework JAX that is supposed to run on TPU or GPU without code change. I haven't explored JAX in depth but it would be a major inconvenience to say the least if it also happens there. I just aint sure at what layer this is limiting it.

kechan commented 2 years ago

@owenvallis

I hardcoded the batch_size to 32 (for testing). Next error i got is that it anticipated 4x4 (while 32x32 was passed) in "aa = aa - diag * LARGE_NUM".

Now, i realized colab assigned me 8 TPU cores, so I am guessing it has distributed the batch evenly across, resulting in 4 per core. So i changed my hardcode to 4, and TPU is finally happy. Since my batch_size and # of cores are both unlikely to change, i can probably live with hardcoding them as constants. I will give it a proper round of training and hope there's no more errors.

But this is too much hacking around. If you have better solutions, please let me know.

kechan commented 2 years ago

Ok, also need .batch(BATCH_SIZE, drop_remainder=True)

to ensure batch_size is same throughout training. After that, training proceeded without any error (incl. validation and saving checkpoints back on GCS).

With all errors and apparent limitations, the convenience and cost-effectiveness running large batch at faster than GPU on colab is worth a try. Feel free to share your experience on this thread if you want to try this.

@owenvallis I can gladly share a version of my notebook later (removing any sensitive things of course), following the hello world notebook format. Let me know.

owenvallis commented 2 years ago

Glad to hear you have a work around for now. I can run some TPU tests next week. If you can share your notebook, that would be super helpful.

Thanks!

kechan commented 2 years ago

@owenvallis Running on TPU seems to require tf dataset coming directly from TFRecords stored on GCS (the latest i knew), so instantiating from .from_tensor_slices(x_train) are likely not to work. This will involve some uglier boilerplate code to parse tfrecord to construct a tf.data.Dataset. I plan to use this opportunity to consolidate this into a friendly helper utility to hide all that lower detail stuff. So sharing the notebook will take a bit longer.

Update: I could be wrong about this statement, maybe it doesn't care where it comes from if it is "in memory", like x_train is. Let me give this a try. However, in my own work, my dataset won't fit in mem, so have to use TFRecords stationed on cloud bucket.

kechan commented 2 years ago

It looks like TPU is not friendly with TFDS + colab, it refuses to read anything tf dataset from local dir. So in order to reuse the same CIFAR10 dataset, i will have to deal with some TFRecord manipulation. I think i will just place the final notebook + the helper python in a new GitHub project and share.

kechan commented 2 years ago

I actually found out all the hardcoding hack mentioned in the loss class is not needed provided you do this to your dataset:

.batch(BATCH_SIZE, drop_remainder=True)

So the compiler probably did an optimization recognizing each batch must have exactly BATCH_SIZE examples, and therefore considered known at XLA compilation time.

kechan commented 2 years ago

Here's my TPU notebook on GitHub as promised. No special TFRecord handling needed it turned out, you can direct TFDS to download to a cloud bucket, and that solved it.

@owenvallis Please let me know if it works for you.