Closed raresdolga closed 1 month ago
If I go to "orbax/checkpoint/multihost/utils.py" and comment out the following check, it works: Original:
def process_index() -> int:
if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
logging.info('Using distributed process id.')
return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
else:
return jax.process_index()
New:
def process_index() -> int:
# if EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value:
# logging.info('Using distributed process id.')
# return jax._src.distributed.global_state.process_id # pylint: disable=protected-access
# else:
return jax.process_index()
You must have an older version of Orbax because process_index
implementation doesn't look like that anymore. The latest version as a try-catch to prevent an issue with the flag from disabling anyone. Can you try updating the version?
try:
experimental_orbax_use_distributed_process_id = (
EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.value
)
except Exception: # pylint: disable=broad-exception-caught
logging.log_first_n(
logging.INFO,
'[thread=%s] Failed to get flag value for'
' EXPERIMENTAL_ORBAX_USE_DISTRIBUTED_PROCESS_ID.',
1,
threading.current_thread().name,
)
experimental_orbax_use_distributed_process_id = False
Hi,
Indeed I updated and now it works.
Closing this issue. Thank you!
Hi, I am trying to save checkpoints using the following code:
The problem happens when I try to create the folder. It checks for multiprocessing using a flag, but the flag is not parsed. Error is:
Seems to be triggered by this line of code in the library:
I should note that I am running my code from a child process.
Do you have any suggestions on how to avoid this? I am not sure why the flag is not picked up since it has a default value. I tried parsing them from the main file using args, but then it asks to define my own arguments as flags, which seems as an overhead.
Originally posted by @raresdolga in https://github.com/google/orbax/discussions/962