Open Devesan opened 2 years ago
checkpiont.load()
should only be used with the pre-trained checkpoints (we use a different format there, and those checkpoints only contain the parameters without any other train state).
For loading a checkpoint from fine-tuning you can use flax.training.checkpoints.restore_checkpoint()
like here:
@andsteing But how can I use the loaded checkpoint to make inference?
By calling model.apply()
-- see the example in the Colab:
After fine-tuning the model, it is storing it as checkpoint_10000 without any extension so when I try to load it I'm getting this error. Code:
checkpoint.load('checkpoint_10000')
""" Traceback (most recent call last): File "vit_jax/test_celeb.py", line 9, in
checkpoint.load('checkpoint_10000')
File "/home/gdevesan_gmail_com/resh_trans/vision_transformer/vit_jax/checkpoint.py", line 113, in load
ckpt_dict = np.load(f, allow_pickle=False)
File "/home/gdevesan_gmail_com/env/lib/python3.8/site-packages/numpy/lib/npyio.py", line 445, in load
raise ValueError("Cannot load file containing pickled data "
ValueError: Cannot load file containing pickled data when allow_pickle=False
"""