In past few days we observed a call to save_checkpoint failing with the most recent orbax release (0.5.17). When downgrading to orbax-checkpoint==0.5.16 everything works again.
The example to reproduce can be obtained from flax docs. For convenience it's copied below.
With orbax-checkpoint==0.5.17 this code fails with an exception. With orbax-checkpoint==0.5.16 it works.
System information
OS Platform and Distribution (e.g., Linux Ubuntu 16.04): seen on both Ubuntu and Mac
Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: flax 0.8.4, jax 0.4.30, jaxlib 0.4.30
Python version: 3.11
Logs, error messages, etc:
Traceback (most recent call last):
File "/tmp/min_example.py", line 46, in <module>
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
File "/tmp/.venv/lib/python3.11/site-packages/flax/training/checkpoints.py", line 697, in save_checkpoint
orbax_checkpointer.save(
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
tmpdir = utils.create_tmp_directory(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/path/step.py", line 607, in create_tmp_directory
if multihost.is_primary_host(primary_host):
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 246, in is_primary_host
if primary_host is None or primary_host == process_index():
^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 252, in process_index
if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 1426, in value
val = getattr(self._flagvalues, self._name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
raise _exceptions.UnparsedFlagAccessError(
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.
Steps to reproduce:
import os
from typing import Optional, Any
import shutil
import numpy as np
import jax
from jax import random, numpy as jnp
import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization
import optax
ckpt_dir = '/tmp/flax_ckpt'
if os.path.exists(ckpt_dir):
shutil.rmtree(ckpt_dir) # Remove any existing checkpoints from the last notebook run.
key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,)) # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)
# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001) # An Optax SGD optimizer.
state = train_state.TrainState.create(
apply_fn=model.apply,
params=variables['params'],
tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))
# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}
# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}
# Import Flax Checkpoints.
from flax.training import checkpoints
checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
target=ckpt,
step=0,
overwrite=True,
keep=2)
Hello, flax team!
In past few days we observed a call to
save_checkpoint
failing with the most recent orbax release (0.5.17). When downgrading toorbax-checkpoint==0.5.16
everything works again.The example to reproduce can be obtained from flax docs. For convenience it's copied below.
With
orbax-checkpoint==0.5.17
this code fails with an exception. Withorbax-checkpoint==0.5.16
it works.System information
pip show flax jax jaxlib
: flax 0.8.4, jax 0.4.30, jaxlib 0.4.30Logs, error messages, etc:
Steps to reproduce: