Closed YichiRockyZhang closed 1 year ago
Hi @YichiRockyZhang Thanks for the issue, looking at your environment (transformers == 4.13.0) I would probably give it a try with on of the newest version of transformers. It seems the config didn't saved properly the model identifier for some reason. Would it be possible to use a recent version of the lib for you?
Hi @YichiRockyZhang
If @younesbelkada's above suggesion is still not working, it would help a lot if you can provide a short but a bit more complete code example that you:
This way, it's easier and fast for us to reproduce and look into the issue. Thank you in advance.
Hi @younesbelkada . Thanks for the response! This did help as running the finetuning script now results in a more sensible saved checkpoint.
I can now load the model with the following:
path = "/fs/nexus-scratch/yzhang42/rag_end2end/model_checkpoints_MS/checkpoint31"
rag_tokenizer = RagTokenizer.from_pretrained(path)
rag_retriever = RagRetriever.from_pretrained(
path,
use_dummy_dataset=False,
indexed_dataset=ds,
index_name="compressed",
)
rag_model = RagTokenForGeneration.from_pretrained(path, retriever=rag_retriever)
Hi @ydshieh ! Unfortunately, I believe my problem is specific to fine-tuning. I'm using the only fine-tuning script for this model that I can find (in huggingface documentation and even on the internet). The script uses pytorch lightning to train and save the model. The below snippet from finetune_rag.py
details how the models is saved.
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
self.model.config.save_step = self.step_count
# self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
if self.custom_config.end2end:
modified_state_dict = self.model.state_dict()
for key in self.model.state_dict().keys():
if key.split(".")[1] == "ctx_encoder":
del modified_state_dict[key]
self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)
save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
self.context_tokenizer.save_pretrained(save_path_dpr)
I understand HF does not maintain these scripts, but for what it's worth, I think retrieval-augmented models are very important and should have a bit more support!
@YichiRockyZhang
Thanks for sharing more details. What I means is that you can still make a self-complete code snippet:
on_save_checkpoint
you providedYou don't need to go through the training part in the script, just the create/save part. By self-complete
, it means we can just run it directly to see the failure you have. Of course you will have to wrap up things in your own way (not just showing us the definition of on_save_checkpoint
). I hope this makes my previous comment a bit clear and look forward to see a reproducible code snippet 🤗
@ydshieh Hi, thank you for the quick responses! I've edited my above reply to reflect the fact that upgrading to transformers==4.30.2 seemed to have worked after making sure my data was ASCII encoded. Though it does seem that the fine-tuning script is only saving the whole model after the first epoch. I've adjusted the code to be
@pl.utilities.rank_zero_only
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
save_path = self.output_dir.joinpath("checkpoint{}".format(self.step_count))
self.model.config.save_step = self.step_count
# self.model.save_pretrained(save_path)
self.tokenizer.save_pretrained(save_path)
if self.custom_config.end2end:
modified_state_dict = self.model.state_dict()
for key in self.model.state_dict().keys():
if key.split(".")[1] == "ctx_encoder":
del modified_state_dict[key]
self.model.save_pretrained(save_directory=save_path, state_dict=modified_state_dict)
save_path_dpr = os.path.join(self.dpr_ctx_check_dir, "checkpoint{}".format(self.step_count))
self.model.rag.ctx_encoder.save_pretrained(save_path_dpr)
self.context_tokenizer.save_pretrained(save_path_dpr)
else: #NEW
state_dict = self.model.state_dict()
self.model.save_pretrained(save_directory=save_path, state_dict=state_dict)
I will update this thread in the morning once fine-tuning is finished. If my fix doesn't work out, I'll try to put together a more minimal and self-complete script for debugging purposes! 🤗
Nice and good luck :-) !
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
@YichiRockyZhang Were you able to successfully load the fine-tuned rag model from a checkpoint? I'm running into the same issue that you reported. If I try the lightning load_from_checkpoint()
function to load the model from a checkpoint I run into format errors
Traceback (most recent call last):
File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 886, in <module>
main(args)
File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 812, in main
model: GenerativeQAModule = GenerativeQAModule(args)
File "/mnt/task_runtime/examples/research_projects/rag-end2end-retriever/finetune_rag.py", line 149, in __init__
model = BaseTransformer.load_from_checkpoint(os.path.join(args.model_name_or_path, "last_checkpoint.ckpt"))
File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/utilities/model_helpers.py", line 125, in wrapper
return self.method(cls, *args, **kwargs)
File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1582, in load_from_checkpoint
loaded = _load_from_checkpoint(
File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/pytorch/core/saving.py", line 63, in _load_from_checkpoint
checkpoint = pl_load(checkpoint_path, map_location=map_location)
File "/miniforge/envs/iris/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py", line 60, in _load
return torch.load(
File "/miniforge/envs/iris/lib/python3.10/site-packages/torch/serialization.py", line 1097, in load
return _load(
File "/miniforge/envs/iris/lib/python3.10/site-packages/torch/serialization.py", line 1525, in _load
result = unpickler.load()
File "/miniforge/envs/iris/lib/python3.10/pickle.py", line 1213, in load
dispatch[key[0]](self)
File "/miniforge/envs/iris/lib/python3.10/pickle.py", line 1590, in load_reduce
stack[-1] = func(*args)
File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/actor.py", line 1631, in _deserialization_helper
return worker.core_worker.deserialize_and_register_actor_handle(
File "python/ray/_raylet.pyx", line 4478, in ray._raylet.CoreWorker.deserialize_and_register_actor_handle
File "python/ray/_raylet.pyx", line 4429, in ray._raylet.CoreWorker.make_actor_handle
File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/function_manager.py", line 539, in load_actor_class
actor_class = self._load_actor_class_from_gcs(
File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/function_manager.py", line 634, in _load_actor_class_from_gcs
class_name = ensure_str(class_name)
File "/miniforge/envs/iris/lib/python3.10/site-packages/ray/_private/utils.py", line 243, in ensure_str
assert isinstance(s, bytes), f"Expected str or bytes, got {type(s)}"
AssertionError: Expected str or bytes, got <class 'NoneType'>
@hyadava This was a while ago but I would basically recommend using a different implementation. I'm fuzzy on whether I got it to work, but I do recall the process was rather troublesome. If possible, I would recommend using a different RAG implementation.
@YichiRockyZhang Thanks for your reply! Yes I agree, the implementation seems to be half baked. Do you recommend any rag implementation (which has the fine-tuning capability)?
@hyadava While I didn't use it too much, Pinecone seemed pretty solid :). Best of luck!
System Info
Python 3.9.16 Transformers 4.13.0 WSL
Who can help?
@ArthurZucker @younesbelkada @shamanez
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
After finetuning RAG, I'm left with the following directory, and I'm not sure how to load the resulting checkpoint.
I should note the checkpoint is ~6 GB while the original huggingface checkpoint is 2 GB. I suspect this is because I used the
finetune_rag_ray_end2end.sh
script, so it includes all 3 models (reader, retriever, generator).Below are my attempts to load the checkpoint
Attempt 1
The program runs forever with the following traceback when I interrupt it:
Attempt 2
Gives the following error:
Expected behavior
I'm not sure what expected behavior is supposed to be.