AkariAsai / self-rag

This includes the original implementation of SELF-RAG: Learning to Retrieve, Generate and Critique through self-reflection by Akari Asai, Zeqiu Wu, Yizhong Wang, Avirup Sil, and Hannaneh Hajishirzi.
https://selfrag.github.io/
MIT License
1.81k stars 167 forks source link

Training Critic Model #11

Open etherion-1337 opened 1 year ago

etherion-1337 commented 1 year ago

Hey @AkariAsai awesome work your team has done :)

I am trying to get access to the 7B critic model mentioned in the paper and I noticed it is not released. If you have a trained model I am happy to test it as well.

At the same time I am trying to train this critic model with your provided "gpt4_reward_all_0813_train.json". But it seems like it is not compatible by directly running your "/data_creation/train_special_tokens.py". Do you happen to have an preprocessing script or could you provide me the training data you have processed ?

EDIT: I just realised Line 239 is the culprit, I have changed to: prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] and it runs ok.

For training critic model is it ok to remove prompt_no_input_paragraph and prompt_no_input_separated ?

AkariAsai commented 1 year ago

Hi @etherion-1337, thank you for your interests! yes, we are happy to release it too. I will upload the critic model as well later today.

re the error: I apologize for the error! During refactoring I removed or modified some old impleentaitons and haven't tested the critic model training code. Yes, those two aren't used during training of the latest model so feel free to remove it!

AkariAsai commented 1 year ago

We started uploading the Critic model, and it should be available tomorrow. https://huggingface.co/selfrag/self_rag_critic

etherion-1337 commented 1 year ago

Hi @AkariAsai thanks for uploading. I will do some testing.

minstar commented 12 months ago

Hi @AkariAsai, thanks for uploading the trained critic model. I've tried to use and test this critic model. However, it showed that it doesn't seem to follow the instructions on different domains (e.g., biomedicine) during inference.

For example,

instruction: When given instruction and evidence, evaluate whether the evidence is relevant to the instruction and provides valuable information for generating meaningful responses.\nUse a rating of [Relevant] to indicate relevance and usefulness, and [Irrelevant] to indicate irrelevance.
input: 'Task instruction: Constipation is caused by all of the following drugs EXCEPT : What of the following is the right choice?
(A) Neostigmine (B) Atropine (C) Morphine (D) Fentanyl
Evidence: Opioid effect permitting smaller doses of opioids be used. Consequently, several opioid/antihistamine combination products have been marketed, such as "Meprozine" (meperidine/promethazine) and "Diconal" (dipipanone/cyclizine), and these may also reduce opioid induced nausea. Opioid-induced constipation (OIC) develops in 90 to 95% of people taking opioids long-term. Since tolerance to this problem does not develop readily, most people on long-term opioids need to take a laxative or enemas. While all opioids cause constipation, there are some differences between drugs, with studies suggesting tramadol, tapentadol, methadone and fentanyl may cause relatively less constipation, while with codeine, morphine, oxycodone or hydromorphone constipation may be comparatively
GPT-4 output: [Relevant]
Prediction: Fully supported]

Do you have any ideas that I can refer to?

AkariAsai commented 12 months ago

heh let me double check this week. We have trained multiple different variants for critic, and we might have uploaded different checkpoint -- in one of the checkpoints we have mismapped instructions and the label predictions are somewhat flipped. Sorry for the inconvenience!

minstar commented 11 months ago

Hi @AkariAsai, Is there any progress on this?

roynirmal commented 11 months ago

Hello @AkariAsai @etherion-1337

I am trying to train the critic training model by downloading the training data from the google drive link. Firstly, I had to remove the import from ..retrieval_lm.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn as it was throwing an import error.

Removing that makes the code run