Closed avacaondata closed 1 year ago
This should work. Can you tell me if you have a model folder in the experiments result folder? If you specify that path to load_from_pretrained(), it should work.
@rajcscw I have a "checkpoints" folder, with a checkpoint binary per step. That's the binary I try to rename to flax_model.msgpack
. However, it throws the above error. Does anyone have an example code or something on how you load a rl-trained checkpoint ?
I have a directory structure like this:
experiments/
-checkpoints/
.checkpoint_0
.checkpoint_1
....
-config.json
-epoch_*_val_split_predictions.json
-rollout_info.jsonl
-test_split_metrics.jsonl
-training_info.jsonl
-val_split_metrics.jsonl
Any help is appreciated, thank you very much :)
Okay removing the config.json
I advanced a little bit, but now I got another error when doing the .from_pretrained()
call with from_flax=True
. The error is:
/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
jax.tree_util.register_keypaths(
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 446, in from_pretrained
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1870, in from_pretrained
model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py", line 174, in load_flax_checkpoint_in_pytorch_model
flax_state_dict = from_bytes(flax_cls, state_f.read())
File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/serialization.py", line 425, in from_bytes
state_dict = msgpack_restore(encoded_bytes)
File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/serialization.py", line 407, in msgpack_restore
state_dict = msgpack.unpackb(
File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.
That would not work because each checkpoint is not just a language model but also has the trainer, alg states etc.
If you just need the final model, there will be a folder named model, that is the path you need to pass it to from_pretrained()
But if your training is not finished and have only checkpoints, you can just run train_text_generation.py
again by setting train_evaluation/n_iters=0
in the config and set the base_path_to_store_results
, project name
, experiment_name
so that it loads the latest checkpoint correctly
from base_path_to_store_results/project_name/experiment_name/checkpoints
.
After this step, it should create a folder named model
which you can pass it to from_pretrained()
Great that worked!! Thank you so much for helping me out !! @rajcscw :heart: We can close the issue.
Hi, I have tried training a MarianMT model by maximizing bert_score, and everything worked fine until I was trying to load these weights from transformers, when I encountered an issue.
I have created a folder with the last checkpoint binary and tried to run
AutoModelForSeq2SeqLM.from_pretrained(<folder_name>)
from there, but threw the following error:Then, I looked into transformers documentation and saw that for loading flax models there should be a
flax_model.msgpack
file in the directory, so I renamed the checkpoint binary to that, and retried withfrom_flax=True
in thefrom_pretrained
call. However, there is still an issue with loading this model:I guess if this library is suitable for training models from transformers library, there should be no problem in loading with transformers the models after training via reinforcement learning with the library, so I would like to ask if I am doing something wrong when loading the models, or what is the correct way of loading checkpoints with transformers after training with RL4LMs.
Thanks in advance :)