Open yongchanghao opened 1 year ago
The doc says mu is inferred from grads and updates if mu_dtype=None.
mu
grads
updates
mu_dtype=None
But this line actually turns jnp.bfloat16 and jnp.float16 to jnp.float32 when mu_dtype=None.
jnp.bfloat16
jnp.float16
jnp.float32
Example on GPUs:
>>> jax.__version__ '0.4.4' >>> x.astype(jnp.float16).dtype dtype('float16') >>> x.astype(jnp.float16).astype(None).dtype dtype('float32')
The doc says
mu
is inferred fromgrads
andupdates
ifmu_dtype=None
.But this line actually turns
jnp.bfloat16
andjnp.float16
tojnp.float32
whenmu_dtype=None
.Example on GPUs: