Closed Sanqiang closed 3 years ago
We trained the 11B models on v3-8 TPUs with model parallelism.
Could I know which library for model parallelism you are using? It looks like the codes in the Github only contains BART code (huggingface code), just curious how you train the T5 model, especially with model parallelism.
Meanwhile, the paper mentioned you used V3-8 TPU (I think it is 8 16GB memory GPUs). It takes 36 hours to pre-train such a model for 100 k steps (mini-batch is 8). My training speed is far slower than yours, just curious how's your config for T5-11B(3B) model
I tried to reproduce the model with A100 (40GB) GPU, but it cannot fit T5-3B/11B without model parallel or Deepspeed? I am wondering how you fit the large model into TPU? are you using half precision (fp16) or something else?