google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
446 stars 68 forks source link

Use bfloat16 for eval #66

Open tbaker2 opened 7 months ago

tbaker2 commented 7 months ago

I'm running paxml on an Intel Xeon CPU server using the paxml/main.py program. I'm trying to create a model that creates weights in bfloat16, and uses that datatype during eval. I modified the LmCloudSpmd2B configuration with the following lines:

MODEL_DTYPE = jnp.bfloat16
ICI_MESH_SHAPE = [1, 1, 1]

The training status output includes the following output.

model.dtype : type/jax.numpy/float32
model.fprop_dtype : dtype[bfloat16]

All of the other operator datatypes are float32. When I run that model with the --eval switch all of the computation is in float32. How can I direct paxml to use bfloat16?

Tom

tbaker2 commented 7 months ago

Any comments on this?