Open jakevdp opened 2 weeks ago
Caught by the jax-ai-stack nightly tests: https://github.com/jax-ml/jax-ai-stack/issues/90#issuecomment-2466190697
Package versions:
$ pip list | grep orbax orbax-checkpoint 0.9.0 orbax-export 0.0.5
Repro:
import jax.numpy as jnp from orbax.export import JaxModule params = {'a': jnp.array(5.0), 'b': jnp.array(1.1), 'c': jnp.array(0.55)} def model_fn(params, inputs): a, b, c = params['a'], params['b'], params['c'] return a * jnp.sin(inputs) + b * inputs + c jax_module = JaxModule(params, model_fn)
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) [<ipython-input-2-3dff93cebc33>](https://localhost:8080/#) in <cell line: 10>() 8 return a * jnp.sin(inputs) + b * inputs + c 9 ---> 10 jax_module = JaxModule(params, model_fn) [/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in __init__(self, params, apply_fn, trainable, input_polymorphic_shape, jax2tf_kwargs, jit_compile, name, pspecs, allow_multi_axis_sharding_conslidation) 190 self._methods = dict() 191 else: --> 192 tf_vars = _jax_params_to_tf_variables( 193 params, trainable, pspecs, allow_multi_axis_sharding_conslidation 194 ) [/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _jax_params_to_tf_variables(params, trainable, pspecs, allow_multi_axis_sharding_conslidation) 393 ) 394 --> 395 names = _get_param_names(params) 396 if pspecs is None: 397 pspecs = jax.tree_util.tree_map(lambda x: None, params) [/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _get_param_names(params) 317 return name.replace('~', '_') 318 --> 319 names = jax.tree_util.tree_map_with_path( 320 lambda kp, _: _param_name_from_keypath(kp), params 321 ) [/usr/local/lib/python3.10/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map_with_path(f, tree, is_leaf, *rest) 1198 keypath_leaves = list(zip(*keypath_leaves)) 1199 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] -> 1200 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) 1201 1202 [/usr/local/lib/python3.10/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0) 1198 keypath_leaves = list(zip(*keypath_leaves)) 1199 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest] -> 1200 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves)) 1201 1202 [/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in <lambda>(kp, _) 318 319 names = jax.tree_util.tree_map_with_path( --> 320 lambda kp, _: _param_name_from_keypath(kp), params 321 ) 322 [/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _param_name_from_keypath(keypath) 309 def _param_name_from_keypath(keypath: Tuple[Any, ...]) -> str: 310 if hasattr(ocp, 'tree'): --> 311 get_key_name = ocp.tree.get_key_name 312 else: 313 get_key_name = ocp.utils.get_key_name AttributeError: module 'orbax.checkpoint.tree' has no attribute 'get_key_name'
@
I think orbax-export just needs a version update. (b/378500084 for your reference, Jake).
Caught by the jax-ai-stack nightly tests: https://github.com/jax-ml/jax-ai-stack/issues/90#issuecomment-2466190697
Package versions:
Repro: