Closed dcjones closed 3 years ago
When jax_enable_x64 is set, Adam's apply_gradient method will promote all float32 arrays to float64, potentially unexpectedly degrading performance.
jax_enable_x64
apply_gradient
This is due to jax's wonky type promotion semantics. The offending line is: https://github.com/google/flax/blob/3e36db3e5e3b8e6e1777d612f270e7948238aa9c/flax/optim/adam.py#L82
which promotes like:
jnp.array([0], dtype=jnp.int32) + 1. # == DeviceArray([1.], dtype=float64)
and then cascades from there promoting everything to float64
Arrays should retain their dtypes on optimizer updates.
from jax.config import config config.update("jax_enable_x64", True) import jax.numpy as jnp import flax opt = flax.optim.Adam(1e-3).create( {"x": jnp.zeros(10, dtype=jnp.float32)}) assert opt.target["x"].dtype == jnp.float32 opt = opt.apply_gradient({"x": jnp.zeros(10, dtype=jnp.float32)}) # This fails, since dtype was promoted to float64 assert opt.target["x"].dtype == jnp.float32
Sorry for the delay - thanks so much for letting us know, we'll try to fix it (and same issue in LAMB) in #965
Problem you have encountered:
When
jax_enable_x64
is set, Adam'sapply_gradient
method will promote all float32 arrays to float64, potentially unexpectedly degrading performance.This is due to jax's wonky type promotion semantics. The offending line is: https://github.com/google/flax/blob/3e36db3e5e3b8e6e1777d612f270e7948238aa9c/flax/optim/adam.py#L82
which promotes like:
and then cascades from there promoting everything to float64
What you expected to happen:
Arrays should retain their dtypes on optimizer updates.
Logs, error messages, etc:
Steps to reproduce: