Open kechan opened 2 years ago
I think this is pretty significant for SimCLR, since TPU can dramatically speed up training and allow the large batch size these methods required without exhausting accelerator memory too soon.
I think it doesn't like the use of tf.range in the one_hot call. I could see where that would be stateful if if is implemented as a generator. Can you try the following instead?
diag = tf.linalg.diag(tf.ones(batch_size))
....
labels = tf.linalg.diag(tf.ones(batch_size), num_cols=batch_size*2)
@owenvallis Yes, this seems to make sense. Let me make your suggested change and try.
Unfortunately, tf.linalg.diag still got error
StatefulPartitionedCall/simclr/PartitionedCall/diag_1 with op MatrixDiagV3 must be a compile-time constant.
Actually, batch_size is also not known during compile time. Maybe next thing i can try is just hardcode it.
I am now a bit surprised at the broader limitation on TPU/TF this seemed to imply. I read about a more recent framework JAX that is supposed to run on TPU or GPU without code change. I haven't explored JAX in depth but it would be a major inconvenience to say the least if it also happens there. I just aint sure at what layer this is limiting it.
@owenvallis
I hardcoded the batch_size to 32 (for testing). Next error i got is that it anticipated 4x4 (while 32x32 was passed) in "aa = aa - diag * LARGE_NUM".
Now, i realized colab assigned me 8 TPU cores, so I am guessing it has distributed the batch evenly across, resulting in 4 per core. So i changed my hardcode to 4, and TPU is finally happy. Since my batch_size and # of cores are both unlikely to change, i can probably live with hardcoding them as constants. I will give it a proper round of training and hope there's no more errors.
But this is too much hacking around. If you have better solutions, please let me know.
Ok, also need .batch(BATCH_SIZE, drop_remainder=True)
to ensure batch_size is same throughout training. After that, training proceeded without any error (incl. validation and saving checkpoints back on GCS).
With all errors and apparent limitations, the convenience and cost-effectiveness running large batch at faster than GPU on colab is worth a try. Feel free to share your experience on this thread if you want to try this.
@owenvallis I can gladly share a version of my notebook later (removing any sensitive things of course), following the hello world notebook format. Let me know.
Glad to hear you have a work around for now. I can run some TPU tests next week. If you can share your notebook, that would be super helpful.
Thanks!
@owenvallis Running on TPU seems to require tf dataset coming directly from TFRecords stored on GCS (the latest i knew), so instantiating from .from_tensor_slices(x_train) are likely not to work. This will involve some uglier boilerplate code to parse tfrecord to construct a tf.data.Dataset. I plan to use this opportunity to consolidate this into a friendly helper utility to hide all that lower detail stuff. So sharing the notebook will take a bit longer.
Update: I could be wrong about this statement, maybe it doesn't care where it comes from if it is "in memory", like x_train is. Let me give this a try. However, in my own work, my dataset won't fit in mem, so have to use TFRecords stationed on cloud bucket.
It looks like TPU is not friendly with TFDS + colab, it refuses to read anything tf dataset from local dir. So in order to reuse the same CIFAR10 dataset, i will have to deal with some TFRecord manipulation. I think i will just place the final notebook + the helper python in a new GitHub project and share.
I actually found out all the hardcoding hack mentioned in the loss class is not needed provided you do this to your dataset:
.batch(BATCH_SIZE, drop_remainder=True)
So the compiler probably did an optimization recognizing each batch must have exactly BATCH_SIZE examples, and therefore considered known at XLA compilation time.
I followed closely the hello world notebook, and added in required setup for TPU, and tried to run on google colab.
Got some errors at contrastive_model.fit(...):
Best solution so far here (if issue TL;DR):
Do this for your dataset will solve this:
.batch(BATCH_SIZE, drop_remainder=True)
If you are interested, see my notebook here.