It is highly likely that the issue is in the TF record dataset building pipeline, but I couldn't definitively single out a root cause.
Note: Training works just fine on GPU.
It seems that the issue was with internal Colaboratory TPU driver rather than the code itself. Since distributed training is working as expected, I am closing this issue for now.
Currently, trying to train on multiple cloud TPUs results in the training being stuck at 0%.
It is highly likely that the issue is in the TF record dataset building pipeline, but I couldn't definitively single out a root cause. Note: Training works just fine on GPU.