google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.71k stars 194 forks source link

Unwanted type promotion in prodigy optimizer #1051

Closed Grutschus closed 2 months ago

Grutschus commented 2 months ago

When using the prodigy update fn in a jitted function with jax_enable_x64 = True, some types in the prodigy state and the updates are promoted to float64 even if all inputs are float32.

From my observations, the promotion happens here: https://github.com/google-deepmind/optax/blob/faeb7215ece7f42c3f7144d37a1c83b77643023c/optax/contrib/_prodigy.py#L120

A potential fix would be to set the dtype of bc explicitly to jnp.float32 as is done during initialization:

    bc = jnp.array(((1 - beta2 ** (count + 1)) ** 0.5) / (1 - beta1 ** (count + 1)), dtype=jnp.float32)

Steps to reproduce

Script to reproduce

# /// script
# requires-python = ">=3.12"
# dependencies = [
#     "jax",
#     "optax",
# ]
# ///
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from optax.contrib import prodigy

jax.config.update("jax_enable_x64", True)

params = jax.random.normal(jax.random.PRNGKey(1), (3, 3), dtype=jnp.float16)
grad = 0.01 * params
print(f"Dtypes of inputs -- params: {params.dtype}, grads: {grad.dtype}")

optimizer = prodigy()
opt_state = optimizer.init(params)
print("Dtypes of optimizer state after initialization")
print(jtu.tree_map(lambda x: x.dtype, opt_state))

update, opt_state = optimizer.update(grad, opt_state, params)
print("Dtypes of optimizer state after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, update))

@jax.jit
def opt_update(params, grad, opt_state):
    update, opt_state = optimizer.update(grad, opt_state, params)
    return update, opt_state

update, opt_state = opt_update(params, grad, opt_state)
print("Dtypes of optimizer state after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, update))
uv run --script script.py

Output

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
Dtypes of inputs -- params: float16, grads: float16
Dtypes of optimizer state after initialization
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of optimizer state after update w/o jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of update after update w/o jitting
float32
Dtypes of optimizer state after update w/ jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float64'), params0=dtype('float16'), estim_lr=dtype('float64'), numerator_weighted=dtype('float64'), count=dtype('int32'))
Dtypes of update after update w/ jitting
float64

Environment info

uv.lock

version = 1 requires-python = ">=3.12" [[package]] name = "absl-py" version = "2.1.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 } wheels = [ { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 }, ] [[package]] name = "chex" version = "0.1.86" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, { name = "jax" }, { name = "jaxlib" }, { name = "numpy" }, { name = "setuptools" }, { name = "toolz" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/26/a2/46649fb9f6a33cc7c2822161cc5481f0ffe5965fde1e6fc4c3003cd22323/chex-0.1.86.tar.gz", hash = "sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1", size = 89021 } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/625d545d08c6e258d2d63a93a0bf8ed8a296c09d254208e73f9d4fb0b746/chex-0.1.86-py3-none-any.whl", hash = "sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568", size = 98167 }, ] [[package]] name = "etils" version = "1.9.4" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/19/e0/d8e99c24e7c55a9cb6a405fa502c059f77ed789f916bffbcaa8e1cc65f2d/etils-1.9.4.tar.gz", hash = "sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24", size = 103161 } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/35/7f8fcc9c23a504cf09e2795164eeb39a39ade1b2c7c8724ee207b2019ae6/etils-1.9.4-py3-none-any.whl", hash = "sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b", size = 164341 }, ] [package.optional-dependencies] epy = [ { name = "typing-extensions" }, ] [[package]] name = "jax" version = "0.4.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, { name = "ml-dtypes" }, { name = "numpy" }, { name = "opt-einsum" }, { name = "scipy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/73/e4/c1a4c0e7dafbc53fff9f42f9c1bf5918dabd1f91325512d6b382bea8750b/jax-0.4.31.tar.gz", hash = "sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287", size = 1743359 } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/cf/5f51b43bd692e90585c0ef6e8d1b0db5d254fe0224a6570daa59a1be014f/jax-0.4.31-py3-none-any.whl", hash = "sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7", size = 2038969 }, ] [[package]] name = "jaxlib" version = "0.4.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy" }, { name = "scipy" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/27/3eee15d1b168d434498c388780114d7629f715e19c2d08754ab4be82ad2d/jaxlib-0.4.31-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d", size = 88818529 }, { url = "https://files.pythonhosted.org/packages/68/cf/28895a4a89d88d18592507d7a35218b6bb2d8bced13615065c9f925f2ae1/jaxlib-0.4.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832", size = 70079551 }, { url = "https://files.pythonhosted.org/packages/e0/af/10b49f8de2acc7abc871478823579d7241be52ca0d6bb0d2b2c476cc1b68/jaxlib-0.4.31-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803", size = 73053401 }, { url = "https://files.pythonhosted.org/packages/b1/09/58d35465d48c8bee1d9a4e7a3c5db2edaabfc7ac94f4576c9f8c51b83e70/jaxlib-0.4.31-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd", size = 88162291 }, { url = "https://files.pythonhosted.org/packages/c8/13/1bb2bcb4d9f4719dd5f3d98f5c2fc2235f961ced576366b040372eebdb17/jaxlib-0.4.31-cp312-cp312-win_amd64.whl", hash = "sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072", size = 56299104 }, ] [[package]] name = "ml-dtypes" version = "0.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/dd/50/17ab8a66d66bdf55ff6dea6fe2df424061cee65c6d772abc871bb563f91b/ml_dtypes-0.4.0.tar.gz", hash = "sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb", size = 692650 } wheels = [ { url = "https://files.pythonhosted.org/packages/30/9d/890e8c9cb556cec121f784fd84089e1e52939ba6eabf5dc62f6435db28d6/ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06", size = 394380 }, { url = "https://files.pythonhosted.org/packages/37/d5/3f3085b3a155e1b84c7fc680f05538d31cf01b835aa19cb17edd4994693f/ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49", size = 2181698 }, { url = "https://files.pythonhosted.org/packages/8c/ef/5635b60d444db9c949b32d4e1a0a30b3ac237afbd71cce8bd1ccfb145723/ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259", size = 2158784 }, { url = "https://files.pythonhosted.org/packages/0f/b7/7cfca987ca898b64c0b7d185e957fbd8dccb64fe5ae9e44f68ec83371df5/ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675", size = 127498 }, ] [[package]] name = "numpy" version = "2.1.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/59/5f/9003bb3e632f2b58f5e3a3378902dcc73c5518070736c6740fe52454e8e1/numpy-2.1.1.tar.gz", hash = "sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd", size = 18874860 } wheels = [ { url = "https://files.pythonhosted.org/packages/36/11/c573ef66c004f991989c2c6218229d9003164525549409aec5ec9afc0285/numpy-2.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e", size = 20884403 }, { url = "https://files.pythonhosted.org/packages/6b/6c/a9fbef5fd2f9685212af2a9e47485cde9357c3e303e079ccf85127516f2d/numpy-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe", size = 13493375 }, { url = "https://files.pythonhosted.org/packages/34/f2/1316a6b08ad4c161d793abe81ff7181e9ae2e357a5b06352a383b9f8e800/numpy-2.1.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f", size = 5088823 }, { url = "https://files.pythonhosted.org/packages/be/15/fabf78a6d4a10c250e87daf1cd901af05e71501380532ac508879cc46a7e/numpy-2.1.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521", size = 6619825 }, { url = "https://files.pythonhosted.org/packages/9f/8a/76ddef3e621541ddd6984bc24d256a4e3422d036790cbbe449e6cad439ee/numpy-2.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b", size = 13696705 }, { url = "https://files.pythonhosted.org/packages/cb/22/2b840d297183916a95847c11f82ae11e248fa98113490b2357f774651e1d/numpy-2.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201", size = 16041649 }, { url = "https://files.pythonhosted.org/packages/c7/e8/6f4825d8f576cfd5e4d6515b9eec22bd618868bdafc8a8c08b446dcb65f0/numpy-2.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a", size = 16409358 }, { url = "https://files.pythonhosted.org/packages/bf/f8/5edf1105b0dc24fd66fc3e9e7f3bca3d920cde571caaa4375ec1566073c3/numpy-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313", size = 14172488 }, { url = "https://files.pythonhosted.org/packages/f4/c2/dddca3e69a024d2f249a5b68698328163cbdafb7e65fbf6d36373bbabf12/numpy-2.1.1-cp312-cp312-win32.whl", hash = "sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed", size = 6237195 }, { url = "https://files.pythonhosted.org/packages/b7/98/5640a09daa3abf0caeaefa6e7bf0d10c0aa28a77c84e507d6a716e0e23df/numpy-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270", size = 12568082 }, { url = "https://files.pythonhosted.org/packages/6b/9e/8bc6f133bc6d359ccc9ec051853aded45504d217685191f31f46d36b7065/numpy-2.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5", size = 20834810 }, { url = "https://files.pythonhosted.org/packages/32/1b/429519a2fa28681814c511574017d35f3aab7136d554cc65f4c1526dfbf5/numpy-2.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5", size = 13507739 }, { url = "https://files.pythonhosted.org/packages/25/18/c732d7dd9896d11e4afcd487ac65e62f9fa0495563b7614eb850765361fa/numpy-2.1.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136", size = 5074465 }, { url = "https://files.pythonhosted.org/packages/3e/37/838b7ae9262c370ab25312bab365492016f11810ffc03ebebbd54670b669/numpy-2.1.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0", size = 6606418 }, { url = "https://files.pythonhosted.org/packages/8b/b9/7ff3bfb71e316a5b43a124c4b7a5881ab12f3c32636014bef1f757f19dbd/numpy-2.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb", size = 13692464 }, { url = "https://files.pythonhosted.org/packages/42/78/75bcf16e6737cd196ff7ecf0e1fd3f953293a34dff4fd93fb488e8308536/numpy-2.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df", size = 16037763 }, { url = "https://files.pythonhosted.org/packages/23/99/36bf5ffe034d06df307bc783e25cf164775863166dcd878879559fe0379f/numpy-2.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78", size = 16410374 }, { url = "https://files.pythonhosted.org/packages/7f/16/04c5dab564887d4cd31a9ed30e51467fa70d52a4425f5a9bd1eed5b3d34c/numpy-2.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556", size = 14169873 }, { url = "https://files.pythonhosted.org/packages/09/e0/d1b5adbf1731886c4186c59a9fa208585df9452a43a2b60e79af7c649717/numpy-2.1.1-cp313-cp313-win32.whl", hash = "sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b", size = 6234118 }, { url = "https://files.pythonhosted.org/packages/d0/9c/2391ee6e9ebe77232ddcab29d92662b545e99d78c3eb3b4e26d59b9ca1ca/numpy-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0", size = 12561742 }, { url = "https://files.pythonhosted.org/packages/38/0e/c4f754f9e73f9bb520e8bf418c646f2c4f70c5d5f2bc561e90f884593193/numpy-2.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553", size = 20858403 }, { url = "https://files.pythonhosted.org/packages/32/fc/d69092b9171efa0cb8079577e71ce0cac0e08f917d33f6e99c916ed51d44/numpy-2.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480", size = 13519851 }, { url = "https://files.pythonhosted.org/packages/14/2a/d7cf2cd9f15b23f623075546ea64a2c367cab703338ca22aaaecf7e704df/numpy-2.1.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f", size = 5115444 }, { url = "https://files.pythonhosted.org/packages/8e/00/e87b2cb4afcecca3b678deefb8fa53005d7054f3b5c39596e5554e5d98f8/numpy-2.1.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468", size = 6628903 }, { url = "https://files.pythonhosted.org/packages/ab/9d/337ae8721b3beec48c3413d71f2d44b2defbf3c6f7a85184fc18b7b61f4a/numpy-2.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef", size = 13665945 }, { url = "https://files.pythonhosted.org/packages/c0/90/ee8668e84c5d5cc080ef3beb622c016adf19ca3aa51afe9dbdcc6a9baf59/numpy-2.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f", size = 16023473 }, { url = "https://files.pythonhosted.org/packages/38/a0/57c24b2131879183051dc698fbb53fd43b77c3fa85b6e6311014f2bc2973/numpy-2.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c", size = 16400624 }, { url = "https://files.pythonhosted.org/packages/bb/4c/14a41eb5c9548c6cee6af0936eabfd985c69230ffa2f2598321431a9aa0a/numpy-2.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec", size = 14155072 }, ] [[package]] name = "opt-einsum" version = "3.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7d/bf/9257e53a0e7715bc1127e15063e831f076723c6cd60985333a1c18878fb8/opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549", size = 73951 } wheels = [ { url = "https://files.pythonhosted.org/packages/bc/19/404708a7e54ad2798907210462fd950c3442ea51acc8790f3da48d2bee8b/opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147", size = 65486 }, ] [[package]] name = "optax" version = "0.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, { name = "chex" }, { name = "etils", extra = ["epy"] }, { name = "jax" }, { name = "jaxlib" }, { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d6/5f/e8b09028b37a8c1c159359e59469f3504b550910d472d8ee59543b1735d9/optax-0.2.3.tar.gz", hash = "sha256:ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9", size = 205212 } wheels = [ { url = "https://files.pythonhosted.org/packages/a3/8b/7032a6788205e9da398a8a33e1030ee9a22bd9289126e5afed9aac33bcde/optax-0.2.3-py3-none-any.whl", hash = "sha256:083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f", size = 289647 }, ] [[package]] name = "optax-debug" version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "jax" }, { name = "optax" }, ] [package.metadata] requires-dist = [ { name = "jax", specifier = ">=0.4.31" }, { name = "optax", specifier = ">=0.2.3" }, ] [[package]] name = "scipy" version = "1.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554 } wheels = [ { url = "https://files.pythonhosted.org/packages/c0/04/2bdacc8ac6387b15db6faa40295f8bd25eccf33f1f13e68a72dc3c60a99e/scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d", size = 39128781 }, { url = "https://files.pythonhosted.org/packages/c8/53/35b4d41f5fd42f5781dbd0dd6c05d35ba8aa75c84ecddc7d44756cd8da2e/scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07", size = 29939542 }, { url = "https://files.pythonhosted.org/packages/66/67/6ef192e0e4d77b20cc33a01e743b00bc9e68fb83b88e06e636d2619a8767/scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5", size = 23148375 }, { url = "https://files.pythonhosted.org/packages/f6/32/3a6dedd51d68eb7b8e7dc7947d5d841bcb699f1bf4463639554986f4d782/scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc", size = 25578573 }, { url = "https://files.pythonhosted.org/packages/f0/5a/efa92a58dc3a2898705f1dc9dbaf390ca7d4fba26d6ab8cfffb0c72f656f/scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310", size = 35319299 }, { url = "https://files.pythonhosted.org/packages/8e/ee/8a26858ca517e9c64f84b4c7734b89bda8e63bec85c3d2f432d225bb1886/scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066", size = 40849331 }, { url = "https://files.pythonhosted.org/packages/a5/cd/06f72bc9187840f1c99e1a8750aad4216fc7dfdd7df46e6280add14b4822/scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1", size = 42544049 }, { url = "https://files.pythonhosted.org/packages/aa/7d/43ab67228ef98c6b5dd42ab386eae2d7877036970a0d7e3dd3eb47a0d530/scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f", size = 44521212 }, { url = "https://files.pythonhosted.org/packages/50/ef/ac98346db016ff18a6ad7626a35808f37074d25796fd0234c2bb0ed1e054/scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79", size = 39091068 }, { url = "https://files.pythonhosted.org/packages/b9/cc/70948fe9f393b911b4251e96b55bbdeaa8cca41f37c26fd1df0232933b9e/scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e", size = 29875417 }, { url = "https://files.pythonhosted.org/packages/3b/2e/35f549b7d231c1c9f9639f9ef49b815d816bf54dd050da5da1c11517a218/scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73", size = 23084508 }, { url = "https://files.pythonhosted.org/packages/3f/d6/b028e3f3e59fae61fb8c0f450db732c43dd1d836223a589a8be9f6377203/scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e", size = 25503364 }, { url = "https://files.pythonhosted.org/packages/a7/2f/6c142b352ac15967744d62b165537a965e95d557085db4beab2a11f7943b/scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d", size = 35292639 }, { url = "https://files.pythonhosted.org/packages/56/46/2449e6e51e0d7c3575f289f6acb7f828938eaab8874dbccfeb0cd2b71a27/scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e", size = 40798288 }, { url = "https://files.pythonhosted.org/packages/32/cd/9d86f7ed7f4497c9fd3e39f8918dd93d9f647ba80d7e34e4946c0c2d1a7c/scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06", size = 42524647 }, { url = "https://files.pythonhosted.org/packages/f5/1b/6ee032251bf4cdb0cc50059374e86a9f076308c1512b61c4e003e241efb7/scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84", size = 44469524 }, ] [[package]] name = "setuptools" version = "74.1.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/3e/2c/f0a538a2f91ce633a78daaeb34cbfb93a54bd2132a6de1f6cec028eee6ef/setuptools-74.1.2.tar.gz", hash = "sha256:95b40ed940a1c67eb70fc099094bd6e99c6ee7c23aa2306f4d2697ba7916f9c6", size = 1356467 } wheels = [ { url = "https://files.pythonhosted.org/packages/cb/9c/9ad11ac06b97e55ada655f8a6bea9d1d3f06e120b178cd578d80e558191d/setuptools-74.1.2-py3-none-any.whl", hash = "sha256:5f4c08aa4d3ebcb57a50c33b1b07e94315d7fc7230f7115e47fc99776c8ce308", size = 1262071 }, ] [[package]] name = "toolz" version = "0.12.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/3e/bf/5e12db234df984f6df3c7f12f1428aa680ba4e101f63f4b8b3f9e8d2e617/toolz-0.12.1.tar.gz", hash = "sha256:ecca342664893f177a13dac0e6b41cbd8ac25a358e5f215316d43e2100224f4d", size = 66550 } wheels = [ { url = "https://files.pythonhosted.org/packages/b7/8a/d82202c9f89eab30f9fc05380daae87d617e2ad11571ab23d7c13a29bb54/toolz-0.12.1-py3-none-any.whl", hash = "sha256:d22731364c07d72eea0a0ad45bafb2c2937ab6fd38a3507bf55eae8744aa7d85", size = 56121 }, ] [[package]] name = "typing-extensions" version = "4.12.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } wheels = [ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, ]

optax version: 0.2.3 jax version: 0.4.31 (w/o cuda) OS: Linux 5.15.0-119-generic #129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

vroulet commented 2 months ago

Hello @Grutschus,

Thank you for reporting this! And thank you very much for sending a minimal working example. The culprits are the default floats in the definition of prodigy. Namely the following code operates as you wish:

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from optax.contrib import prodigy

jax.config.update("jax_enable_x64", True)

params = jax.random.normal(jax.random.PRNGKey(1), (3, 3), dtype=jnp.float16)
grad = 0.01 * params
print(f"Dtypes of inputs -- params: {params.dtype}, grads: {grad.dtype}")

optimizer = prodigy(
      learning_rate = jnp.asarray(1., dtype=jnp.float32),
      betas = (
          jnp.asarray(0.9, dtype=jnp.float32),
          jnp.asarray(0.999, dtype=jnp.float32)
      )
  )
opt_state = optimizer.init(params)
print("Dtypes of optimizer state after initialization")
print(jtu.tree_map(lambda x: x.dtype, opt_state))

update, opt_state = optimizer.update(grad, opt_state, params)
print("Dtypes of optimizer state after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, update))

@jax.jit
def opt_update(params, grad, opt_state):
    update, opt_state = optimizer.update(grad, opt_state, params)
    return update, opt_state

update, opt_state = opt_update(params, grad, opt_state)
print("Dtypes of optimizer state after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, update))

Outputs:

Dtypes of inputs -- params: float16, grads: float16
Dtypes of optimizer state after initialization
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of optimizer state after update w/o jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of update after update w/o jitting
float32
Dtypes of optimizer state after update w/ jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of update after update w/ jitting
float32

So we may simply convert the betas/learning_rate to jnp.float32 to get what we want. There is a more general problem with this code where the types are forced to float32 even when the params may be in another dtype. Would be great to fix that such that the elements of the state that are "params-like" share the same type as the params.

So would you be willing to make a PR to fix these type issues? If not can you open an issue specifically to promote types of "params-like" fields in the state so that they match the oens of params?

Thank you again!