google-research / electra

ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
Apache License 2.0
2.33k stars 352 forks source link

The implementation of layerwise learning rate decay #51

Closed importpandas closed 4 years ago

importpandas commented 4 years ago

https://github.com/google-research/electra/blob/79111328070e491b287c307906701ebc61091eb2/model/optimization.py#L188-L193

According to the code here, assume that n_layers=24, then key_to_depths["encoder/layer_23/"] = 24 which is the depth for last encoder layer, but the learning rate for last layer is: learning_rate * (layer_decay ** (24+ 2 - 24)) = learning_rate * (layer_decay ** (2)).

That's what confused me. Why the learning rate for last layer is learning_rate * (layer_decay ** (2)) rather than learning_rate? Do I ignore anything?

clarkkev commented 4 years ago

For the layerwise learning rate decay we count task-specific layer added on top of the pre-trained transformer as additional layer of the model, so the learning rate for the last layer of ELECTRA should be learning_rate 0.8. But you've still found a bug, where instead it is learning_rate 0.8^2.

The bug happened because there used to be a pooler layer in ELECTRA before we removed the next-sentence-prediction task. In that case the learning rates per layer were

importpandas commented 4 years ago

I got it, thanks for your explanation.