Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
Make num_train_steps configurable in gpu configs #14
Minor change to
paxml/contrib/gpu/scripts_gpu/configs.py
which allowsnum_train_steps
to be configured.