google-deepmind / kfac-jax

Second Order Optimization and Curvature Estimation with K-FAC in JAX.
Apache License 2.0
248 stars 22 forks source link

Bug: Jax 0.4.13 support #279

Open arnon-1 opened 3 weeks ago

arnon-1 commented 3 weeks ago

Hello,

While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors. I installed commit a4531e90bf7d3cc46b423b01f61eaafeccb02fed which is fairly recent.

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/utils/types.py", line 27, in <module>
    DType = jax.typing.DTypeLike
AttributeError: module 'jax.typing' has no attribute 'DTypeLike'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/curvature_blocks/curvature_block.py", line 123, in parameters_shapes
    return tuple(jax.tree.map(
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/jax/_src/deprecations.py", line 53, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax' has no attribute 'tree'

  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tracer.py", line 790, in forward
    write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
  File "/opt/miniconda3/envs/jax/lib/python3.10/site-packages/kfac_jax/_src/tag_graph_matcher.py", line 68, in eval_jaxpr_eqn
    user_context = jax_extend.source_info_util.user_context
AttributeError: module 'jax.extend' has no attribute 'source_info_util'
james-martens commented 5 days ago

I would recommend using the latest version of JAX. It's possible that changes made to the code to support the new version of JAX has broke compatibility with older version. e.g. the change to use jax.tree.map instead of jax.tree_util.tree_map. The 'source_info_util' error is strange to me since we have a version check for that. Maybe you have a weird corrupted version of the JAX library, or maybe our version check on that line is wrong.

arnon-1 commented 4 days ago

Thank you for your answer. I would use the latest version if I could... I don't think the JAX library is the error as I received the same errors on multiple systems with different installation methods. Increasing the version check to 0.4.13 indeed seemed to fix that error (I wasn't sure how much information I was allowed to provide since PRs have been explicitly disallowed). I guess the version bump solves this issue.

james-martens commented 4 days ago

Increasing the version check to 0.4.13 indeed seemed to fix that error

Sorry, which error was this? The 'source_info_util' one?

How did you fix the other errors if you are still using 0.4.13?

arnon-1 commented 4 days ago

For the first error: I copied the DTypeLike definition from the later jax version.

+from jax._src.typing import SupportsDType
-DType = jax.typing.DTypeLike
+DType = Union[
+  str,            # like 'float32', 'int32'
+  type[Any],      # like np.float32, np.int32, float, int
+  np.dtype,       # like np.dtype('float32'), np.dtype('int32')
+  SupportsDType,  # like jnp.float32, jnp.int32
+]

For the second error: Most of the functionality in jax.tree used to be located in jax.tree_util (now a legacy api). For example, I changed jax.tree.map to jax.tree_util.tree_map

For the third error: I updated the previously mentioned version check, this seems to work:

-  if jax_version > (0, 4, 11):
+  if jax_version > (0, 4, 13):