google-research / t5x

Apache License 2.0
2.65k stars 301 forks source link

[Question] Weight init for lm_head #1477

Open Birch-san opened 9 months ago

Birch-san commented 9 months ago

Hi t5x community,

I am trying to align the HF transformers implementation of T5 with MTF and T5X, such that it can be relied upon for pretraining.

Could you possibly confirm whether what I've found here regarding HF's pretraining weight initialization, is a difference compared to the official T5X or MTF implementations?
https://github.com/huggingface/transformers/pull/26441

HF initialize their (untied) lm_head with std=1:
https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L831

I suspect this is a mistake. I think they are approaching its initialization as though the lm_head were to be tied to the embedding layer (though in this code path, it is not).
Perhaps they could compensate for its increased variance by scaling the logits down before giving them to the lm_head:
https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L1769

…but they are not applying this compensation in the untied case.

Consequently, I see that the lm_head initially outputs huge activations with variance ~= hidden_dim, and results in initial cross-entropy loss of ~110.

I see that nanoT5 copied the same convention, and empirically exhibits the same behaviour: huge initial loss.
https://github.com/PiotrNawrot/nanoT5/issues/25
I think this will have had consequences for the paper they wrote around pretraining on HF code.

As for the fix…

I think I found evidence in t5x that the lm_head initializiation needs to be changed to std=hidden_dim**-.5:
https://github.com/huggingface/transformers/pull/26441#issuecomment-1743477649

though there is a competing theory (based on MTF) that lm_head ought to be initialized to std=0.05:
https://github.com/huggingface/transformers/pull/26441#issuecomment-1741840284
I do note though that 0.05 is of a very similar magnitude to hidden_dim**-.5, so perhaps it could work similarly well:

512**-.5
0.044
768**-.5
0.036

Does this sound about right? Should an untied lm_head be initialized with std=hidden_dim**-.5 normally-distributed noise?

Bonus question: are any of HF's other layers initialized incorrectly?
https://github.com/huggingface/transformers/blob/3b7675b2b844b02d4821b827871a21ad16dd446c/src/transformers/models/t5/modeling_t5.py#L818

Thanks!

Taytay commented 8 months ago

Thanks @Birch-san for opening this issue. I’m an interested observer who is anxious to hear the response. This would mean a lot to those of us who want to pre-train T5 the “correct” way.