lee-ny / teaching_arithmetic

MIT License
71 stars 19 forks source link

Follow-Up Issue: Question on the Efficacy of the `reverse` Method and Its Underlying Rationale #2

Open james016 opened 1 year ago

james016 commented 1 year ago

Hello again,

I recently submitted an issue titled "Potential Bug: Improvement in 3-Digit Addition Baseline by Adjusting Prompt Formatting" and would like to follow up with another query related to your "Teaching Arithmetic to Small Transformers" work.

Concerns:

I have some questions about the principal argument in your paper, which suggests that solving arithmetic problems by considering the most significant digit first requires a more global approach, making the task significantly harder to train.

When examining a 3-digit addition task like (A3A2A1 + B3B2B1 = C3C2C1), the paper claims that (C3) would require comprehensive, global information. However, in most instances, (C3) can be computed using only (A3), (B3), and possibly the carry from (A2 + B2). The task only requires information from all digits when (A2 + B2 = 9).

For the reverse method (A3A2A1 + B3B2B1 = C1C2C3), the computation for (C3) seems similarly dependent on carries from (B2), (A2), and (C2). Therefore, it's unclear to me why there would be a substantial difference in complexity between the plain and reverse methods for calculating (C3).

Additional Evidence:

In my earlier investigation, I observed that bad cases from the plain2 method rarely included situations where (A2 + B2 = 9). This leads me to wonder if the primary reason the reverse method performs better might differ from what is discussed in the paper.

I'm eager to hear your insights on this matter.

Best regards, Su Wang

james016 commented 1 year ago

Hello again,

I'm following up on my own issue to provide some additional data and clarifications. Upon further testing, I have new results to share regarding the plain2 method's performance and bad cases.

Updated Performance:

The plain2 method achieved an accuracy of 9663 out of 9900 examples, which equates to approximately 97.61%.

Updated Statistics on Bad Cases:

I realized that my initial observation about the rarity of errors involving ( A2 + B2 = 9 ) was not as accurate as I initially thought. After more thorough testing, I found that about 61.6% of the errors do indeed involve cases where ( A2 + B2 = 9 ). However, there is still a significant proportion—around 38.4%—of errors that are not easily accounted for by this scenario.

I am keen to further discuss this topic once the original issue has been reviewed. The new findings could offer additional perspectives on both the task's complexity and the effectiveness of various approaches.

Best regards, Su Wang

lee-ny commented 1 year ago

Hello Su,

Thank you for your interest and the analysis.

Our intuition is as follows (please refer to lemma 1 and lemma 2 for proofs and explanation): Your analysis is on point regarding the observation that (C3) can often be computed with only partial information in certain cases. However, the model lacks knowledge about whether (A2 + B2 = 9) during its generation process. This requires the model to learn a mapping that considers all digits to determine if a carry should be propagated. This complexity is why the plain format requires learning a mapping that incorporates every digit.

In contrast, the reverse method benefits from having carry information available since it outputs digits of lower significance first and can make use of this information to generate the next digit. For example, outputting C2 is simply A2+B2+carry from C1, which makes the mapping easier to learn.

We hope this explanation clarifies the rationale for our reverse methods. Please feel free to reach out to us over email (nayoung.lee@wisc.edu and cc my co-author ksreenivasa2@wisc.edu) if you have further questions or insights to share.

Best regards, Nayoung

james016 commented 1 year ago

Thank you for your detailed response, Nayoung.

I appreciate the efforts to explain the rationale behind the reverse and plain methods. However, I would like to clarify some points based on my own analysis, specifically concerning the complexity of calculating C3 in both methods.

For the reverse method, to deduce the value of C3, one needs the knowledge of C2, A2, B2, A3, and B3. This is because knowing these digits is essential to infer whether there is a carry from the previous digits that affects C3.

Similarly, in the plain method, for cases where A2 + B2 != 9, calculating C3 also requires the knowledge of A2, B2, A3, and B3. Here too, a carry can be determined solely based on A2 and B2, without needing to know the sum of A1 and B1. This leads me to think that, for this specific subset of problems, the complexity of calculating C3 is similar in both methods.

However, for cases where A2 + B2 = 9 in the plain method, the complexity undoubtedly increases as one would need to know all digits from A1 to A3 and B1 to B3 to accurately deduce C3.

In my previous tests, I found that about 61.6% of errors in the plain method occurred when A2 + B2 = 9, which supports your statement. But that leaves around 38.4% of errors unaccounted for (given that the reverse method has nearly a 100% accuracy rate). I suspect that the reverse method may bring some further advantages that are not explained by the reduced complexity in calculating C3. I would love to hear your thoughts on this.

I don't fully agree with the point that the model lacks knowledge about whether (A2 + B2 = 9) because, from a task complexity perspective, both the plain and reverse methods should have access to this information during the computation of C3.

Looking forward to your thoughts on this.

If you wish for a more in-depth discussion, feel free to reach out to me via email at swang016@gmail.com.

Best regards, Su