delphi-suite / delphi

small language models training made easy
Apache License 2.0
9 stars 1 forks source link

`model(X, labels=Y, return_dict=True).loss` is wrong #133

Closed jettjaniak closed 4 months ago

jettjaniak commented 5 months ago

it should be X, labels=X ideally we would force it to do what it was supposed to, instead of shifting tokens on it's own but if we can't we need to adjust the design of tokenization script to produce sequences of seq_len (512) instead of seq_len+1 (513)

jettjaniak commented 5 months ago

we need some performance test to catch issues like this in the future

jaidhyani commented 5 months ago

Don't we still want to pass it seq_len+1? If it's converting it to internally we still get 512 positions on inputs of length 513, right?

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1233

SrGonao commented 5 months ago

I agree with jai, we don't need to change the tokenizer. @jaidhyani could you make a PR to fix this loss?

jaidhyani commented 5 months ago

Already merged it a few days ago. Still need to add performance tests though

On Mon, Apr 29, 2024, 5:23 AM Goncalo Paulo @.***> wrote:

I agree with jai, we don't need to change the tokenizer. @jaidhyani https://github.com/jaidhyani could you make a PR to fix this loss?

— Reply to this email directly, view it on GitHub https://github.com/delphi-suite/delphi/issues/133#issuecomment-2082593399, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAC5BYY5CGNHHIGJZV2XEF3Y7Y3VZAVCNFSM6AAAAABG4GQMJOVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDAOBSGU4TGMZZHE . You are receiving this because you were mentioned.Message ID: @.***>

jettjaniak commented 4 months ago

Don't we still want to pass it seq_len+1? If it's converting it to internally we still get 512 positions on inputs of length 513, right?

https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L1233

I'm not sure what you're linking to, probably line numer shifted with some new commits on main. I believe it'll compute logits for all input ids, and then discard logits for last position when computing loss.

https://github.com/huggingface/transformers/blob/ccdabc5642bf84849af93f591e207dc625c8e1e1/src/transformers/models/llama/modeling_llama.py#L1213-L1214

It's not a big deal, but considering this I think inputs should be seq_len. LMK if you have strong takes