google / automl

Google Brain AutoML
Apache License 2.0
6.18k stars 1.45k forks source link

Potentially wrong type inference #1188

Open yongchanghao opened 1 year ago

yongchanghao commented 1 year ago

The doc says mu is inferred from grads and updates if mu_dtype=None.

But this line actually turns jnp.bfloat16 and jnp.float16 to jnp.float32 when mu_dtype=None.

Example on GPUs:

>>> jax.__version__
'0.4.4'
>>> x.astype(jnp.float16).dtype
dtype('float16')
>>> x.astype(jnp.float16).astype(None).dtype
dtype('float32')