Open jheek opened 2 years ago
I think the difficulty here is that it's hard to distinguish between a non-existing checkpoint file and a non-existing checkpoint directory. In the later case we probably don't want to raise an error.
I think it makes sense to return an additional value indicating whether the restoration actually took place. Silent failures are the scariest.
EDIT: I just realized that we can use the following pattern to check if the restoration is successful, since restore_checkpoint
guarantees that "if no step specified and no checkpoint files present, returns the passed-in target unchanged".
state = create_train_state(...)
restored = checkpoints.restore_checkpoint(ckpt_dir, state)
if state is restored:
raise FileNotFoundError(f"Cannot load checkpoint from {ckpt_dir}")
Discussed in https://github.com/google/flax/discussions/1612