Open arampacha opened 3 years ago
Casting weights to bf16 is not recommended and removed for now.
here's the gradient accumulation from the vision_transformer
codebase:
https://github.com/google-research/vision_transformer/blob/ba9a85bdc430daf4da7b9da67b486a4e0f5bb278/vit_jax/hyper.py#L77
And here's a small example https://github.com/google-research/vision_transformer/blob/ba9a85bdc430daf4da7b9da67b486a4e0f5bb278/vit_jax/train.py#L63-L66
for gradient accumulation, i have opened a PR: https://github.com/ncoop57/gpt-code-clippy/pull/29 let me know if we can sync up for this
Hello, what are the minimum hardware requirements to run the training script?
Hi @celsofranssa, the hyperparameters in HF model cards (for example here) are tuned for TPU-v3-8. But you can run the script on GPU adjusting batch size accordingly and mb switching dtype
from bfloat16
to float16
for your hardware. Not sure what the minimum requirement would be exactly. You can also consider decreasing block_size
if you run out of memory.