google / orbax

Orbax provides common checkpointing and persistence utilities for JAX users
https://orbax.readthedocs.io/
Apache License 2.0
307 stars 36 forks source link

Hi, I am trying to save checkpoints using the following code: #1225

Closed raresdolga closed 1 month ago

raresdolga commented 1 month ago

Hi, I am trying to save checkpoints using the following code:

options = ocp.CheckpointManagerOptions(
            max_to_keep=self.max_checkpoints,
            create=True,
            best_fn=best_loss,
            best_mode="min",
        )

        self.checkpoint_manager = ocp.CheckpointManager(
            os.path.join(self._out_dir, "checkpoints"),
            options=options,
            item_names=("state", "metadata"),
            item_handlers={
                "state": ocp.StandardCheckpointHandler(),
                "metadata": ocp.JsonCheckpointHandler(),
            },
        )

The problem happens when I try to create the folder. It checks for multiprocessing using a flag, but the flag is not parsed. Error is:

absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.

Seems to be triggered by this line of code in the library:

absl.flags._exceptions.UnparsedFlagAccessError: Trying to access flag --experimental_orbax_use_distributed_process_id before flags were parsed.

I should note that I am running my code from a child process.

Do you have any suggestions on how to avoid this? I am not sure why the flag is not picked up since it has a default value. I tried parsing them from the main file using args, but then it asks to define my own arguments as flags, which seems as an overhead.

Originally posted by @raresdolga in https://github.com/google/orbax/discussions/962

raresdolga commented 1 month ago

If I go to "orbax/checkpoint/multihost/utils.py" and comment out the following check, it works: Original:

def process_index() -> int:
  if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
    logging.info('Using distributed process id.')
    return jax._src.distributed.global_state.process_id  # pylint: disable=protected-access
  else:
    return jax.process_index()

New:

def process_index() -> int:
  # if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
  #   logging.info('Using distributed process id.')
  #   return jax._src.distributed.global_state.process_id  # pylint: disable=protected-access
  # else:
  return jax.process_index()
cpgaffney1 commented 1 month ago

You must have an older version of Orbax because process_index implementation doesn't look like that anymore. The latest version as a try-catch to prevent an issue with the flag from disabling anyone. Can you try updating the version?

try:
    experimental_orbax_use_distributed_process_id = (
        EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value
    )
  except Exception:  # pylint: disable=broad-exception-caught
    logging.log_first_n(
        logging.INFO,
        '[thread=%s] Failed to get flag value for'
        ' EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.',
        1,
        threading.current_thread().name,
    )
    experimental_orbax_use_distributed_process_id = False
raresdolga commented 1 month ago

Hi,

Indeed I updated and now it works.

Closing this issue. Thank you!