google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
307 stars 36 forks source link

Latest orbax-export release is incompatible with latest orbax-checkpoint release #1314

Open jakevdp opened 2 weeks ago

jakevdp commented 2 weeks ago

Caught by the jax-ai-stack nightly tests: https://github.com/jax-ml/jax-ai-stack/issues/90#issuecomment-2466190697

Package versions:

$ pip list | grep orbax
orbax-checkpoint                   0.9.0
orbax-export                       0.0.5

Repro:

import jax.numpy as jnp
from orbax.export import JaxModule

params = {'a': jnp.array(5.0), 'b': jnp.array(1.1), 'c': jnp.array(0.55)}

def model_fn(params, inputs):
  a, b, c = params['a'], params['b'], params['c']
  return a * jnp.sin(inputs) + b * inputs + c

jax_module = JaxModule(params, model_fn)
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
[<ipython-input-2-3dff93cebc33>](https://localhost:8080/#) in <cell line: 10>()
      8   return a * jnp.sin(inputs) + b * inputs + c
      9 
---> 10 jax_module = JaxModule(params, model_fn)

[/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in __init__(self, params, apply_fn, trainable, input_polymorphic_shape, jax2tf_kwargs, jit_compile, name, pspecs, allow_multi_axis_sharding_conslidation)
    190       self._methods = dict()
    191     else:
--> 192       tf_vars = _jax_params_to_tf_variables(
    193           params, trainable, pspecs, allow_multi_axis_sharding_conslidation
    194       )

[/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _jax_params_to_tf_variables(params, trainable, pspecs, allow_multi_axis_sharding_conslidation)
    393       )
    394 
--> 395   names = _get_param_names(params)
    396   if pspecs is None:
    397     pspecs = jax.tree_util.tree_map(lambda x: None, params)

[/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _get_param_names(params)
    317     return name.replace('~', '_')
    318 
--> 319   names = jax.tree_util.tree_map_with_path(
    320       lambda kp, _: _param_name_from_keypath(kp), params
    321   )

[/usr/local/lib/python3.10/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in tree_map_with_path(f, tree, is_leaf, *rest)
   1198   keypath_leaves = list(zip(*keypath_leaves))
   1199   all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1200   return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
   1201 
   1202 

[/usr/local/lib/python3.10/dist-packages/jax/_src/tree_util.py](https://localhost:8080/#) in <genexpr>(.0)
   1198   keypath_leaves = list(zip(*keypath_leaves))
   1199   all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
-> 1200   return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
   1201 
   1202 

[/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in <lambda>(kp, _)
    318 
    319   names = jax.tree_util.tree_map_with_path(
--> 320       lambda kp, _: _param_name_from_keypath(kp), params
    321   )
    322 

[/usr/local/lib/python3.10/dist-packages/orbax/export/jax_module.py](https://localhost:8080/#) in _param_name_from_keypath(keypath)
    309   def _param_name_from_keypath(keypath: Tuple[Any, ...]) -> str:
    310     if hasattr(ocp, 'tree'):
--> 311       get_key_name = ocp.tree.get_key_name
    312     else:
    313       get_key_name = ocp.utils.get_key_name

AttributeError: module 'orbax.checkpoint.tree' has no attribute 'get_key_name'
cpgaffney1 commented 2 weeks ago

@

I think orbax-export just needs a version update. (b/378500084 for your reference, Jake).