Linear95 / SPAG

Self-playing Adversarial Language Game Enhances LLM Reasoning
Apache License 2.0
63 stars 7 forks source link

Extremely unbalanced attacker defender winrate and strange behavior of SFTed model #5

Open thwu1 opened 2 months ago

thwu1 commented 2 months ago

In the gpt4_game_top30k_results.json file, there are 20067 attacker win samples and 3287 defender win samples, with att/def ~ 6.1

While after SFTed the model using

torchrun --nproc_per_node=8 --master_port=6000 train.py \
    --output_dir "./ckpt" \
    --model_name_or_path "meta-llama/Llama-2-7b-hf" \
    --ref_model_name_or_path "meta-llama/Llama-2-7b-hf" \
    --lm_kl_coeff 0.1 \
    --train_method "SFTwithKL" \
    --train_data_path "./data/train_imitation_gpt4.json" \
    --remove_unused_columns False \
    --num_train_epochs 1 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --evaluation_strategy no \
    --padding_side "right" \
    --truncation_side "left" \
    --max_length 2048 \
    --save_strategy epoch \
    --learning_rate 5e-6 \
    --lr_scheduler_type "cosine" \
    --warmup_ratio 0.03 \
    --logging_steps 1 \
    --weight_decay 0. \
    --deepspeed "./configs/default_offload_opt_param.json" \
    --gradient_checkpointing True \
    --tf32 True \
    --bf16 True

Here's the result statsitics of the generated self play histories (p.s. I optimized the play_llm_game code for speedup):

totally get 99740 data from 33973 game, with 31273 attacker games;  2700 defender games.
reweight the sample with attacker_weight: 0.5431682281840565 ; defender_weight: 6.291296296296296

The att/def ~ 11.58 is even higher, indicate more inbalance!

More strangly, there's around 10-15% (by direct eye ball) empty responses among all generated self play histories after the first SFT. For instance:

     "history": [
            {
                "role": "attacker",
                "content": " let's dive right in. If you think of the saying \"once upon a time\", afterwards comes what?"
            },
            {
                "role": "defender",
                "content": " Happily ever after"
            },
            {
                "role": "attacker",
                "content": " That's a nice way to put it. But at this point, before \"happily ever after\", there's generally another number that follows. Can you guess what that number might be?"
            },
            {
                "role": "defender",
                "content": " The number usually associated with \"happily ever after\" is two."
            },
            {
                "role": "attacker",
                "content": " That's a good guess, but it's actually four. Think about the fairy tale \"The Three Little Pigs\". There are three pigs, right?"
            },
            {
                "role": "defender",
                "content": " Yes, you're right."
            },
            {
                "role": "attacker",
                "content": " But in that story, two more stories happen on either side of the Key Bridge."
            },
            {
                "role": "defender",
                "content": ""
            },
            {
                "role": "attacker",
                "content": ""
            },
            {
                "role": "defender",
                "content": " I know the word! It is \"Third.\""
            }
        ],
        "target_word": "three",
        "max_turns": 5,

Could the author confirm the unbalance and the strange SFTed behavior? @Linear95 @underwoodnoble

Linear95 commented 2 months ago

Hi, sorry for the mistake. We have fixed the bug of unbalance by adding lower() to content in rule-based outcome judgment. Please check #6.

Linear95 commented 2 months ago

For the empty-response problem, I guess there are some bugs in the response-collecting process. Could you describe how you accelerate play_llm_game.py in detail?

thwu1 commented 2 months ago

For the lower issue I have already fixed, but still very unbalanced. Is 50-50 an expected outcome?

I'm spinning up the backend using vllm, which speed up from 10h generate time to 0.5h

Linear95 commented 2 months ago

It is not a 50-50 win rate in expectation. For imitation learning, the LLM only learns how to play the game. There is no regularizer to guarantee that the attacker and defender have the same level of playing. But if you continuously try SPAG-1 training on the im_model, the win rate will get closer to 50-50 in expectation.

Linear95 commented 2 months ago

For the lower issue I have already fixed, but still very unbalanced. Is 50-50 an expected outcome?

I'm spinning up the backend using vllm, which speed up from 10h generate time to 0.5h

Still not sure where the bug is for the empty reply... Maybe print out each prompt and its generation in the dialog collection loop?

thwu1 commented 2 months ago
    "history": [
      {
        "role": "attacker",
        "content": " This is an unusual type of mineral that I've noticed some people use as insulation, especially when they've been working with clay-type products. It's unique in its ability to really adsorb heat because of its tiny particles. Have you ever had to deal with that sort of material for construction work? How did it go?"
      },
      {
        "role": "defender",
        "content": " No, I haven't had to deal with that kind of material for construction work, but it's interesting to find out about such materials and their properties."
      },
      {
        "role": "attacker",
        "content": ""
      },
      {
        "role": "defender",
        "content": ""
      },
      {
        "role": "attacker",
        "content": " And it's fascinating how various items can have properties that seem so different, often with their own internal workings in the form of substances with quite particular structures and ways of acting on stimuli. They can be quite surprising. There's an element that's often included in fire-resistant products. It's sort of a miracle material for what it's made of, given its specific chemical composition. It's used primarily for one of those purposes, yet it's also been associated with a certain health concern for those people who were handling a lot of that insulation over many years. Curiosity, don't you think?"
      },
      {
        "role": "defender",
        "content": ""
      },
      {
        "role": "attacker",
        "content": ""
      },
      {
        "role": "defender",
        "content": ""
      },
      {
        "role": "attacker",
        "content": "\n  - attacker: "
      },
      {
        "role": "defender",
        "content": ""
      }
    ],
    "target_word": "vermiculite",
    "max_turns": 5

@Linear95 This is one example I get using original play_llm_game.py, and the statistic:

totally get 121563 data from 40364 game, with 35893 attacker games;  4471 defender games.
reweight the sample with attacker_weight: 0.5622823391747694 ; defender_weight: 4.513978975620667
thwu1 commented 2 months ago

Some hypothesis:

  1. Data is wrong
  2. SFT wrong due to some strange machine issue
  3. Some tokenizer issue cause the decoding to be strange
HanClinto commented 2 months ago

Is there any chance that the original im_llama2_sampling_all_words_results.json results from the paper authors could be uploaded to Google Drive or something? I would love to look through the results in more depth!

thwu1 commented 2 months ago

@Linear95 Here's one sample I found in gpt4_game_top30k_results.json that's likely labelled wrong. Don't know how many have the similar problem

  {
    "target_word": "think",
    "max_turns": 6,
    "history": [
      {
        "role": "attacker",
        "content": "let's get started. When you're trying to solve a problem, what's the first thing you do?"
      },
      {
        "role": "defender",
        "content": "I usually start by understanding the problem thoroughly. I try to gather as much information as possible about it. Then, I brainstorm possible solutions and evaluate each one of them based on their feasibility and effectiveness."
      },
      {
        "role": "attacker",
        "content": "That's a great approach! So, when you're evaluating these solutions, what process do you use? Do you just go with your gut feeling, or is there a more systematic way you approach it?"
      },
      {
        "role": "defender",
        "content": "I prefer a more systematic approach. I usually weigh the pros and cons of each solution, consider the resources required, and evaluate the potential impact. I also consider the time frame in which the solution needs to be implemented. This helps me to make a more informed decision."
      },
      {
        "role": "attacker",
        "content": "That's a very thorough and logical approach. So, when you're in this process, what would you say is the most important action you're doing? It's something that's crucial to problem-solving and decision-making."
      },
      {
        "role": "defender",
        "content": "The most important action in this process is probably analysis. It's crucial to analyze the problem, the potential solutions, and the impacts of those solutions. Without proper analysis, it's difficult to make informed decisions."
      },
      {
        "role": "attacker",
        "content": "Absolutely, analysis is key. But there's also another crucial step that involves using your brain to form conclusions or ideas. It's something we do all the time, especially when we're problem-solving. What would you call that process?"
      },
      {
        "role": "defender",
        "content": "I know the word! It is \"thinking\"."
      }
    ],
    "winner": "attacker"
  },
Linear95 commented 2 months ago

@Linear95 Here's one sample I found in gpt4_game_top30k_results.json that's likely labelled wrong. Don't know how many have the similar problem

  {
    "target_word": "think",
    "max_turns": 6,
    "history": [
      {
        "role": "attacker",
        "content": "let's get started. When you're trying to solve a problem, what's the first thing you do?"
      },
      {
        "role": "defender",
        "content": "I usually start by understanding the problem thoroughly. I try to gather as much information as possible about it. Then, I brainstorm possible solutions and evaluate each one of them based on their feasibility and effectiveness."
      },
      {
        "role": "attacker",
        "content": "That's a great approach! So, when you're evaluating these solutions, what process do you use? Do you just go with your gut feeling, or is there a more systematic way you approach it?"
      },
      {
        "role": "defender",
        "content": "I prefer a more systematic approach. I usually weigh the pros and cons of each solution, consider the resources required, and evaluate the potential impact. I also consider the time frame in which the solution needs to be implemented. This helps me to make a more informed decision."
      },
      {
        "role": "attacker",
        "content": "That's a very thorough and logical approach. So, when you're in this process, what would you say is the most important action you're doing? It's something that's crucial to problem-solving and decision-making."
      },
      {
        "role": "defender",
        "content": "The most important action in this process is probably analysis. It's crucial to analyze the problem, the potential solutions, and the impacts of those solutions. Without proper analysis, it's difficult to make informed decisions."
      },
      {
        "role": "attacker",
        "content": "Absolutely, analysis is key. But there's also another crucial step that involves using your brain to form conclusions or ideas. It's something we do all the time, especially when we're problem-solving. What would you call that process?"
      },
      {
        "role": "defender",
        "content": "I know the word! It is \"thinking\"."
      }
    ],
    "winner": "attacker"
  },

This is because we only considered the exact target word inference in the GPT-4 data collection. Later we found this rule was too strict for the defender, then we considered the word variations of tense and number in the self-play data collection. However, this does not make a significant change in the game outcomes.

We just applied the current judgment in assign_rewards.py to gpt4_game_top30k_results.json, the number of defender-winning games only increases from 4293 to 4456.

Linear95 commented 2 months ago

Is there any chance that the original im_llama2_sampling_all_words_results.json results from the paper authors could be uploaded to Google Drive or something? I would love to look through the results in more depth!

I have provided all the self-play game episodes in the data/ folder. Pls check.

Linear95 commented 2 months ago

Some hypothesis:

  1. Data is wrong
  2. SFT wrong due to some strange machine issue
  3. Some tokenizer issue cause the decoding to be strange

For the third hypothesis, there is a risk that the tokenizer initialized differently in train.py and play_llm_game.py... Since LLaMA-2 does not have pad_token, we set pad_token_id=0 initially. Later, we find this inappropriate for other LLMs for token_id 0 already being used. So we changed the tokenizer special-token setups again.

I just defined a new function for special token initialization and called it in both train.py and play_llm_game.py. Could you please try the imitation training and sampling again to check if empty responses still exist?

thwu1 commented 2 months ago

Some hypothesis:

  1. Data is wrong
  2. SFT wrong due to some strange machine issue
  3. Some tokenizer issue cause the decoding to be strange

For the third hypothesis, there is a risk that the tokenizer initialized differently in train.py and play_llm_game.py... Since LLaMA-2 does not have pad_token, we set pad_token_id=0 initially. Later, we find this inappropriate for other LLMs for token_id 0 already being used. So we changed the tokenizer special-token setups again.

I just defined a new function for special token initialization and called it in both train.py and play_llm_game.py. Could you please try the imitation training and sampling again to check if empty responses still exist?

Yes, I'll try that. Thank you!

thwu1 commented 1 month ago

@Linear95 I tried to replicate the results. Here's the summarization:

During generation, by passing min_tokens=2 I can avoid empty responses observed before, as a result this increase the defender win rate. Here's the result for llama-2-7b on mmlu, the evaluation use lm-eval-harness as suggested by the paper:

After IM: 46.5% +- 0.4% SPAG-1: 46.2% +- 0.4% SPAG-2: 46.5% +- 0.4% SPAG-3: 46.4% +- 0.4%

There seems to be no improvement beyond noise in evaluation. I also observed during training, defender win rate are increasing slowly. The weight decrease from around 4.29 (im) -> 3.92 (spag-1) -> 3.49 (spag-2). (As a reference, the author provided data is 4.28(im) -> 4.27 (spag-1) -> 3.56 (spag-2))

thwu1 commented 1 month ago

Also wanna add some results for ai2_arc and bbh benchmark:

         arc_c (+- 1.46)                  arc_e (+- 0.8)                    bbh (+- 1.12)
IM:          44.37                            76.77                             26.84
SPAG1:       45.20                            76.77                             27.99
SPAG2:       46.76                            77.48                             28.82
SPAG3:       47.78                            77.82                             28.56

I see clear improvement in arc and bbh, so it at least indicate the code is working as expected. Though I don't observe reliable improvement in MMLU.

Linear95 commented 1 month ago

Also wanna add some results for ai2_arc and bbh benchmark:

         arc_c (+- 1.46)                  arc_e (+- 0.8)                    bbh (+- 1.12)
IM:          44.37                            76.77                             26.84
SPAG1:       45.20                            76.77                             27.99
SPAG2:       46.76                            77.48                             28.82
SPAG3:       47.78                            77.82                             28.56

I see clear improvement in arc and bbh, so it at least indicate the code is working as expected. Though I don't observe reliable improvement in MMLU.

We didn't claim that SPAG can significantly improve on MMLU. Actually, I don't think MMLU is a reasoning-testing benchmark. It's evaluating the general language capability. If you check our paper, MMLU even gets worse based on Baichuan-2.

PinzhengWang322 commented 1 week ago

Hi @thwu1, unfortunately, the original author can't share the SPAG checkpoints due to some restrictions. #8 I was wondering if you might be able to share the trained checkpoints you have re-implemented? It would be very helpful. Thanks a lot for considering my request!

thwu1 commented 1 week ago

Hi @thwu1, unfortunately, the original author can't share the SPAG checkpoints due to some restrictions. #8 I was wondering if you might be able to share the trained checkpoints you have re-implemented? It would be very helpful. Thanks a lot for considering my request!

Hi @PinzhengWang322 I've uploaded:

SFT ckpt: https://huggingface.co/ThWu/spag_im_ckpt SPAG1 ckpt: https://huggingface.co/ThWu/spag_spag1_ckpt SPAG2 ckpt: https://huggingface.co/ThWu/spag_spag2_ckpt SPAG3 ckpt: https://huggingface.co/ThWu/spag_spag3_ckpt

And my code is public at https://github.com/thwu1/spag

PinzhengWang322 commented 1 week ago

Thank you so much for sharing these resources!