muhd-umer / pvt-flax

Unofficial JAX/Flax implementation of Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions.
https://arxiv.org/abs/2102.12122
MIT License
2 stars 0 forks source link

Distributed training on TPU unfunctional #4

Closed muhd-umer closed 2 years ago

muhd-umer commented 2 years ago

Currently, trying to train on multiple cloud TPUs results in the training being stuck at 0%.

JAX Process: 0 / 1
JAX Local Devices: [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
[15:59:22] Epoch: 1
           Training:   0%|                              | 0/937 [00:00<?, ?it/s]

It is highly likely that the issue is in the TF record dataset building pipeline, but I couldn't definitively single out a root cause. Note: Training works just fine on GPU.

muhd-umer commented 2 years ago

It seems that the issue was with internal Colaboratory TPU driver rather than the code itself. Since distributed training is working as expected, I am closing this issue for now.