google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Convert string dtype to jnp dtype during evaluation #54

Closed ashors1 closed 1 year ago

ashors1 commented 1 year ago

During evaluation, forces given fprop_dtype to a proper jnp dtype to handle the case in which uses passes a string dtype (e.g. “bfloat16” rather than jnp.bfloat16).