google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.83k stars 620 forks source link

When jax_enable_x64 is set Adam promotes everything to float64 #924

Closed dcjones closed 3 years ago

dcjones commented 3 years ago

Problem you have encountered:

When jax_enable_x64 is set, Adam's apply_gradient method will promote all float32 arrays to float64, potentially unexpectedly degrading performance.

This is due to jax's wonky type promotion semantics. The offending line is: https://github.com/google/flax/blob/3e36db3e5e3b8e6e1777d612f270e7948238aa9c/flax/optim/adam.py#L82

which promotes like:

jnp.array([0], dtype=jnp.int32) + 1. # == DeviceArray([1.], dtype=float64)

and then cascades from there promoting everything to float64

What you expected to happen:

Arrays should retain their dtypes on optimizer updates.

Logs, error messages, etc:

Steps to reproduce:

from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as jnp
import flax

opt = flax.optim.Adam(1e-3).create(
        {"x": jnp.zeros(10, dtype=jnp.float32)})

assert opt.target["x"].dtype == jnp.float32

opt = opt.apply_gradient({"x": jnp.zeros(10, dtype=jnp.float32)})

# This fails, since dtype was promoted to float64
assert opt.target["x"].dtype == jnp.float32
levskaya commented 3 years ago

Sorry for the delay - thanks so much for letting us know, we'll try to fix it (and same issue in LAMB) in #965