da03 / Internalize_CoT_Step_by_Step

https://huggingface.co/spaces/yuntian-deng/gpt2-multiplication
MIT License
46 stars 3 forks source link

Performance form stage 1 to stage 6 #1

Open Lucas-TY opened 1 month ago

Lucas-TY commented 1 month ago

Hi, thanks for this great work.

May I ask about the intermediate performance from stage 1 to stage 6? I read the paper but I'm still unclear about why it works. Also, did you compare the results of the model trained without CoT (stage 6 only)?

da03 commented 1 month ago

Thank you for your interest! For the intermediate performance from stage 1 to stage 6, please see the blue curve in figure 3 in https://arxiv.org/pdf/2405.14838. In this experiment, we removed 8 CoT tokens per epoch, and evaluate on the validation set the intermediate accuracy (so all points' x coordinates are multiples of 8). Note that the accuracy is fluctuating a lot in this curve since we want to remove tokens as fast as possible --- generally we found that the slower we remove tokens the more stable the intermediate stages, so I would expect to see a very flat curve if we only remove 1 token per epoch.

Regarding why it works, I think there is no free lunch, and the high-level reason it outperforms without CoT (see table 3, first block - No CoT), is that we are using CoT steps as supervision during training. Stage 0 is the same as normal explicit CoT training so it's trivial for the model. The key intuition is that when we remove one more CoT token and finetune the model to directly start from producing the next CoT token, the model is forced to internalize one CoT step into its internal hidden states, and internalizing one CoT token is not too challenging for it. It's as if using a scaffold to help the model, but then we gradually remove the scaffold and the model gradually figures out how to work on its own.

da03 commented 1 month ago

Another thing I want to add regarding figure 3: it's plotting the accuracy of generating the entire word, so any mistake in the middle will be counted as 0. If we look at token-level accuracy, it's always above 90% and mostly around 99%, so the model almost always stay close to being able to solve the task (under the CoT budget of each stage).

Lucas-TY commented 1 month ago

Ah, maybe I misunderstood it. How can you prove that this method works better than training the model without CoT? Based on my understanding, your approach removes the middle part of the calculation: 12 x 34 = 12 x 4 + 12 x 30 = 48 + 360 = 408. My question is, would your method outperform training the model directly without CoT (12 x 34 = 408)? Figure 3 only showed your method and some other removing method.

da03 commented 4 weeks ago

The results of not using CoT are shown in Table 2 and Table 3 No CoT. Also, I thought it's pretty obvious that we cannot directly train a GPT-2 Small to solve 9-by-9 multiplication without CoT...

da03 commented 4 weeks ago

This is also shown in prior works such as https://arxiv.org/pdf/2112.00114. In their Figure 3 - In-Distribution Accuracy, you can see that without CoT (scratchpad) even the accuracy of multi-digit addition if very low (baseline).