google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.04k stars 640 forks source link

checkpoints.restore_checkpoint` should raise an error for explicit paths that don't exist #1631

Open jheek opened 2 years ago

jheek commented 2 years ago

Discussed in https://github.com/google/flax/discussions/1612

Originally posted by **PgLoLo** October 9, 2021 From documentation of `flax.training.checkpoints.restore_checkpoint`: >Returns: > Restored `target` updated from checkpoint file, or if no step specified and > no checkpoint files present, returns the passed-in `target` unchanged. > If a file path is specified and is not found, the passed-in `target` will be > returned. This is to match the behavior of the case where a directory path > is specified but the directory has not yet been created. Why silently hiding the checkpoint absence (which is a great potential for a bug: by path misspecification, step misspecification, etc.) is a good idea? Makes no sense to me, is there any logic behind this decision?
jheek commented 2 years ago

I think the difficulty here is that it's hard to distinguish between a non-existing checkpoint file and a non-existing checkpoint directory. In the later case we probably don't want to raise an error.

nalzok commented 2 years ago

I think it makes sense to return an additional value indicating whether the restoration actually took place. Silent failures are the scariest.

EDIT: I just realized that we can use the following pattern to check if the restoration is successful, since restore_checkpoint guarantees that "if no step specified and no checkpoint files present, returns the passed-in target unchanged".

state = create_train_state(...)
restored = checkpoints.restore_checkpoint(ckpt_dir, state)
if state is restored:
    raise FileNotFoundError(f"Cannot load checkpoint from {ckpt_dir}")