openreasoner / openr

OpenR: An Open Source Framework for Advanced Reasoning with Large Language Models
https://openreasoner.github.io/
MIT License
1.01k stars 71 forks source link

Is this normal? MCTS is worse than the COT method. #49

Open Dada-Cloudzxy opened 3 days ago

Dada-Cloudzxy commented 3 days ago

System Info

python==3.10.15 cuda==11.8-8.8.1 torch==2.4.0 The latest version of code GPU A100_40G * 8

Who can help?

@ziyuwan @Gebro13 @

Information

Tasks

Reproduction

Two experiments for To compare MCTS and COT, I conducted two experiments

Experiment 1: I used mistral-7b-sft and math-shepherd-mistral-7b-prm downloaded form huggingface to test.

  1. run create_service_math_shepherd.sh to start server.
  2. run 'cot_greedy.sh', 'vanila_mcts.sh', 'cot_rerank.sh' (not change config in scripts). I got three results. Method: cot. Average result: ({'majority_vote': 0.28, 'total_completion_tokens': 287.582},) Method: vanila_mcts. Average result: ({'majority_vote': 0.258, 'total_completion_tokens': 974.864},) Method: best_of_n. Average result: ({'majority_vote': 0.264, 'total_completion_tokens': 249.592},)

    Experiment 2: To rule out the effect of poor performance of the math_shepherd model, I used the MATH_APS dataset to train the Qwen2.5-Math-1.5B PRM.

  3. 'torchrun --nproc_per_node=4 prm/code/finetune_qwen.py --model_path Qwen/Qwen2.5-Math-1.5B-Instruct --per_device_train_batch_size 2 --learning_rate 1e-4', it would train prm with 3 epochs and save 3 checkpoints.
  4. run 'create_service_qwen2.5_math_vllm.sh' 3 times with three checkpoints, and run 'cot_greedy.sh' and 'vanila_mcts.sh'( change tree_max_width from 4 to 5 and tree_max_depth from 50 to 40). I got four results (baseline and three prm guided result). Baseline: cot. Average result: ({'majority_vote': 0.744, 'total_completion_tokens': 553.754},) checkpoint 1: vanila_mcts. Average result: ({'majority_vote': 0.748, 'total_completion_tokens': 2873.62},) checkpoint 2: vanila_mcts. Average result: ({'majority_vote': 0.728, 'total_completion_tokens': 2718.342},) (really bad result) checkpoint 3: vanila_mcts. Average result: ({'majority_vote': 0.746, 'total_completion_tokens': 2873.62},)

Expected behavior

The MCTS method has a serious performance degradation while increasing the computational loss compared to the report showing the improvement of other methods. I wonder if there is a problem with this? Many thanks to the repository owners and contributors!!!!!!!

ziyuwan commented 3 days ago

Could you share your MCTS config here, I guess you set select_by_prior=True

Dada-Cloudzxy commented 2 days ago

Could you share your MCTS config here, I guess you set select_by_prior=True

OK, like i said previously, I just modified tree width and depth in official config.

'''sh python reason/evaluation/evaluate.py \ --LM Qwen2.5-Math-1.5B-Instruct \ --RM checkpoint-99591 \ --task_name MATH \ --temperature 0.7 \ --max_new_tokens 2048 \ --num_sequence 1 \ --tree_max_width 5 \ --tree_max_depth 40 \ --save_dir debug \ --method vanila_mcts \ --num_worker 4 \ --controller_addr http://0.0.0.0:28777 \ ''' For select_by_prior, evaluate.py: line 208, It's written as False in the code

ziyuwan commented 2 days ago

Thanks for your information, have you tried to use our open-sourced PRM in vanila_mcts? I'm not sure whether it's because your reward model is not good enough.

Dada-Cloudzxy commented 2 days ago

Thanks for your information, have you tried to use our open-sourced PRM in vanila_mcts? I'm not sure whether it's because your reward model is not good enough.

In experiment one, I used open-sourced PRM like math-shepherd-mistral-7b-prm. In experiment two, I trained Qwen2.5 with Math-APS following this work. I will try to use the open source model of this work. I also want to reproduce prm, so could you give me some information about training it(eg. training hyperparameters). Now I just use config in official scripts and MATH-APS dataset.

ziyuwan commented 2 days ago

I also want to reproduce prm, so could you give me some information about training it(eg. training hyperparameters). Now I just use config in official scripts and MATH-APS dataset.

Thanks for your response, I think @Gebro13 could help with prm training and currently we are updating the code of prm labeling.

ziyuwan commented 2 days ago

In experiment one, I used open-sourced PRM like math-shepherd-mistral-7b-prm.

I'll try to reproduce the results and see if any bugs/errors here.

Dada-Cloudzxy commented 1 day ago

@ziyuwan @Gebro13

Hi! I used your open-sourced PRM in vanila_mcts. It seemed to be working. But when I use other open-sourced PRMs or reproduce my prm with MATH-APS dataset, it failed. So could @Gebro13 help me with prm training details? Thank you very much!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!


use open-sourced math-psa Method: cot. Average result: ({'majority_vote': 0.826, 'total_completion_tokens': 642.766},) Method: vanila_mcts. Average result: ({'majority_vote': 0.84, 'total_completion_tokens': 3617.946},)

FanqingM commented 17 hours ago

in this file: reason/evaluation/methods.py in line 44: def best_of_n( config: BestOfNConfig, gen_config: LMCallingConfig, problem_inst: Dict[str, str], lm_call: LanguageModelCallingFunction, rm_call: RewardModelCallingFunction, ) -> SolutionOutput: if gen_config.max_new_tokens < 256: print("Warning: max_new_tokens is less than 256")

gen_config.n = config.num_sequence
task = Task(task_name=config.task_name)
prompt = task.prompt_fn(problem_inst["question"])
output = lm_call(prompt, gen_config)
completion_tokens = output.num_tokens
return SolutionOutput(
    solutions=output.text,
    completion_tokens=completion_tokens,
)

why the func param has rm_call, but it does not occurs in the func body?

ziyuwan commented 17 hours ago

why the func param has rm_call, but it does not occurs in the func body?

We will evaluate all answers' scores in MathEvaluator.

Dada-Cloudzxy commented 9 hours ago

@ziyuwan @Gebro13

Hi! I used your open-sourced PRM in vanila_mcts. It seemed to be working. But when I use other open-sourced PRMs or reproduce my prm with MATH-APS dataset, it failed. So could @Gebro13 help me with prm training details? Thank you very much!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

use open-sourced math-psa Method: cot. Average result: ({'majority_vote': 0.826, 'total_completion_tokens': 642.766},) Method: vanila_mcts. Average result: ({'majority_vote': 0.84, 'total_completion_tokens': 3617.946},)

@ziyuwan @Gebro13 Hi!This time I used 7b prm to train on the MATH-APS dataset, and again I did not get good results. I would appreciate it if you could provide some details on prm training!Looking forward to reply!

Model Baseline(cot_greedy) Epoch-1 Epoch-2 Epoch-3
Qwen2.5-Math-1.5b 0.744 0.748 0.728 0.746
Qwen2.5-Math-7b 0.826 0.816 0.824 0.818