google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.78k stars 610 forks source link

`save_checkpoint` fails with the most recent orbax release #4015

Closed apaleyes closed 1 week ago

apaleyes commented 1 week ago

Hello, flax team!

In past few days we observed a call to save_checkpoint failing with the most recent orbax release (0.5.17). When downgrading to orbax-checkpoint==0.5.16 everything works again.

The example to reproduce can be obtained from flax docs. For convenience it's copied below.

With orbax-checkpoint==0.5.17 this code fails with an exception. With orbax-checkpoint==0.5.16 it works.

System information

Logs, error messages, etc:

Traceback (most recent call last):
  File "/tmp/min_example.py", line 46, in <module>
    checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
  File "/tmp/.venv/lib/python3.11/site-packages/flax/training/checkpoints.py", line 697, in save_checkpoint
    orbax_checkpointer.save(
  File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 165, in save
    tmpdir = utils.create_tmp_directory(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/path/step.py", line 607, in create_tmp_directory
    if multihost.is_primary_host(primary_host):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 246, in is_primary_host
    if primary_host is None or primary_host == process_index():
                                               ^^^^^^^^^^^^^^^
  File "/tmp/.venv/lib/python3.11/site-packages/orbax/checkpoint/multihost/utils.py", line 252, in process_index
    if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 1426, in value
    val = getattr(self._flagvalues, self._name)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/.venv/lib/python3.11/site-packages/absl/flags/_flagvalues.py", line 498, in __getattr__
    raise _exceptions.UnparsedFlagAccessError(
absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.

Steps to reproduce:

import os
from typing import Optional, Any
import shutil

import numpy as np
import jax
from jax import random, numpy as jnp

import flax
from flax import linen as nn
from flax.training import checkpoints, train_state
from flax import struct, serialization

import optax

ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

key1, key2 = random.split(random.key(0))
x1 = random.normal(key1, (5,))      # A simple JAX array.
model = nn.Dense(features=3)
variables = model.init(key2, x1)

# Flax's TrainState is a pytree dataclass and is supported in checkpointing.
# Define your class with `@flax.struct.dataclass` decorator to make it compatible.
tx = optax.sgd(learning_rate=0.001)      # An Optax SGD optimizer.
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=tx)
# Perform a simple gradient update similar to the one during a normal training workflow.
state = state.apply_gradients(grads=jax.tree_util.tree_map(jnp.ones_like, state.params))

# Some arbitrary nested pytree with a dictionary and a NumPy array.
config = {'dimensions': np.array([5, 3])}

# Bundle everything together.
ckpt = {'model': state, 'config': config, 'data': [x1]}

# Import Flax Checkpoints.
from flax.training import checkpoints

checkpoints.save_checkpoint(ckpt_dir='/tmp/flax_ckpt/flax-checkpointing',
                            target=ckpt,
                            step=0,
                            overwrite=True,
                            keep=2)
IvyZX commented 1 week ago

They pushed some new releases today and it's 0.5.19 now - should fix this error.

apaleyes commented 1 week ago

thanks @IvyZX ! it's getting quite scary... it's 0.5.20 now! guess it's probably best to wait for the things to stabilise