allenai / RL4LMs

A modular RL library to fine-tune language models to human preferences
https://rl4lms.apps.allenai.org/
Apache License 2.0
2.18k stars 191 forks source link

Error when trying to load a checkpoint from Transformers after RL training #43

Closed avacaondata closed 1 year ago

avacaondata commented 1 year ago

Hi, I have tried training a MarianMT model by maximizing bert_score, and everything worked fine until I was trying to load these weights from transformers, when I encountered an issue.

I have created a folder with the last checkpoint binary and tried to run AutoModelForSeq2SeqLM.from_pretrained(<folder_name>) from there, but threw the following error:

OSError: Error no file named pytorch_model.bin found in directory <directory> but there is a file for Flax weights. Use `from_flax=True` to load this model from those weights.

Then, I looked into transformers documentation and saw that for loading flax models there should be a flax_model.msgpack file in the directory, so I renamed the checkpoint binary to that, and retried with from_flax=True in the from_pretrained call. However, there is still an issue with loading this model:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 446, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1843, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/marian/modeling_marian.py", line 1281, in __init__
    self.model = MarianModel(config)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/marian/modeling_marian.py", line 1090, in __init__
    self.shared = nn.Embedding(vocab_size, config.d_model, padding_idx)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/torch/nn/modules/sparse.py", line 132, in __init__
    assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
AssertionError: Padding_idx must be within num_embeddings

I guess if this library is suitable for training models from transformers library, there should be no problem in loading with transformers the models after training via reinforcement learning with the library, so I would like to ask if I am doing something wrong when loading the models, or what is the correct way of loading checkpoints with transformers after training with RL4LMs.

Thanks in advance :)

rajcscw commented 1 year ago

This should work. Can you tell me if you have a model folder in the experiments result folder? If you specify that path to load_from_pretrained(), it should work.

avacaondata commented 1 year ago

@rajcscw I have a "checkpoints" folder, with a checkpoint binary per step. That's the binary I try to rename to flax_model.msgpack. However, it throws the above error. Does anyone have an example code or something on how you load a rl-trained checkpoint ?

I have a directory structure like this:

experiments/
    -checkpoints/
        .checkpoint_0
        .checkpoint_1
        ....
    -config.json
    -epoch_*_val_split_predictions.json
    -rollout_info.jsonl
    -test_split_metrics.jsonl
    -training_info.jsonl
    -val_split_metrics.jsonl

Any help is appreciated, thank you very much :)

avacaondata commented 1 year ago

Okay removing the config.json I advanced a little bit, but now I got another error when doing the .from_pretrained() call with from_flax=True. The error is:

/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/core/frozen_dict.py:169: FutureWarning: jax.tree_util.register_keypaths is deprecated, and will be removed in a future release. Please use `register_pytree_with_keys()` instead.
  jax.tree_util.register_keypaths(
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 446, in from_pretrained
    return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_utils.py", line 1870, in from_pretrained
    model = load_flax_checkpoint_in_pytorch_model(model, resolved_archive_file)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/transformers/modeling_flax_pytorch_utils.py", line 174, in load_flax_checkpoint_in_pytorch_model
    flax_state_dict = from_bytes(flax_cls, state_f.read())
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/serialization.py", line 425, in from_bytes
    state_dict = msgpack_restore(encoded_bytes)
  File "/home/alejandro.vaca/miniconda3/envs/nlp_rl/lib/python3.9/site-packages/flax/serialization.py", line 407, in msgpack_restore
    state_dict = msgpack.unpackb(
  File "msgpack/_unpacker.pyx", line 201, in msgpack._cmsgpack.unpackb
msgpack.exceptions.ExtraData: unpack(b) received extra data.
rajcscw commented 1 year ago

That would not work because each checkpoint is not just a language model but also has the trainer, alg states etc. If you just need the final model, there will be a folder named model, that is the path you need to pass it to from_pretrained()

But if your training is not finished and have only checkpoints, you can just run train_text_generation.py again by setting train_evaluation/n_iters=0 in the config and set the base_path_to_store_results, project name, experiment_name so that it loads the latest checkpoint correctly from base_path_to_store_results/project_name/experiment_name/checkpoints.

After this step, it should create a folder named model which you can pass it to from_pretrained()

avacaondata commented 1 year ago

Great that worked!! Thank you so much for helping me out !! @rajcscw :heart: We can close the issue.