Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458
stars
69
forks
source link
Convert string dtype to jnp dtype during evaluation #54
During evaluation, forces given fprop_dtype to a proper jnp dtype to handle the case in which uses passes a string dtype (e.g. “bfloat16” rather than jnp.bfloat16).
During evaluation, forces given fprop_dtype to a proper jnp dtype to handle the case in which uses passes a string dtype (e.g. “bfloat16” rather than jnp.bfloat16).