google-research / vision_transformer

Apache License 2.0
10.05k stars 1.26k forks source link

Issue when loading from fine-tuned checkpoints. #160

Open Devesan opened 2 years ago

Devesan commented 2 years ago

After fine-tuning the model, it is storing it as checkpoint_10000 without any extension so when I try to load it I'm getting this error. Code:

checkpoint.load('checkpoint_10000')

""" Traceback (most recent call last): File "vit_jax/test_celeb.py", line 9, in checkpoint.load('checkpoint_10000') File "/home/gdevesan_gmail_com/resh_trans/vision_transformer/vit_jax/checkpoint.py", line 113, in load ckpt_dict = np.load(f, allow_pickle=False) File "/home/gdevesan_gmail_com/env/lib/python3.8/site-packages/numpy/lib/npyio.py", line 445, in load raise ValueError("Cannot load file containing pickled data " ValueError: Cannot load file containing pickled data when allow_pickle=False """

andsteing commented 2 years ago

checkpiont.load() should only be used with the pre-trained checkpoints (we use a different format there, and those checkpoints only contain the parameters without any other train state).

For loading a checkpoint from fine-tuning you can use flax.training.checkpoints.restore_checkpoint() like here:

https://github.com/google-research/vision_transformer/blob/3f3b50d3490f8b2f46f084a14683407b41f5c5a7/vit_jax/train.py#L147-L148

yjqiu commented 2 years ago

@andsteing But how can I use the loaded checkpoint to make inference?

andsteing commented 2 years ago

By calling model.apply() -- see the example in the Colab:

https://colab.sandbox.google.com/github/google-research/vision_transformer/blob/main/vit_jax.ipynb#scrollTo=N-wIdj_qnbIM