Open arnon-1 opened 3 weeks ago
I would recommend using the latest version of JAX. It's possible that changes made to the code to support the new version of JAX has broke compatibility with older version. e.g. the change to use jax.tree.map instead of jax.tree_util.tree_map. The 'source_info_util' error is strange to me since we have a version check for that. Maybe you have a weird corrupted version of the JAX library, or maybe our version check on that line is wrong.
Thank you for your answer. I would use the latest version if I could... I don't think the JAX library is the error as I received the same errors on multiple systems with different installation methods. Increasing the version check to 0.4.13 indeed seemed to fix that error (I wasn't sure how much information I was allowed to provide since PRs have been explicitly disallowed). I guess the version bump solves this issue.
Increasing the version check to 0.4.13 indeed seemed to fix that error
Sorry, which error was this? The 'source_info_util' one?
How did you fix the other errors if you are still using 0.4.13?
For the first error: I copied the DTypeLike definition from the later jax version.
+from jax._src.typing import SupportsDType
-DType = jax.typing.DTypeLike
+DType = Union[
+ str, # like 'float32', 'int32'
+ type[Any], # like np.float32, np.int32, float, int
+ np.dtype, # like np.dtype('float32'), np.dtype('int32')
+ SupportsDType, # like jnp.float32, jnp.int32
+]
For the second error: Most of the functionality in jax.tree used to be located in jax.tree_util (now a legacy api). For example, I changed jax.tree.map to jax.tree_util.tree_map
For the third error: I updated the previously mentioned version check, this seems to work:
- if jax_version > (0, 4, 11):
+ if jax_version > (0, 4, 13):
Hello,
While trying to use kfac_jax in jax 0.4.13 (which is supported if I am not mistaken), I had to fix some errors. I installed commit a4531e90bf7d3cc46b423b01f61eaafeccb02fed which is fairly recent.