Open aakashrkumar opened 1 year ago
Hi, I've been planning to train this model, I have a tpu pod(v3-128) through trc, which should equate to ~ 5 tb of ram and 2 tb of vram, I had a few questions about how to begin training the model.
- What would be the appropriate dataset to train on? I was currently considering using the pile for pre training, but gathering human feedback for rlhf still seems like a challenge.
- How large of a model would you recommend?
- I saw you mentioned flash attention, are there any drawbacks to using it, because it seems to be practically the best attention
Thanks for all of your implementations, they have been really helpful to learn from
I am currently doing a distributed training run and will be open-sourcing all of the weights for a PaLM-rlhf-pytorch model. I will open a PR when training finishes.
To answer these questions:
I will update as training progresses.
Best,
Enrico
Hi Enrico, Thanks for your response, on the note of flash attention not being possible on tpus, does this imply that tpu context size/efficiency will be substantially behind gpus for the foreseeable future? Or is there an alternative to get similar improvements with tpus? Memory management and partitioning has been what I struggle with the most in trying to train on tpus with Jax.
Thanks, Aakash
Hi Enrico, Thanks for your response, on the note of flash attention not being possible on tpus, does this imply that tpu context size/efficiency will be substantially behind gpus for the foreseeable future? Or is there an alternative to get similar improvements with tpus? Memory management and partitioning has been what I struggle with the most in trying to train on tpus with Jax.
Thanks, Aakash
TPUs are already pretty highly efficient at scale. You can benchmark Flash Attention vs Jax if you are interested in the exact speed tests.
Hi, I've been planning to train this model, I have a tpu pod(v3-128) through trc, which should equate to ~ 5 tb of ram and 2 tb of vram, I had a few questions about how to begin training the model.
Thanks for all of your implementations, they have been really helpful to learn from