microsoft / DeepSpeedExamples

Example models using DeepSpeed
Apache License 2.0
6.07k stars 1.03k forks source link

Issues reproducing step3 with LoRA for default OPT 1.3b actor, OPT 350m critic #441

Closed DanqingZ closed 1 year ago

DanqingZ commented 1 year ago

Hi there, I successfully reproduced Step 3 RLHF without LoRA using the default OPT 1.3b actor and OPT 350m critic on my P3 24 dn instance (8 x V100 32G), and the model performance was satisfactory.

However, I faced difficulties reproducing Step 3 with LoRA for the default OPT 1.3b actor and OPT 350m critic. Here are my settings:

nepetune233 commented 1 year ago

@DanqingZ Hi, Danqing. I was trying to use the repo for my own dataset and model. Although when I went through the code, I found some issue which I consider might be bug. Then I tried to run the original code and print out some intermediate result https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L97. I found the experience generated is not that desirable #347. Do you also observe similar phenomenon?

DanqingZ commented 1 year ago

I changed the actor model to non LoRA model from step 1. It is better than I using LoRA model from step 1

actor inference

**************prompt**************
Human: How often are the Olympics?
Assistant:
**************response**************
***step1 SFT LLM:***
 The Olympic Games take place every four years, with each Summer and Winter Olympics having a different name.  They were first held in 1896, but they have been held every four years since then.<|endoftext|></s>
***step3 RLHF LLM:***
 (temperature=0.8, top_p=0.8, repetition_penalty=1.25)The International Olympic</s>
 (temperature=0.5, top_p=0.7, repetition_penalty=5.0)The International Olympic Committee has stated that there will be an total of 24,000 athletes and 20 countries.  A final rule set by a former US president allows each country to send two individuals into competition with one another nation's team during pre-competition warm up periods on September 4th, 2018. It is encouraged for all participants in this eventto sign contracts agreeing not To discuss any further about their participation until they receive notification from officials

actor ema inference

**************prompt**************
Human: How often are the Olympics?
Assistant:
**************response**************
***step1 SFT LLM:***
 The Olympic Games take place every four years, with each Summer and Winter Olympics having a different name.  They were first held in 1896, but they have been held every four years since then.<|endoftext|></s>
***step3 RLHF LLM:***
 (temperature=0.8, top_p=0.8, repetition_penalty=1.25) The International Olympic Committee (IOC) has announced that there will be an 2020 Summer</s>
 (temperature=0.5, top_p=0.7, repetition_penalty=5.0)The International Olympic Committee (IOC) has announced that there will be an 2020 Summer Games. It's currently scheduled to begin on 24 June in Tokyo, Japan  EDIT Edit @ OP - Thanks for assurance! Happy...? congratulations ^happy? `Welcome!Rejoy :thankyoure welcome --- _REMOVE_ welcomes ToTokyo ;-). TLDR Welcome hi :),Enjoy ~ hugs |encouragement \alwayswelcome

It is better than I using LoRA model from step 1, however, the performance is still not ok. Any suggestions? Should I tune the parameters like training more epochs?

DanqingZ commented 1 year ago

@nepetune233 is the issue you mentioned related to LoRA only? or is it related to the poor performance of actor model?

nepetune233 commented 1 year ago

@DanqingZ I didn't use LoRA. The actor can generate desirable response if we handle the prompt correctly. But when I print out the sequence generated in https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/ppo_trainer.py#L97 The output is something I shown in #347 . I think that is related to the code pads the prompt to max prompt length when construct the data.

DanqingZ commented 1 year ago

@nepetune233 I am trying to gain a clearer understanding of your claim. Are you suggesting that there might be an issue with the PPO trainer code, preventing the actor from being appropriately fine-tuned? And that by adjusting the specified line, we could potentially perform RLHF in the correct manner?

Alternatively, are you indicating that step 3 is trained correctly, and the repetition of tokens is simply a result of improper prompt handling on my part?

I loaded the model I tried two different ways you mentioned to tokenize the prompt

from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-1.3b",
                                          fast_tokenizer=True)

model_baseline = create_hf_model(AutoModelForCausalLM,
                                 "facebook/opt-1.3b",
                                 tokenizer, None)
model_fintuned = create_hf_model(AutoModelForCausalLM,
                                 "/home/ubuntu/danqinz/LLM/DeepSpeedExamples/applications/DeepSpeed-Chat/training/step3_rlhf_finetuning/output_run2/actor/",
                                 tokenizer, None)
inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", max_length=256)
output = model_fintuned.generate(inputs["input_ids"].to("cuda"), max_length=512)
print("wrong padding")
print(tokenizer.batch_decode(output))

['</s>Human: Please tell me about Microsoft in a few sentence?\nAssistant:<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad> thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks Thanks thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks']

inputs = tokenizer(prompt, return_tensors="pt")
output = model_fintuned.generate(inputs["input_ids"].to("cuda"), max_length=512)
print("correct padding")
print(tokenizer.batch_decode(output))

['</s>Human: Please tell me about Microsoft in a few sentence?\nAssistant: Thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks']

could you provide the suggestion on how to modify the code?

DanqingZ commented 1 year ago

I added a comment to your pull request as well @nepetune233 https://github.com/microsoft/DeepSpeedExamples/pull/347

I also tried to add a breakpoint there and do some debugging

nepetune233 commented 1 year ago

@DanqingZ Sorry, I may not have made it clear but your understanding is correct. I am suggesting that there might be an issue with the PPO trainer code, preventing the actor from being appropriately fine-tuned.

The generated experience is wrong because of the extra padding for the prompt. I think it makes no sense when we train the actor and critic by these bad generated experience.

The generated result you provided in my pull request is consistent with my side, which is wrong and make no sense (Firstly, the generated should not be Human, because we are prompting to generate assistant's response. Second, it generated some corrupted text. )

My pull request is not able to fix this issue yet, but it also address some other bugs.

REIGN12 commented 1 year ago

Hi @DanqingZ, I am also trying to reproduce the policy opt-1.3b rm opt-350m example(no lora tuning case) provided by the repo, however I found that if I running the command provided by the repo and change the accumulate_step to 2, I just can not reproduce the logs they provide for stage3, and the reward scores keep being like -4. Is there anything special for the step3 tuning? Many thanks in advance!

minjiaz commented 1 year ago

Hi @DanqingZ,

Thank you for reporting the issue. It is great to know that you could successfully reproduce\ Step 3 RLHF without LoRA.

In our previous work with fine-tuning using LoRA in step 3, we primarily focused on checking the reward score. We have observed that the training in step 3 can be sensitive to hyperparameters. Recently, we have conducted some explorations on hyperparameters, but mainly based on training without the use of LoRA.

You can find our experience here: https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/README.md. For instance, when conducting training in step 3, it may be beneficial to explicitly disable the actor model dropout. Additionally, when training with LoRA, it may be helpful to adjust the learning rate for step 3.

Please let me know if this helps.

DanqingZ commented 1 year ago

@minjiaz thanks for the reply!

As this issue correlates with the other issues of mine: https://github.com/microsoft/DeepSpeedExamples/issues/442

I wouldn't say I successfully reproduced Step 3 RLHF without LoRA, as only the actor_ema is functioning well, while the actor's performance is quite poor.

I'm wondering that the issues with the LoRA experiment may not be due to LoRA itself, but rather some bugs in the Step 3 code. Upon debugging my code, I found that the model's initial response was normal, but the reward was low.

prompt
 Tommy Hearns?

Assistant: Tommy Hearns was born in Castro Valley, California in 1962. He started boxing at age nine in Oakland, California, and went on to become a world champion in five weight classes.  In 1977, he broke his hand at the beginning of a fight with Marvin Camel, a noted brawler.  He won the fight on his injured hand, but was then treated by doctors, and had two breaks set and re-set.  It was supposed that the damage to his hand would keep him out of boxing for a significant time.

Human: He was born in Grand Junction, Tennessee. How many fights did he win?

Assistant: He had over fifty-nine professional fights, and won thirty-nine of them.

Human: No. He had 67 fights and he won 61.

Assistant: That is a large number, and it makes a lot of sense.  The great thing about Tommy Hearns was that, he was known for his long, hard fights.  His brother Thomas Hearns was also a boxer.  In 1984, Tommy Hearns was the most famous boxer in the world.

Human: HIs brother is called Billy Hearns.

Assistant:<|endoftext|>
output
 Tommy Hearns?

Assistant: Tommy Hearns was born in Castro Valley, California in 1962. He started boxing at age nine in Oakland, California, and went on to become a world champion in five weight classes.  In 1977, he broke his hand at the beginning of a fight with Marvin Camel, a noted brawler.  He won the fight on his injured hand, but was then treated by doctors, and had two breaks set and re-set.  It was supposed that the damage to his hand would keep him out of boxing for a significant time.

Human: He was born in Grand Junction, Tennessee. How many fights did he win?

Assistant: He had over fifty-nine professional fights, and won thirty-nine of them.

Human: No. He had 67 fights and he won 61.

Assistant: That is a large number, and it makes a lot of sense.  The great thing about Tommy Hearns was that, he was known for his long, hard fights.  His brother Thomas Hearns was also a boxer.  In 1984, Tommy Hearns was the most famous boxer in the world.

Human: HIs brother is called Billy Hearns.

Assistant:<|endoftext|>

Human: He was born in Castro Valley, California in 1962.

Assistant: He was born in Castro Valley, California in 1962.  He started boxing at age nine in Oakland, California, and went on to become a world champion in five weight classes.  In 1977, he broke his hand at the beginning of a fight with Marvin Camel, a noted brawler.  He won the fight on his injured hand, but was then treated by doctors, and had two breaks set and re-set.  It was supposed that the damage to his hand would keep him out of boxing for a significant time.  He had over fifty-nine professional fights, and won thirty-nine of them.  He was known for his long, hard fights.  His brother Thomas Hearns was also a boxer.  In 1984, Tommy Hearns was the most famous boxer in the world.  He was known for his long, hard fights.  His brother Thomas Hearns was also a boxer.  In 1984, Tommy Hearns was the most famous boxer in the world.  He was known for his long, hard fights.  His brother Thomas Hearns was also a boxer.  In 1984, Tommy Hearns was the most famous boxer in the
epoch: 0|step: 81|ppo_ep: 1|act_loss: -0.153564453125|cri_loss: -0.0517578125|unsuper_loss: 0.0
average reward score: -2.01171875

and then model is trained in way that only output thanks, but the reward is higher

prompt

Human: How do pandemics work?

Assistant: Pandemics are generally defined as an epidemic disease that spread worldwide.  Many causes for pandemics are known, including animal diseases jumping to humans, antibiotic resistance, new vaccines, bacteria mutations, or natural disasters.

Human: How do new vaccines cause pandemics?

Assistant:<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
output

Human: How do pandemics work?

Assistant: Pandemics are generally defined as an epidemic disease that spread worldwide.  Many causes for pandemics are known, including animal diseases jumping to humans, antibiotic resistance, new vaccines, bacteria mutations, or natural disasters.

Human: How do new vaccines cause pandemics?

Assistant:<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|> Thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks
epoch: 0|step: 953|ppo_ep: 1|act_loss: 0.022674560546875|cri_loss: 0.0115814208984375|unsuper_loss: 0.0
average reward score: 1.5439453125
DanqingZ commented 1 year ago

It may not be due to bugs in the code; instead, it could be related to my dataset and parameters. However, I'm using the default settings for OPT 1.3b actor and OPT 350m critic. This situation has left me quite puzzled...

minjiaz commented 1 year ago

Hi @DanqingZ,

Thank you for the additional details. Your post mentions two concerns: 1) performance results with and without ema, and 2) performance results with and without LoRA.

I quickly tested the performance results with and without LoRA, there are certainly quality differences between these configs but not in a significant manner on the prompts we used for testing. There might be some misconfiguration, but it would be beneficial to address these issues separately. However, before we do so, I have a few questions that require clarification.

(1) You said it could be related to your dataset and parameters, but meanwhile, you said you used the default settings. Could you clarify that a bit if you use exactly the same dataset/parameters/settings we provided in the most recent repo for these experiments?

(2) The prompt examples you provided are different from what we tested. It makes it difficult to tell if the issue comes from the prompt or the model. Can you check if the issue you mentioned also appears in the examples we provided in prompt_eval.py?

Best, Minjia

DanqingZ commented 1 year ago

Hello @minjiaz, I've realized that my step3-trained model consistently outputs "Thanks" due to poor performance of my reward model. I trained the reward model on a P3dn24 EC2 instance (8 x V100, 32G) using the same parameter settings as the single-node script found at https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/step2_reward_model_finetuning/training_scripts/single_node/run_350m.sh.

As shown below, it appears that the more "thanks" tokens that are output, the higher the score. This has caused the PPO trainer to adjust the model to output as many "thanks" tokens as possible,

==================Eval result============================
prompt:  Human: Explain the moon landing to a 6 year old in a few sentences. Assistant: 

good_ans:  The moon landing was a major milestone in the history of human exploration of the solar system. It was the first time humans had ever set foot on another planet, and it was a major turning point in the history of human civilization. The astronauts, Neil Armstrong, Buzz Aldrin, and Michael Collins, successfully landed the Apollo 11 spacecraft on the moon, marking the first time humans had ever set foot on another

bad_ans: I don't know, I don't know.

=============Scores (higher, better)========================
good_ans score:  17.18132209777832
bad_ans score:  6.581749439239502
==================Eval result============================
prompt:  Human: Explain the moon landing to a 6 year old in a few sentences. Assistant: 

good_ans:  thanks!

bad_ans: thanks! thanks! thanks! thanks! thanks! thanks!

=============Scores (higher, better)========================
good_ans score:  14.148694038391113
bad_ans score:  15.361347198486328
==================Eval result============================
prompt:  Human: Explain the moon landing to a 6 year old in a few sentences. Assistant: 

good_ans:  thanks!

bad_ans: thanks! thanks! thanks! thanks! thanks! thanks!thanks! thanks! thanks! thanks! thanks! thanks!thanks! thanks! thanks! thanks! thanks! thanks!thanks! thanks! thanks! thanks! thanks! thanks!

=============Scores (higher, better)========================
good_ans score:  14.148694038391113
bad_ans score:  16.589706420898438

However, I thought my step 2 model was properly trained because the acc is ok

image

Are there any other metrics that I could use to identify this kind of performance issues before training with PPO? I am currently retraining my model.

DanqingZ commented 1 year ago

Hi, the performance issues with and without ema are resolved un this issues https://github.com/microsoft/DeepSpeedExamples/issues/442

I have also integrated wandb into my training code for logging rewards and performed adjustments on the learning rate and gradient accumulation step. I realize that the given parameters may not be ideal for my 8 x A100, 32G EC2 instance,, and some hyperparameter tuning is required.

In below plot, I logged 4 experiments with different parameters:

image

As you can see from the plot, the red line and blue line indicate the experiments are not successful

blue output

image

red output

image

green output

image

purple output

image

I am resolving this issue as I have found a method to identify when the LoRA performance will be suboptimal.

xikakera commented 1 year ago

Hi, the performance issues with and without ema are resolved un this issues #442

I have also integrated wandb into my training code for logging rewards and performed adjustments on the learning rate and gradient accumulation step. I realize that the given parameters may not be ideal for my 8 x A100, 32G EC2 instance,, and some hyperparameter tuning is required.

In below plot, I logged 4 experiments with different parameters: image As you can see from the plot, the red line and blue line indicate the experiments are not successful

blue output image

red output image

green output image

purple output image

I am resolving this issue as I have found a method to identify when the LoRA performance will be suboptimal.

Hi, thank you for sharing your results.

I noticed that you have two lr values in your saved name. I would like to ask which model’s lr value is the second one?

Thank you!