Closed Grutschus closed 2 months ago
Hello @Grutschus,
Thank you for reporting this! And thank you very much for sending a minimal working example. The culprits are the default floats in the definition of prodigy. Namely the following code operates as you wish:
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from optax.contrib import prodigy
jax.config.update("jax_enable_x64", True)
params = jax.random.normal(jax.random.PRNGKey(1), (3, 3), dtype=jnp.float16)
grad = 0.01 * params
print(f"Dtypes of inputs -- params: {params.dtype}, grads: {grad.dtype}")
optimizer = prodigy(
learning_rate = jnp.asarray(1., dtype=jnp.float32),
betas = (
jnp.asarray(0.9, dtype=jnp.float32),
jnp.asarray(0.999, dtype=jnp.float32)
)
)
opt_state = optimizer.init(params)
print("Dtypes of optimizer state after initialization")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
update, opt_state = optimizer.update(grad, opt_state, params)
print("Dtypes of optimizer state after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/o jitting")
print(jtu.tree_map(lambda x: x.dtype, update))
@jax.jit
def opt_update(params, grad, opt_state):
update, opt_state = optimizer.update(grad, opt_state, params)
return update, opt_state
update, opt_state = opt_update(params, grad, opt_state)
print("Dtypes of optimizer state after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, opt_state))
print("Dtypes of update after update w/ jitting")
print(jtu.tree_map(lambda x: x.dtype, update))
Outputs:
Dtypes of inputs -- params: float16, grads: float16
Dtypes of optimizer state after initialization
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of optimizer state after update w/o jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of update after update w/o jitting
float32
Dtypes of optimizer state after update w/ jitting
ProdigyState(exp_avg=dtype('float32'), exp_avg_sq=dtype('float32'), grad_sum=dtype('float32'), params0=dtype('float16'), estim_lr=dtype('float32'), numerator_weighted=dtype('float32'), count=dtype('int32'))
Dtypes of update after update w/ jitting
float32
So we may simply convert the betas/learning_rate to jnp.float32 to get what we want. There is a more general problem with this code where the types are forced to float32 even when the params may be in another dtype. Would be great to fix that such that the elements of the state that are "params-like" share the same type as the params.
So would you be willing to make a PR to fix these type issues? If not can you open an issue specifically to promote types of "params-like" fields in the state so that they match the oens of params?
Thank you again!
When using the prodigy update fn in a jitted function with
jax_enable_x64 = True
, some types in the prodigy state and the updates are promoted to float64 even if all inputs are float32.From my observations, the promotion happens here: https://github.com/google-deepmind/optax/blob/faeb7215ece7f42c3f7144d37a1c83b77643023c/optax/contrib/_prodigy.py#L120
A potential fix would be to set the dtype of
bc
explicitly tojnp.float32
as is done during initialization:Steps to reproduce
Script to reproduce
Output
Environment info
uv.lock
version = 1 requires-python = ">=3.12" [[package]] name = "absl-py" version = "2.1.0" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/7a/8f/fc001b92ecc467cc32ab38398bd0bfb45df46e7523bf33c2ad22a505f06e/absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff", size = 118055 } wheels = [ { url = "https://files.pythonhosted.org/packages/a2/ad/e0d3c824784ff121c03cc031f944bc7e139a8f1870ffd2845cc2dd76f6c4/absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308", size = 133706 }, ] [[package]] name = "chex" version = "0.1.86" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, { name = "jax" }, { name = "jaxlib" }, { name = "numpy" }, { name = "setuptools" }, { name = "toolz" }, { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/26/a2/46649fb9f6a33cc7c2822161cc5481f0ffe5965fde1e6fc4c3003cd22323/chex-0.1.86.tar.gz", hash = "sha256:e8b0f96330eba4144659e1617c0f7a57b161e8cbb021e55c6d5056c7378091d1", size = 89021 } wheels = [ { url = "https://files.pythonhosted.org/packages/e6/ed/625d545d08c6e258d2d63a93a0bf8ed8a296c09d254208e73f9d4fb0b746/chex-0.1.86-py3-none-any.whl", hash = "sha256:251c20821092323a3d9c28e1cf80e4a58180978bec368f531949bd9847eee568", size = 98167 }, ] [[package]] name = "etils" version = "1.9.4" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/19/e0/d8e99c24e7c55a9cb6a405fa502c059f77ed789f916bffbcaa8e1cc65f2d/etils-1.9.4.tar.gz", hash = "sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24", size = 103161 } wheels = [ { url = "https://files.pythonhosted.org/packages/2e/35/7f8fcc9c23a504cf09e2795164eeb39a39ade1b2c7c8724ee207b2019ae6/etils-1.9.4-py3-none-any.whl", hash = "sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b", size = 164341 }, ] [package.optional-dependencies] epy = [ { name = "typing-extensions" }, ] [[package]] name = "jax" version = "0.4.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxlib" }, { name = "ml-dtypes" }, { name = "numpy" }, { name = "opt-einsum" }, { name = "scipy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/73/e4/c1a4c0e7dafbc53fff9f42f9c1bf5918dabd1f91325512d6b382bea8750b/jax-0.4.31.tar.gz", hash = "sha256:fd2d470643a0073d822737f0788f71391656af7e62cc5b2e7995ee390ceac287", size = 1743359 } wheels = [ { url = "https://files.pythonhosted.org/packages/7e/cf/5f51b43bd692e90585c0ef6e8d1b0db5d254fe0224a6570daa59a1be014f/jax-0.4.31-py3-none-any.whl", hash = "sha256:5688703735133d0dc537e99a1d646198a49c9d472d4715fde4bd437c44151bd7", size = 2038969 }, ] [[package]] name = "jaxlib" version = "0.4.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "ml-dtypes" }, { name = "numpy" }, { name = "scipy" }, ] wheels = [ { url = "https://files.pythonhosted.org/packages/fa/27/3eee15d1b168d434498c388780114d7629f715e19c2d08754ab4be82ad2d/jaxlib-0.4.31-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:185fb615ab6bd95315fbcbd951d84e71f9835d603db8c03c91faee98ce95ff4d", size = 88818529 }, { url = "https://files.pythonhosted.org/packages/68/cf/28895a4a89d88d18592507d7a35218b6bb2d8bced13615065c9f925f2ae1/jaxlib-0.4.31-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9f89c185287e40ee8173a7142d6495311e772cd139a93dca93f0d99c1872832", size = 70079551 }, { url = "https://files.pythonhosted.org/packages/e0/af/10b49f8de2acc7abc871478823579d7241be52ca0d6bb0d2b2c476cc1b68/jaxlib-0.4.31-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:4d867a1a0565b31cfdaabbec81e0302c6461bb2ac4b92c04670328d795819803", size = 73053401 }, { url = "https://files.pythonhosted.org/packages/b1/09/58d35465d48c8bee1d9a4e7a3c5db2edaabfc7ac94f4576c9f8c51b83e70/jaxlib-0.4.31-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:1f1afa5fd58a60f67f0ca586e26714aece62eaa2c8334c24d0e8285afc4a7ccd", size = 88162291 }, { url = "https://files.pythonhosted.org/packages/c8/13/1bb2bcb4d9f4719dd5f3d98f5c2fc2235f961ced576366b040372eebdb17/jaxlib-0.4.31-cp312-cp312-win_amd64.whl", hash = "sha256:c4bfd15315e30525514b7262d555bea00745b09ac9818bb14c20ef8afbbab072", size = 56299104 }, ] [[package]] name = "ml-dtypes" version = "0.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/dd/50/17ab8a66d66bdf55ff6dea6fe2df424061cee65c6d772abc871bb563f91b/ml_dtypes-0.4.0.tar.gz", hash = "sha256:eaf197e72f4f7176a19fe3cb8b61846b38c6757607e7bf9cd4b1d84cd3e74deb", size = 692650 } wheels = [ { url = "https://files.pythonhosted.org/packages/30/9d/890e8c9cb556cec121f784fd84089e1e52939ba6eabf5dc62f6435db28d6/ml_dtypes-0.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:ee9f91d4c4f9959a7e1051c141dc565f39e54435618152219769e24f5e9a4d06", size = 394380 }, { url = "https://files.pythonhosted.org/packages/37/d5/3f3085b3a155e1b84c7fc680f05538d31cf01b835aa19cb17edd4994693f/ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad6849a2db386b38e4d54fe13eb3293464561780531a918f8ef4c8169170dd49", size = 2181698 }, { url = "https://files.pythonhosted.org/packages/8c/ef/5635b60d444db9c949b32d4e1a0a30b3ac237afbd71cce8bd1ccfb145723/ml_dtypes-0.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eaa32979ebfde3a0d7c947cafbf79edc1ec77ac05ad0780ee86c1d8df70f2259", size = 2158784 }, { url = "https://files.pythonhosted.org/packages/0f/b7/7cfca987ca898b64c0b7d185e957fbd8dccb64fe5ae9e44f68ec83371df5/ml_dtypes-0.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:3b67ec73a697c88c1122038e0de46520e48dc2ec876d42cf61bc5efe3c0b7675", size = 127498 }, ] [[package]] name = "numpy" version = "2.1.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/59/5f/9003bb3e632f2b58f5e3a3378902dcc73c5518070736c6740fe52454e8e1/numpy-2.1.1.tar.gz", hash = "sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd", size = 18874860 } wheels = [ { url = "https://files.pythonhosted.org/packages/36/11/c573ef66c004f991989c2c6218229d9003164525549409aec5ec9afc0285/numpy-2.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e", size = 20884403 }, { url = "https://files.pythonhosted.org/packages/6b/6c/a9fbef5fd2f9685212af2a9e47485cde9357c3e303e079ccf85127516f2d/numpy-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe", size = 13493375 }, { url = "https://files.pythonhosted.org/packages/34/f2/1316a6b08ad4c161d793abe81ff7181e9ae2e357a5b06352a383b9f8e800/numpy-2.1.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f", size = 5088823 }, { url = "https://files.pythonhosted.org/packages/be/15/fabf78a6d4a10c250e87daf1cd901af05e71501380532ac508879cc46a7e/numpy-2.1.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521", size = 6619825 }, { url = "https://files.pythonhosted.org/packages/9f/8a/76ddef3e621541ddd6984bc24d256a4e3422d036790cbbe449e6cad439ee/numpy-2.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b", size = 13696705 }, { url = "https://files.pythonhosted.org/packages/cb/22/2b840d297183916a95847c11f82ae11e248fa98113490b2357f774651e1d/numpy-2.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201", size = 16041649 }, { url = "https://files.pythonhosted.org/packages/c7/e8/6f4825d8f576cfd5e4d6515b9eec22bd618868bdafc8a8c08b446dcb65f0/numpy-2.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a", size = 16409358 }, { url = "https://files.pythonhosted.org/packages/bf/f8/5edf1105b0dc24fd66fc3e9e7f3bca3d920cde571caaa4375ec1566073c3/numpy-2.1.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313", size = 14172488 }, { url = "https://files.pythonhosted.org/packages/f4/c2/dddca3e69a024d2f249a5b68698328163cbdafb7e65fbf6d36373bbabf12/numpy-2.1.1-cp312-cp312-win32.whl", hash = "sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed", size = 6237195 }, { url = "https://files.pythonhosted.org/packages/b7/98/5640a09daa3abf0caeaefa6e7bf0d10c0aa28a77c84e507d6a716e0e23df/numpy-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270", size = 12568082 }, { url = "https://files.pythonhosted.org/packages/6b/9e/8bc6f133bc6d359ccc9ec051853aded45504d217685191f31f46d36b7065/numpy-2.1.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5", size = 20834810 }, { url = "https://files.pythonhosted.org/packages/32/1b/429519a2fa28681814c511574017d35f3aab7136d554cc65f4c1526dfbf5/numpy-2.1.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5", size = 13507739 }, { url = "https://files.pythonhosted.org/packages/25/18/c732d7dd9896d11e4afcd487ac65e62f9fa0495563b7614eb850765361fa/numpy-2.1.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136", size = 5074465 }, { url = "https://files.pythonhosted.org/packages/3e/37/838b7ae9262c370ab25312bab365492016f11810ffc03ebebbd54670b669/numpy-2.1.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0", size = 6606418 }, { url = "https://files.pythonhosted.org/packages/8b/b9/7ff3bfb71e316a5b43a124c4b7a5881ab12f3c32636014bef1f757f19dbd/numpy-2.1.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb", size = 13692464 }, { url = "https://files.pythonhosted.org/packages/42/78/75bcf16e6737cd196ff7ecf0e1fd3f953293a34dff4fd93fb488e8308536/numpy-2.1.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df", size = 16037763 }, { url = "https://files.pythonhosted.org/packages/23/99/36bf5ffe034d06df307bc783e25cf164775863166dcd878879559fe0379f/numpy-2.1.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78", size = 16410374 }, { url = "https://files.pythonhosted.org/packages/7f/16/04c5dab564887d4cd31a9ed30e51467fa70d52a4425f5a9bd1eed5b3d34c/numpy-2.1.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556", size = 14169873 }, { url = "https://files.pythonhosted.org/packages/09/e0/d1b5adbf1731886c4186c59a9fa208585df9452a43a2b60e79af7c649717/numpy-2.1.1-cp313-cp313-win32.whl", hash = "sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b", size = 6234118 }, { url = "https://files.pythonhosted.org/packages/d0/9c/2391ee6e9ebe77232ddcab29d92662b545e99d78c3eb3b4e26d59b9ca1ca/numpy-2.1.1-cp313-cp313-win_amd64.whl", hash = "sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0", size = 12561742 }, { url = "https://files.pythonhosted.org/packages/38/0e/c4f754f9e73f9bb520e8bf418c646f2c4f70c5d5f2bc561e90f884593193/numpy-2.1.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553", size = 20858403 }, { url = "https://files.pythonhosted.org/packages/32/fc/d69092b9171efa0cb8079577e71ce0cac0e08f917d33f6e99c916ed51d44/numpy-2.1.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480", size = 13519851 }, { url = "https://files.pythonhosted.org/packages/14/2a/d7cf2cd9f15b23f623075546ea64a2c367cab703338ca22aaaecf7e704df/numpy-2.1.1-cp313-cp313t-macosx_14_0_arm64.whl", hash = "sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f", size = 5115444 }, { url = "https://files.pythonhosted.org/packages/8e/00/e87b2cb4afcecca3b678deefb8fa53005d7054f3b5c39596e5554e5d98f8/numpy-2.1.1-cp313-cp313t-macosx_14_0_x86_64.whl", hash = "sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468", size = 6628903 }, { url = "https://files.pythonhosted.org/packages/ab/9d/337ae8721b3beec48c3413d71f2d44b2defbf3c6f7a85184fc18b7b61f4a/numpy-2.1.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef", size = 13665945 }, { url = "https://files.pythonhosted.org/packages/c0/90/ee8668e84c5d5cc080ef3beb622c016adf19ca3aa51afe9dbdcc6a9baf59/numpy-2.1.1-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f", size = 16023473 }, { url = "https://files.pythonhosted.org/packages/38/a0/57c24b2131879183051dc698fbb53fd43b77c3fa85b6e6311014f2bc2973/numpy-2.1.1-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c", size = 16400624 }, { url = "https://files.pythonhosted.org/packages/bb/4c/14a41eb5c9548c6cee6af0936eabfd985c69230ffa2f2598321431a9aa0a/numpy-2.1.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec", size = 14155072 }, ] [[package]] name = "opt-einsum" version = "3.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/7d/bf/9257e53a0e7715bc1127e15063e831f076723c6cd60985333a1c18878fb8/opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549", size = 73951 } wheels = [ { url = "https://files.pythonhosted.org/packages/bc/19/404708a7e54ad2798907210462fd950c3442ea51acc8790f3da48d2bee8b/opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147", size = 65486 }, ] [[package]] name = "optax" version = "0.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "absl-py" }, { name = "chex" }, { name = "etils", extra = ["epy"] }, { name = "jax" }, { name = "jaxlib" }, { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/d6/5f/e8b09028b37a8c1c159359e59469f3504b550910d472d8ee59543b1735d9/optax-0.2.3.tar.gz", hash = "sha256:ec7ab925440b0c5a512e1f24fba0fb3e7d760a7fd5d2496d7a691e9d37da01d9", size = 205212 } wheels = [ { url = "https://files.pythonhosted.org/packages/a3/8b/7032a6788205e9da398a8a33e1030ee9a22bd9289126e5afed9aac33bcde/optax-0.2.3-py3-none-any.whl", hash = "sha256:083e603dcd731d7e74d99f71c12f77937dd53f79001b4c09c290e4f47dd2e94f", size = 289647 }, ] [[package]] name = "optax-debug" version = "0.1.0" source = { virtual = "." } dependencies = [ { name = "jax" }, { name = "optax" }, ] [package.metadata] requires-dist = [ { name = "jax", specifier = ">=0.4.31" }, { name = "optax", specifier = ">=0.2.3" }, ] [[package]] name = "scipy" version = "1.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554 } wheels = [ { url = "https://files.pythonhosted.org/packages/c0/04/2bdacc8ac6387b15db6faa40295f8bd25eccf33f1f13e68a72dc3c60a99e/scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d", size = 39128781 }, { url = "https://files.pythonhosted.org/packages/c8/53/35b4d41f5fd42f5781dbd0dd6c05d35ba8aa75c84ecddc7d44756cd8da2e/scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07", size = 29939542 }, { url = "https://files.pythonhosted.org/packages/66/67/6ef192e0e4d77b20cc33a01e743b00bc9e68fb83b88e06e636d2619a8767/scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5", size = 23148375 }, { url = "https://files.pythonhosted.org/packages/f6/32/3a6dedd51d68eb7b8e7dc7947d5d841bcb699f1bf4463639554986f4d782/scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc", size = 25578573 }, { url = "https://files.pythonhosted.org/packages/f0/5a/efa92a58dc3a2898705f1dc9dbaf390ca7d4fba26d6ab8cfffb0c72f656f/scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310", size = 35319299 }, { url = "https://files.pythonhosted.org/packages/8e/ee/8a26858ca517e9c64f84b4c7734b89bda8e63bec85c3d2f432d225bb1886/scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066", size = 40849331 }, { url = "https://files.pythonhosted.org/packages/a5/cd/06f72bc9187840f1c99e1a8750aad4216fc7dfdd7df46e6280add14b4822/scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1", size = 42544049 }, { url = "https://files.pythonhosted.org/packages/aa/7d/43ab67228ef98c6b5dd42ab386eae2d7877036970a0d7e3dd3eb47a0d530/scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f", size = 44521212 }, { url = "https://files.pythonhosted.org/packages/50/ef/ac98346db016ff18a6ad7626a35808f37074d25796fd0234c2bb0ed1e054/scipy-1.14.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79", size = 39091068 }, { url = "https://files.pythonhosted.org/packages/b9/cc/70948fe9f393b911b4251e96b55bbdeaa8cca41f37c26fd1df0232933b9e/scipy-1.14.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e", size = 29875417 }, { url = "https://files.pythonhosted.org/packages/3b/2e/35f549b7d231c1c9f9639f9ef49b815d816bf54dd050da5da1c11517a218/scipy-1.14.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73", size = 23084508 }, { url = "https://files.pythonhosted.org/packages/3f/d6/b028e3f3e59fae61fb8c0f450db732c43dd1d836223a589a8be9f6377203/scipy-1.14.1-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e", size = 25503364 }, { url = "https://files.pythonhosted.org/packages/a7/2f/6c142b352ac15967744d62b165537a965e95d557085db4beab2a11f7943b/scipy-1.14.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d", size = 35292639 }, { url = "https://files.pythonhosted.org/packages/56/46/2449e6e51e0d7c3575f289f6acb7f828938eaab8874dbccfeb0cd2b71a27/scipy-1.14.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e", size = 40798288 }, { url = "https://files.pythonhosted.org/packages/32/cd/9d86f7ed7f4497c9fd3e39f8918dd93d9f647ba80d7e34e4946c0c2d1a7c/scipy-1.14.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06", size = 42524647 }, { url = "https://files.pythonhosted.org/packages/f5/1b/6ee032251bf4cdb0cc50059374e86a9f076308c1512b61c4e003e241efb7/scipy-1.14.1-cp313-cp313-win_amd64.whl", hash = "sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84", size = 44469524 }, ] [[package]] name = "setuptools" version = "74.1.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/3e/2c/f0a538a2f91ce633a78daaeb34cbfb93a54bd2132a6de1f6cec028eee6ef/setuptools-74.1.2.tar.gz", hash = "sha256:95b40ed940a1c67eb70fc099094bd6e99c6ee7c23aa2306f4d2697ba7916f9c6", size = 1356467 } wheels = [ { url = "https://files.pythonhosted.org/packages/cb/9c/9ad11ac06b97e55ada655f8a6bea9d1d3f06e120b178cd578d80e558191d/setuptools-74.1.2-py3-none-any.whl", hash = "sha256:5f4c08aa4d3ebcb57a50c33b1b07e94315d7fc7230f7115e47fc99776c8ce308", size = 1262071 }, ] [[package]] name = "toolz" version = "0.12.1" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/3e/bf/5e12db234df984f6df3c7f12f1428aa680ba4e101f63f4b8b3f9e8d2e617/toolz-0.12.1.tar.gz", hash = "sha256:ecca342664893f177a13dac0e6b41cbd8ac25a358e5f215316d43e2100224f4d", size = 66550 } wheels = [ { url = "https://files.pythonhosted.org/packages/b7/8a/d82202c9f89eab30f9fc05380daae87d617e2ad11571ab23d7c13a29bb54/toolz-0.12.1-py3-none-any.whl", hash = "sha256:d22731364c07d72eea0a0ad45bafb2c2937ab6fd38a3507bf55eae8744aa7d85", size = 56121 }, ] [[package]] name = "typing-extensions" version = "4.12.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/df/db/f35a00659bc03fec321ba8bce9420de607a1d37f8342eee1863174c69557/typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8", size = 85321 } wheels = [ { url = "https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d", size = 37438 }, ]
optax version: 0.2.3 jax version: 0.4.31 (w/o cuda) OS:
Linux 5.15.0-119-generic #129-Ubuntu SMP Fri Aug 2 19:25:20 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux