openai / grok

MIT License
4.06k stars 506 forks source link

The label is in the input #8

Open zhaoyanlyu opened 9 months ago

zhaoyanlyu commented 9 months ago

Thank you for sharing your implementation.

If I understood correctly, the toy example in this paper is to train a network (Transformer) to solve equation:

$$a \circ b = c$$

given $a$, $b$ as inputs, predicting the correct $c$.

To translate it for the Transformer, we tokenize everything, and add end-of-sentence <|EOS|> token in the following fashion (which is suggested by this code).

<|EOS|> <a> <OP> <b> <=> <c> <|EOS|>

where <a> <b> and <c> are integers.

By design, we may use <|EOS|> <a> <OP> <b> <=> <?> <|EOS|> as the input, where <?> is a placeholder token for the solution to the equation. The output of the Transformer can be the predicted equation: <|EOS|> <a> <OP> <b> <=> <c_> <|EOS|>, where <c_> indicates the predicted token c. And the target should be the correct, full equation: <|EOS|> <a> <OP> <b> <=> <c> <|EOS|>. We then calculate the loss base on the second-to-the-last tokens: <c>, <c_>.

However, in this implementation, the input is the first 6 tokens, i.e. <|EOS|> <a> <OP> <b> <=> <c>, while the target is the last 6 tokens, i.e. <a> <OP> <b> <=> <c> <|EOS|>. The attached figure shows the first batch of x (input) and y (target) obtained in debugging form follow position

https://github.com/openai/grok/blob/43efed280af24a8837b05fd9c97a3d14f295666f/grok/training.py#L292C1-L293C63

Untitled

In the figure above, 0 indicating <|EOS|> token, 1 indicating '<=>' token, 6 indicating the '**2+' operation (which is a conditional equation depending on odd or even <a>)

The problem is the solution is already in the input x. Therefore, I think the model is trained on a wrong task.