huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
135.02k stars 27.02k forks source link

Loading Trained RAG Model #24386

Closed YichiRockyZhang closed 1 year ago

YichiRockyZhang commented 1 year ago

System Info

Python 3.9.16 Transformers 4.13.0 WSL

Who can help?

@ArthurZucker @younesbelkada @shamanez

Information

Tasks

Reproduction

After finetuning RAG, I'm left with the following directory, and I'm not sure how to load the resulting checkpoint.

image

I should note the checkpoint is ~6 GB while the original huggingface checkpoint is 2 GB. I suspect this is because I used the finetune_rag_ray_end2end.sh script, so it includes all 3 models (reader, retriever, generator).

Below are my attempts to load the checkpoint

Attempt 1

ds = load_dataset(path='wiki_dpr', name='psgs_w100.multiset.compressed', split='train')

rag_tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
rag_retriever = RagRetriever.from_pretrained(
    "facebook/rag-token-base",
    use_dummy_dataset=False,
    indexed_dataset=ds,
    index_name="embeddings",
)

rag_model = RagTokenForGeneration.from_pretrained("facebook/rag-token-base", retriever=rag_retriever)

checkpoint_path = "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/val_avg_em=0.0026-step_count=601.0.ckpt"

rag_model.load_state_dict(torch.load(checkpoint_path))

The program runs forever with the following traceback when I interrupt it:

Some weights of RagTokenForGeneration were not initialized from the model checkpoint at facebook/rag-token-base and are newly initialized: ['rag.generator.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/ray/_private/services.py:238: UserWarning: Not all Ray Dashboard dependencies were found. To use the dashboard please install Ray using `pip install ray[default]`. To disable this message, set RAY_DISABLE_IMPORT_WARNING env var to '1'.
  warnings.warn(warning_message)
^CTraceback (most recent call last):
  File "/nfshomes/yzhang42/rag/notebooks/rag_eval.py", line 37, in <module>
    rag_model.load_state_dict(torch.load(checkpoint_path))
  File "/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/torch/serialization.py", line 712, in load
    return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
  File "/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/torch/serialization.py", line 1049, in _load
    result = unpickler.load()
  File "/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/ray/actor.py", line 1005, in _deserialization_helper
    return worker.core_worker.deserialize_and_register_actor_handle(
  File "python/ray/_raylet.pyx", line 1594, in ray._raylet.CoreWorker.deserialize_and_register_actor_handle
  File "python/ray/_raylet.pyx", line 1563, in ray._raylet.CoreWorker.make_actor_handle
  File "/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/ray/_private/function_manager.py", line 402, in load_actor_class
    actor_class = self._load_actor_class_from_gcs(
  File "/fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/ray/_private/function_manager.py", line 487, in _load_actor_class_from_gcs
    time.sleep(0.001)
KeyboardInterrupt

Attempt 2

from transformers import AutoConfig, AutoModel, PretrainedConfig, RagTokenizer, RagRetriever, BartForConditionalGeneration, RagTokenForGeneration, RagSequenceForGeneration, RagConfig
from transformers import BartModel

qe_config = PretrainedConfig(
    name_or_path=\
        "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint601/generator_tokenizer/tokenizer_config.json")
gen_config = PretrainedConfig(
    name_or_path=\
    "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint601/question_encoder_tokenizer/tokenizer_config.json")

RagConfig.from_question_encoder_generator_configs(
    question_encoder_config=qe_config,
    generator_config=gen_config
)

Gives the following error:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 11
      4 qe_config = PretrainedConfig(
      5     name_or_path=\
      6         "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint601/generator_tokenizer/tokenizer_config.json")
      7 gen_config = PretrainedConfig(
      8     name_or_path=\
      9     "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint601/question_encoder_tokenizer/tokenizer_config.json")
---> 11 RagConfig.from_question_encoder_generator_configs(
     12     question_encoder_config=qe_config,
     13     generator_config=gen_config
     14 )

File /fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/transformers/models/rag/configuration_rag.py:183, in RagConfig.from_question_encoder_generator_configs(cls, question_encoder_config, generator_config, **kwargs)
    172 @classmethod
    173 def from_question_encoder_generator_configs(
    174     cls, question_encoder_config: PretrainedConfig, generator_config: PretrainedConfig, **kwargs
    175 ) -> PretrainedConfig:
    176     r"""
    177     Instantiate a :class:`~transformers.EncoderDecoderConfig` (or a derived class) from a pre-trained encoder model
    178     configuration and decoder model configuration.
   (...)
    181         :class:`EncoderDecoderConfig`: An instance of a configuration object
    182     """
--> 183     return cls(question_encoder=question_encoder_config.to_dict(), generator=generator_config.to_dict(), **kwargs)

File /fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/transformers/models/rag/configuration_rag.py:140, in RagConfig.__init__(self, vocab_size, is_encoder_decoder, prefix, bos_token_id, pad_token_id, eos_token_id, decoder_start_token_id, title_sep, doc_sep, n_docs, max_combined_length, retrieval_vector_size, retrieval_batch_size, dataset, dataset_split, index_name, index_path, passages_path, use_dummy_dataset, reduce_loss, label_smoothing, do_deduplication, exclude_bos_score, do_marginalize, output_retrieved, use_cache, forced_eos_token_id, **kwargs)
    136 decoder_model_type = decoder_config.pop("model_type")
    138 from ..auto.configuration_auto import AutoConfig
--> 140 self.question_encoder = AutoConfig.for_model(question_encoder_model_type, **question_encoder_config)
    141 self.generator = AutoConfig.for_model(decoder_model_type, **decoder_config)
    143 self.reduce_loss = reduce_loss

File /fs/nexus-scratch/yzhang42/miniconda3/envs/qa3/lib/python3.9/site-packages/transformers/models/auto/configuration_auto.py:492, in AutoConfig.for_model(cls, model_type, *args, **kwargs)
    490     config_class = CONFIG_MAPPING[model_type]
    491     return config_class(*args, **kwargs)
--> 492 raise ValueError(
    493     f"Unrecognized model identifier: {model_type}. Should contain one of {', '.join(CONFIG_MAPPING.keys())}"
    494 )

ValueError: Unrecognized model identifier: . Should contain one of imagegpt, qdqbert, vision-encoder-decoder, trocr, fnet, segformer, vision-text-dual-encoder, perceiver, gptj, layoutlmv2, beit, rembert, visual_bert, canine, roformer, clip, bigbird_pegasus, deit, luke, detr, gpt_neo, big_bird, speech_to_text_2, speech_to_text, vit, wav2vec2, m2m_100, convbert, led, blenderbot-small, retribert, ibert, mt5, t5, mobilebert, distilbert, albert, bert-generation, camembert, xlm-roberta, pegasus, marian, mbart, megatron-bert, mpnet, bart, blenderbot, reformer, longformer, roberta, deberta-v2, deberta, flaubert, fsmt, squeezebert, hubert, bert, openai-gpt, gpt2, transfo-xl, xlnet, xlm-prophetnet, prophetnet, xlm, ctrl, electra, speech-encoder-decoder, encoder-decoder, funnel, lxmert, dpr, layoutlm, rag, tapas, splinter, sew-d, sew, unispeech-sat, unispeech

Expected behavior

I'm not sure what expected behavior is supposed to be.

younesbelkada commented 1 year ago

Hi @YichiRockyZhang Thanks for the issue, looking at your environment (transformers == 4.13.0) I would probably give it a try with on of the newest version of transformers. It seems the config didn't saved properly the model identifier for some reason. Would it be possible to use a recent version of the lib for you?

ydshieh commented 1 year ago

Hi @YichiRockyZhang

If @younesbelkada's above suggesion is still not working, it would help a lot if you can provide a short but a bit more complete code example that you:

This way, it's easier and fast for us to reproduce and look into the issue. Thank you in advance.

YichiRockyZhang commented 1 year ago

Hi @younesbelkada . Thanks for the response! This did help as running the finetuning script now results in a more sensible saved checkpoint.

image

I can now load the model with the following:

path = "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint31"

rag_tokenizer = RagTokenizer.from_pretrained(path)
rag_retriever = RagRetriever.from_pretrained(
    path,
    use_dummy_dataset=False,
    indexed_dataset=ds,
    index_name="compressed",
)

rag_model = RagTokenForGeneration.from_pretrained(path, retriever=rag_retriever)

Hi @ydshieh ! Unfortunately, I believe my problem is specific to fine-tuning. I'm using the only fine-tuning script for this model that I can find (in huggingface documentation and even on the internet). The script uses pytorch lightning to train and save the model. The below snippet from finetune_rag.py details how the models is saved.

    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
        self.model.config.save_step = self.step_count
        # self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

        if self.custom_config.end2end:
            modified_state_dict = self.model.state_dict()
            for key in self.model.state_dict().keys():
                if key.split(".")[1] == "ctx_encoder":
                    del modified_state_dict[key]
            self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)

            save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
            self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
            self.context_tokenizer.save_pretrained(save_path_dpr)

I understand HF does not maintain these scripts, but for what it's worth, I think retrieval-augmented models are very important and should have a bit more support!

ydshieh commented 1 year ago

@YichiRockyZhang

Thanks for sharing more details. What I means is that you can still make a self-complete code snippet:

You don't need to go through the training part in the script, just the create/save part. By self-complete, it means we can just run it directly to see the failure you have. Of course you will have to wrap up things in your own way (not just showing us the definition of on_save_checkpoint). I hope this makes my previous comment a bit clear and look forward to see a reproducible code snippet 🤗

YichiRockyZhang commented 1 year ago

@ydshieh Hi, thank you for the quick responses! I've edited my above reply to reflect the fact that upgrading to transformers==4.30.2 seemed to have worked after making sure my data was ASCII encoded. Though it does seem that the fine-tuning script is only saving the whole model after the first epoch. I've adjusted the code to be

    @pl.utilities.rank_zero_only
    def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
        self.model.config.save_step = self.step_count
        # self.model.save_pretrained(save_path)
        self.tokenizer.save_pretrained(save_path)

        if self.custom_config.end2end:
            modified_state_dict = self.model.state_dict()
            for key in self.model.state_dict().keys():
                if key.split(".")[1] == "ctx_encoder":
                    del modified_state_dict[key]
            self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)

            save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
            self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
            self.context_tokenizer.save_pretrained(save_path_dpr)
        else: #NEW
            state_dict = self.model.state_dict()
            self.model.save_pretrained(save_directory=save_path, state_dict=state_dict)

I will update this thread in the morning once fine-tuning is finished. If my fix doesn't work out, I'll try to put together a more minimal and self-complete script for debugging purposes! 🤗

ydshieh commented 1 year ago

Nice and good luck :-) !

github-actions[bot] commented 1 year ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

hyadava commented 1 month ago

@YichiRockyZhang Were you able to successfully load the fine-tuned rag model from a checkpoint? I'm running into the same issue that you reported. If I try the lightning load_from_checkpoint() function to load the model from a checkpoint I run into format errors

Traceback (most recent call last):
  File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 886, in <module>
    main(args)
  File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 812, in main
    model: GenerativeQAModule = GenerativeQAModule(args)
  File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 149, in __init__
    model = BaseTransformer.load_from_checkpoint(os.path.join(args.model_name_or_path, "last_checkpoint.ckpt"))
  File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
    return self.method(cls, *args, **kwargs)
  File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1582, in load_from_checkpoint
    loaded = _load_from_checkpoint(
  File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/core/saving.py", line 63, in _load_from_checkpoint
    checkpoint = pl_load(checkpoint_path, map_location=map_location)
  File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py", line 60, in _load
    return torch.load(
  File "/miniforge/envs/iris/lib/python3.10/site-packages/torch/serialization.py", line 1097, in load
    return _load(
  File "/miniforge/envs/iris/lib/python3.10/site-packages/torch/serialization.py", line 1525, in _load
    result = unpickler.load()
  File "/miniforge/envs/iris/lib/python3.10/pickle.py", line 1213, in load
    dispatch[key[0]](self)
  File "/miniforge/envs/iris/lib/python3.10/pickle.py", line 1590, in load_reduce
    stack[-1] = func(*args)
  File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/actor.py", line 1631, in _deserialization_helper
    return worker.core_worker.deserialize_and_register_actor_handle(
  File "python/ray/_raylet.pyx", line 4478, in ray._raylet.CoreWorker.deserialize_and_register_actor_handle
  File "python/ray/_raylet.pyx", line 4429, in ray._raylet.CoreWorker.make_actor_handle
  File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/function_manager.py", line 539, in load_actor_class
    actor_class = self._load_actor_class_from_gcs(
  File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/function_manager.py", line 634, in _load_actor_class_from_gcs
    class_name = ensure_str(class_name)
  File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/utils.py", line 243, in ensure_str
    assert isinstance(s, bytes), f"Expected str or bytes, got {type(s)}"
AssertionError: Expected str or bytes, got <class 'NoneType'>
YichiRockyZhang commented 1 month ago

@hyadava This was a while ago but I would basically recommend using a different implementation. I'm fuzzy on whether I got it to work, but I do recall the process was rather troublesome. If possible, I would recommend using a different RAG implementation.

hyadava commented 1 month ago

@YichiRockyZhang Thanks for your reply! Yes I agree, the implementation seems to be half baked. Do you recommend any rag implementation (which has the fine-tuning capability)?

YichiRockyZhang commented 1 month ago

@hyadava While I didn't use it too much, Pinecone seemed pretty solid :). Best of luck!